|
12 | 12 | }, |
13 | 13 | { |
14 | 14 | "cell_type": "code", |
15 | | - "execution_count": 17, |
| 15 | + "execution_count": 2, |
16 | 16 | "metadata": {}, |
17 | 17 | "outputs": [], |
18 | 18 | "source": [ |
|
25 | 25 | "import matplotlib\n", |
26 | 26 | "import matplotlib.cm\n", |
27 | 27 | "\n", |
28 | | - "from embedders.predictors.tree_new import _angular_greater\n", |
| 28 | + "from manify.predictors.decision_tree import _angular_greater\n", |
29 | 29 | "\n", |
30 | 30 | "######################\n", |
31 | 31 | "# HELPER FUNCTIONS\n", |
|
316 | 316 | " color = class2color[cl] # consistent across the entire tree\n", |
317 | 317 | " idx = ymask == cl\n", |
318 | 318 | " fig.add_trace(\n", |
319 | | - " # go.Scatter(x=x_in[idx], y=y_in[idx], mode=\"markers\", marker=dict(color=color), name=f\"Class {cl}\")\n", |
| 319 | + " # go.Scatter(x=x_in[idx], y=y_in[idx], mode=\"markers\", marker=dict(color=color), name=f\"Class {cl}\")\n", |
320 | 320 | " go.Scatter(x=y_in[idx], y=x_in[idx], mode=\"markers\", marker=dict(color=color), name=f\"Class {cl}\")\n", |
321 | 321 | " )\n", |
322 | | - " \n", |
323 | 322 | "\n", |
324 | 323 | " # d) Finally set the axis range to the bounding box of the data (so we don't zoom out).\n", |
325 | 324 | " x_min, x_max = x_all.min(), x_all.max()\n", |
|
343 | 342 | "######################\n", |
344 | 343 | "\n", |
345 | 344 | "\n", |
346 | | - "\n", |
347 | 345 | "def create_dashboard(pdt, X, y):\n", |
348 | 346 | " node2id, edges = build_edges_and_id_map(pdt)\n", |
349 | 347 | " node_masks = compute_node_masks(pdt, X, y, node2id)\n", |
|
403 | 401 | " fig = go.Figure()\n", |
404 | 402 | " y_in_leaf = y[mask]\n", |
405 | 403 | " if len(y_in_leaf) == 0:\n", |
406 | | - " fig.add_annotation(\n", |
407 | | - " x=0.5, y=0.5, xref=\"paper\", yref=\"paper\", text=\"Empty Leaf Node\", showarrow=False\n", |
408 | | - " )\n", |
| 404 | + " fig.add_annotation(x=0.5, y=0.5, xref=\"paper\", yref=\"paper\", text=\"Empty Leaf Node\", showarrow=False)\n", |
409 | 405 | " else:\n", |
410 | 406 | " counts = Counter(y_in_leaf.tolist())\n", |
411 | 407 | " labels = list(counts.keys())\n", |
|
431 | 427 | }, |
432 | 428 | { |
433 | 429 | "cell_type": "code", |
434 | | - "execution_count": 21, |
| 430 | + "execution_count": null, |
435 | 431 | "metadata": {}, |
436 | 432 | "outputs": [ |
437 | 433 | { |
438 | 434 | "name": "stdout", |
439 | 435 | "output_type": "stream", |
440 | 436 | "text": [ |
441 | | - "0.6250\n" |
| 437 | + "0.7050\n" |
442 | 438 | ] |
443 | 439 | }, |
444 | 440 | { |
445 | 441 | "name": "stderr", |
446 | 442 | "output_type": "stream", |
447 | 443 | "text": [ |
448 | | - "/var/folders/ck/0ybgtq694jnd4mbjw_0rm6dh0000gp/T/ipykernel_21314/3256736440.py:153: MatplotlibDeprecationWarning:\n", |
449 | | - "\n", |
450 | | - "The get_cmap function was deprecated in Matplotlib 3.7 and will be removed two minor releases later. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap(obj)`` instead.\n", |
451 | | - "\n" |
| 444 | + "/tmp/ipykernel_89533/2928454690.py:153: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed two minor releases later. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap(obj)`` instead.\n", |
| 445 | + " tab10 = matplotlib.cm.get_cmap(\"tab10\")\n" |
452 | 446 | ] |
453 | 447 | }, |
454 | 448 | { |
|
466 | 460 | " " |
467 | 461 | ], |
468 | 462 | "text/plain": [ |
469 | | - "<IPython.lib.display.IFrame at 0x36881dd80>" |
| 463 | + "<IPython.lib.display.IFrame at 0x757bbfd23670>" |
470 | 464 | ] |
471 | 465 | }, |
472 | 466 | "metadata": {}, |
473 | 467 | "output_type": "display_data" |
474 | 468 | } |
475 | 469 | ], |
476 | 470 | "source": [ |
477 | | - "import embedders\n", |
| 471 | + "import manify\n", |
478 | 472 | "from sklearn.model_selection import train_test_split\n", |
479 | 473 | "\n", |
480 | | - "pm = embedders.manifolds.ProductManifold(signature=[(1, 2)])\n", |
481 | | - "X, y = embedders.gaussian_mixture.gaussian_mixture(pm, 1000, num_classes=4)\n", |
| 474 | + "pm = manify.ProductManifold(signature=[(1, 2)])\n", |
| 475 | + "X, y = pm.gaussian_mixture(1000, num_classes=4)\n", |
482 | 476 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)\n", |
483 | 477 | "\n", |
484 | | - "pdt = embedders.predictors.tree_new.ProductSpaceDT(pm=pm, n_features=\"d_choose_2\", max_depth=5)\n", |
| 478 | + "pdt = manify.ProductSpaceDT(pm=pm, n_features=\"d_choose_2\", max_depth=5)\n", |
485 | 479 | "pdt.fit(X_train, y_train)\n", |
486 | 480 | "\n", |
487 | 481 | "print(f\"{pdt.score(X_test, y_test).float().mean().item():.4f}\")\n", |
488 | 482 | "\n", |
489 | | - "create_dashboard(pdt, X_test, y_test).run_server(debug=True)" |
| 483 | + "create_dashboard(pdt, X_test, y_test).run(debug=True)" |
490 | 484 | ] |
491 | 485 | }, |
492 | 486 | { |
|
499 | 493 | ], |
500 | 494 | "metadata": { |
501 | 495 | "kernelspec": { |
502 | | - "display_name": "base", |
| 496 | + "display_name": "manify", |
503 | 497 | "language": "python", |
504 | 498 | "name": "python3" |
505 | 499 | }, |
|
513 | 507 | "name": "python", |
514 | 508 | "nbconvert_exporter": "python", |
515 | 509 | "pygments_lexer": "ipython3", |
516 | | - "version": "3.10.14" |
| 510 | + "version": "3.10.0" |
517 | 511 | } |
518 | 512 | }, |
519 | 513 | "nbformat": 4, |
|
0 commit comments