Skip to content

Commit e274128

Browse files
authored
TabPFN Example and Alignment with Paper (#299)
* adds tabpfn example * updated average_prediction * updated average_prediction * updated notebook * updated notebook * updated notebook * added data * removes warning when class_index is not provided * adds TODO for xgboost tests * allows games to be initialized from values and be not normalized * updated tabpfn notebook * ran tree notebooks * adds lgbm to tabular notebooks * renames game_fun to game in ExactComputer and closes #297 * ran and updated notebooks * ran TabPFN notebook
1 parent 99e0d77 commit e274128

File tree

20 files changed

+1746
-450
lines changed

20 files changed

+1746
-450
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
## Changelog
22

33
### Development
4+
- renames ``game_fun`` parameter in ``shapiq.ExactComputer`` to ``game`` [#297](https://github.com/mmschlk/shapiq/issues/297)
5+
- adds a TabPFN example notebook to the documentation
6+
- removes warning when class_index is not provided in explainers [#298](https://github.com/mmschlk/shapiq/issues/298)
47
- adds the `sentence_plot` function to the `plot` module to visualize the contributions of words to a language model prediction in a sentence-like format
58
- makes abbreviations in the `plot` module optional [#281](https://github.com/mmschlk/shapiq/issues/281)
69
- adds the `upset_plot` function to the `plot` module to visualize the interactions of higher-order [#290](https://github.com/mmschlk/shapiq/issues/290)

docs/source/notebooks/basics_notebooks/custom_games.ipynb

Lines changed: 66 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,23 @@
1212
{
1313
"metadata": {
1414
"ExecuteTime": {
15-
"end_time": "2024-12-17T14:23:19.696179Z",
16-
"start_time": "2024-12-17T14:23:18.268301Z"
15+
"end_time": "2025-01-10T12:14:05.982266Z",
16+
"start_time": "2025-01-10T12:14:04.426262Z"
1717
}
1818
},
1919
"cell_type": "code",
2020
"source": [
2121
"import shapiq\n",
2222
"import numpy as np\n",
23+
"import os\n",
2324
"\n",
2425
"shapiq.__version__"
2526
],
2627
"outputs": [
2728
{
2829
"data": {
2930
"text/plain": [
30-
"'1.1.1'"
31+
"'1.1.1.dev'"
3132
]
3233
},
3334
"execution_count": 1,
@@ -86,8 +87,8 @@
8687
{
8788
"metadata": {
8889
"ExecuteTime": {
89-
"end_time": "2024-12-17T14:23:19.711170Z",
90-
"start_time": "2024-12-17T14:23:19.698170Z"
90+
"end_time": "2025-01-10T12:14:05.997215Z",
91+
"start_time": "2025-01-10T12:14:05.985240Z"
9192
}
9293
},
9394
"cell_type": "code",
@@ -147,8 +148,8 @@
147148
{
148149
"metadata": {
149150
"ExecuteTime": {
150-
"end_time": "2024-12-17T14:23:19.727173Z",
151-
"start_time": "2024-12-17T14:23:19.713181Z"
151+
"end_time": "2025-01-10T12:14:06.013212Z",
152+
"start_time": "2025-01-10T12:14:06.000205Z"
152153
}
153154
},
154155
"cell_type": "code",
@@ -174,8 +175,8 @@
174175
{
175176
"metadata": {
176177
"ExecuteTime": {
177-
"end_time": "2024-12-17T14:23:19.742179Z",
178-
"start_time": "2024-12-17T14:23:19.730173Z"
178+
"end_time": "2025-01-10T12:14:06.029218Z",
179+
"start_time": "2025-01-10T12:14:06.014204Z"
179180
}
180181
},
181182
"cell_type": "code",
@@ -207,8 +208,8 @@
207208
{
208209
"metadata": {
209210
"ExecuteTime": {
210-
"end_time": "2024-12-17T14:23:19.758170Z",
211-
"start_time": "2024-12-17T14:23:19.745174Z"
211+
"end_time": "2025-01-10T12:14:06.045214Z",
212+
"start_time": "2025-01-10T12:14:06.033206Z"
212213
}
213214
},
214215
"cell_type": "code",
@@ -245,8 +246,8 @@
245246
{
246247
"metadata": {
247248
"ExecuteTime": {
248-
"end_time": "2024-12-17T14:23:19.789577Z",
249-
"start_time": "2024-12-17T14:23:19.760172Z"
249+
"end_time": "2025-01-10T12:14:06.061209Z",
250+
"start_time": "2025-01-10T12:14:06.046217Z"
250251
}
251252
},
252253
"cell_type": "code",
@@ -280,7 +281,7 @@
280281
"application/vnd.jupyter.widget-view+json": {
281282
"version_major": 2,
282283
"version_minor": 0,
283-
"model_id": "e6e0bc19180b4969bae2cbcabef70fdf"
284+
"model_id": "218e4aac6918408d8a38f1c9646509fb"
284285
}
285286
},
286287
"metadata": {},
@@ -308,40 +309,29 @@
308309
{
309310
"metadata": {
310311
"ExecuteTime": {
311-
"end_time": "2024-12-17T14:23:20.357939Z",
312-
"start_time": "2024-12-17T14:23:19.792499Z"
312+
"end_time": "2025-01-10T12:14:06.076763Z",
313+
"start_time": "2025-01-10T12:14:06.063214Z"
313314
}
314315
},
315316
"cell_type": "code",
316317
"source": [
317318
"# save the precomputed values to a file\n",
318-
"cooking_game.save_values(\"data/cooking_game_values.npz\")\n",
319+
"save_path = os.path.join(\"..\", \"data\", \"cooking_game_values.npz\")\n",
320+
"cooking_game.save_values(save_path)\n",
319321
"\n",
320322
"# load the precomputed values from the file\n",
321323
"empty_cooking_game = CookingGame()\n",
322324
"print(\"Values stored before loading: \", empty_cooking_game.value_storage)\n",
323-
"empty_cooking_game.load_values(\"cooking_game_values.npz\")\n",
325+
"empty_cooking_game.load_values(save_path)\n",
324326
"print(\"Values stored after loading: \", empty_cooking_game.value_storage)"
325327
],
326328
"outputs": [
327329
{
328330
"name": "stdout",
329331
"output_type": "stream",
330332
"text": [
331-
"Values stored before loading: []\n"
332-
]
333-
},
334-
{
335-
"ename": "FileNotFoundError",
336-
"evalue": "[Errno 2] No such file or directory: 'cooking_game_values.npz'",
337-
"output_type": "error",
338-
"traceback": [
339-
"\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
340-
"\u001B[1;31mFileNotFoundError\u001B[0m Traceback (most recent call last)",
341-
"Cell \u001B[1;32mIn[7], line 7\u001B[0m\n\u001B[0;32m 5\u001B[0m empty_cooking_game \u001B[38;5;241m=\u001B[39m CookingGame()\n\u001B[0;32m 6\u001B[0m \u001B[38;5;28mprint\u001B[39m(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mValues stored before loading: \u001B[39m\u001B[38;5;124m\"\u001B[39m, empty_cooking_game\u001B[38;5;241m.\u001B[39mvalue_storage)\n\u001B[1;32m----> 7\u001B[0m \u001B[43mempty_cooking_game\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mload_values\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[38;5;124;43mcooking_game_values.npz\u001B[39;49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[43m)\u001B[49m\n\u001B[0;32m 8\u001B[0m \u001B[38;5;28mprint\u001B[39m(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mValues stored after loading: \u001B[39m\u001B[38;5;124m\"\u001B[39m, empty_cooking_game\u001B[38;5;241m.\u001B[39mvalue_storage)\n",
342-
"File \u001B[1;32mC:\\1_Workspaces\\1_Phd_Projects\\shapiq\\shapiq\\games\\base.py:426\u001B[0m, in \u001B[0;36mGame.load_values\u001B[1;34m(self, path, precomputed)\u001B[0m\n\u001B[0;32m 423\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m path\u001B[38;5;241m.\u001B[39mendswith(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124m.npz\u001B[39m\u001B[38;5;124m\"\u001B[39m):\n\u001B[0;32m 424\u001B[0m path \u001B[38;5;241m+\u001B[39m\u001B[38;5;241m=\u001B[39m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124m.npz\u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[1;32m--> 426\u001B[0m data \u001B[38;5;241m=\u001B[39m \u001B[43mnp\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mload\u001B[49m\u001B[43m(\u001B[49m\u001B[43mpath\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 427\u001B[0m n_players \u001B[38;5;241m=\u001B[39m data[\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mn_players\u001B[39m\u001B[38;5;124m\"\u001B[39m]\n\u001B[0;32m 428\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mn_players \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m \u001B[38;5;129;01mand\u001B[39;00m n_players \u001B[38;5;241m!=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mn_players:\n",
343-
"File \u001B[1;32mC:\\1_Workspaces\\1_Phd_Projects\\shapiq\\venv\\lib\\site-packages\\numpy\\lib\\npyio.py:427\u001B[0m, in \u001B[0;36mload\u001B[1;34m(file, mmap_mode, allow_pickle, fix_imports, encoding, max_header_size)\u001B[0m\n\u001B[0;32m 425\u001B[0m own_fid \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mFalse\u001B[39;00m\n\u001B[0;32m 426\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m--> 427\u001B[0m fid \u001B[38;5;241m=\u001B[39m stack\u001B[38;5;241m.\u001B[39menter_context(\u001B[38;5;28;43mopen\u001B[39;49m\u001B[43m(\u001B[49m\u001B[43mos_fspath\u001B[49m\u001B[43m(\u001B[49m\u001B[43mfile\u001B[49m\u001B[43m)\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[38;5;124;43mrb\u001B[39;49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[43m)\u001B[49m)\n\u001B[0;32m 428\u001B[0m own_fid \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mTrue\u001B[39;00m\n\u001B[0;32m 430\u001B[0m \u001B[38;5;66;03m# Code to distinguish from NumPy binary files and pickles.\u001B[39;00m\n",
344-
"\u001B[1;31mFileNotFoundError\u001B[0m: [Errno 2] No such file or directory: 'cooking_game_values.npz'"
333+
"Values stored before loading: []\n",
334+
"Values stored after loading: [ 0. 4. 3. 2. 9. 8. 7. 15.]\n"
345335
]
346336
}
347337
],
@@ -356,19 +346,42 @@
356346
]
357347
},
358348
{
359-
"metadata": {},
349+
"metadata": {
350+
"ExecuteTime": {
351+
"end_time": "2025-01-10T12:14:06.092763Z",
352+
"start_time": "2025-01-10T12:14:06.077767Z"
353+
}
354+
},
360355
"cell_type": "code",
361356
"source": [
362357
"# initialize a game object directly from precomputed values\n",
363-
"game = shapiq.Game(path_to_values=\"data/cooking_game_values.npz\")\n",
358+
"game = shapiq.Game(path_to_values=save_path)\n",
364359
"print(game)\n",
365360
"\n",
366361
"# query the value function of the game for the same coalitions as before\n",
367362
"coals = np.array([[0, 0, 0], [1, 1, 0], [1, 0, 1], [0, 1, 1], [1, 1, 1]])\n",
368363
"game(coals)"
369364
],
370-
"outputs": [],
371-
"execution_count": null
365+
"outputs": [
366+
{
367+
"name": "stdout",
368+
"output_type": "stream",
369+
"text": [
370+
"Game(3 players, normalize=False, normalization_value=0.0, precomputed=True)\n"
371+
]
372+
},
373+
{
374+
"data": {
375+
"text/plain": [
376+
"array([ 0., 9., 8., 7., 15.])"
377+
]
378+
},
379+
"execution_count": 8,
380+
"metadata": {},
381+
"output_type": "execute_result"
382+
}
383+
],
384+
"execution_count": 8
372385
},
373386
{
374387
"metadata": {},
@@ -379,7 +392,12 @@
379392
]
380393
},
381394
{
382-
"metadata": {},
395+
"metadata": {
396+
"ExecuteTime": {
397+
"end_time": "2025-01-10T12:14:06.108755Z",
398+
"start_time": "2025-01-10T12:14:06.095753Z"
399+
}
400+
},
383401
"cell_type": "code",
384402
"source": [
385403
"print(cooking_game.characteristic_function)\n",
@@ -388,8 +406,17 @@
388406
"except AttributeError as e:\n",
389407
" print(\"AttributeError:\", e)"
390408
],
391-
"outputs": [],
392-
"execution_count": null
409+
"outputs": [
410+
{
411+
"name": "stdout",
412+
"output_type": "stream",
413+
"text": [
414+
"{(): 0, (0,): 4, (1,): 3, (2,): 2, (0, 1): 9, (0, 2): 8, (1, 2): 7, (0, 1, 2): 15}\n",
415+
"AttributeError: 'Game' object has no attribute 'characteristic_function'\n"
416+
]
417+
}
418+
],
419+
"execution_count": 9
393420
}
394421
],
395422
"metadata": {

docs/source/notebooks/basics_notebooks/data_valuation.ipynb

Lines changed: 184 additions & 168 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)