Skip to content

Commit effb855

Browse files
committed
ICML revisions
1 parent 60cb605 commit effb855

26 files changed

+1468
-678
lines changed

manify/predictors/kappa_gcn.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,11 @@ def fit(
417417
my_tqdm.update(1)
418418
my_tqdm.set_description(f"Epoch {i+1}/{epochs}, Loss: {loss.item():.4f}")
419419

420+
# Early termination for nan loss
421+
if torch.isnan(loss):
422+
print("Loss is NaN, stopping training.")
423+
break
424+
420425
if use_tqdm:
421426
my_tqdm.close()
422427

notebooks/26_benchmark_single_curvature_gaussian.ipynb

Lines changed: 44 additions & 160 deletions
Large diffs are not rendered by default.

notebooks/42_pc_basic_visualization.ipynb

Lines changed: 122 additions & 451 deletions
Large diffs are not rendered by default.

notebooks/49_plotly_dashboard.ipynb

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
},
1313
{
1414
"cell_type": "code",
15-
"execution_count": 17,
15+
"execution_count": 2,
1616
"metadata": {},
1717
"outputs": [],
1818
"source": [
@@ -25,7 +25,7 @@
2525
"import matplotlib\n",
2626
"import matplotlib.cm\n",
2727
"\n",
28-
"from embedders.predictors.tree_new import _angular_greater\n",
28+
"from manify.predictors.decision_tree import _angular_greater\n",
2929
"\n",
3030
"######################\n",
3131
"# HELPER FUNCTIONS\n",
@@ -316,10 +316,9 @@
316316
" color = class2color[cl] # consistent across the entire tree\n",
317317
" idx = ymask == cl\n",
318318
" fig.add_trace(\n",
319-
" # go.Scatter(x=x_in[idx], y=y_in[idx], mode=\"markers\", marker=dict(color=color), name=f\"Class {cl}\")\n",
319+
" # go.Scatter(x=x_in[idx], y=y_in[idx], mode=\"markers\", marker=dict(color=color), name=f\"Class {cl}\")\n",
320320
" go.Scatter(x=y_in[idx], y=x_in[idx], mode=\"markers\", marker=dict(color=color), name=f\"Class {cl}\")\n",
321321
" )\n",
322-
" \n",
323322
"\n",
324323
" # d) Finally set the axis range to the bounding box of the data (so we don't zoom out).\n",
325324
" x_min, x_max = x_all.min(), x_all.max()\n",
@@ -343,7 +342,6 @@
343342
"######################\n",
344343
"\n",
345344
"\n",
346-
"\n",
347345
"def create_dashboard(pdt, X, y):\n",
348346
" node2id, edges = build_edges_and_id_map(pdt)\n",
349347
" node_masks = compute_node_masks(pdt, X, y, node2id)\n",
@@ -403,9 +401,7 @@
403401
" fig = go.Figure()\n",
404402
" y_in_leaf = y[mask]\n",
405403
" if len(y_in_leaf) == 0:\n",
406-
" fig.add_annotation(\n",
407-
" x=0.5, y=0.5, xref=\"paper\", yref=\"paper\", text=\"Empty Leaf Node\", showarrow=False\n",
408-
" )\n",
404+
" fig.add_annotation(x=0.5, y=0.5, xref=\"paper\", yref=\"paper\", text=\"Empty Leaf Node\", showarrow=False)\n",
409405
" else:\n",
410406
" counts = Counter(y_in_leaf.tolist())\n",
411407
" labels = list(counts.keys())\n",
@@ -431,24 +427,22 @@
431427
},
432428
{
433429
"cell_type": "code",
434-
"execution_count": 21,
430+
"execution_count": null,
435431
"metadata": {},
436432
"outputs": [
437433
{
438434
"name": "stdout",
439435
"output_type": "stream",
440436
"text": [
441-
"0.6250\n"
437+
"0.7050\n"
442438
]
443439
},
444440
{
445441
"name": "stderr",
446442
"output_type": "stream",
447443
"text": [
448-
"/var/folders/ck/0ybgtq694jnd4mbjw_0rm6dh0000gp/T/ipykernel_21314/3256736440.py:153: MatplotlibDeprecationWarning:\n",
449-
"\n",
450-
"The get_cmap function was deprecated in Matplotlib 3.7 and will be removed two minor releases later. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap(obj)`` instead.\n",
451-
"\n"
444+
"/tmp/ipykernel_89533/2928454690.py:153: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed two minor releases later. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap(obj)`` instead.\n",
445+
" tab10 = matplotlib.cm.get_cmap(\"tab10\")\n"
452446
]
453447
},
454448
{
@@ -466,27 +460,27 @@
466460
" "
467461
],
468462
"text/plain": [
469-
"<IPython.lib.display.IFrame at 0x36881dd80>"
463+
"<IPython.lib.display.IFrame at 0x757bbfd23670>"
470464
]
471465
},
472466
"metadata": {},
473467
"output_type": "display_data"
474468
}
475469
],
476470
"source": [
477-
"import embedders\n",
471+
"import manify\n",
478472
"from sklearn.model_selection import train_test_split\n",
479473
"\n",
480-
"pm = embedders.manifolds.ProductManifold(signature=[(1, 2)])\n",
481-
"X, y = embedders.gaussian_mixture.gaussian_mixture(pm, 1000, num_classes=4)\n",
474+
"pm = manify.ProductManifold(signature=[(1, 2)])\n",
475+
"X, y = pm.gaussian_mixture(1000, num_classes=4)\n",
482476
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)\n",
483477
"\n",
484-
"pdt = embedders.predictors.tree_new.ProductSpaceDT(pm=pm, n_features=\"d_choose_2\", max_depth=5)\n",
478+
"pdt = manify.ProductSpaceDT(pm=pm, n_features=\"d_choose_2\", max_depth=5)\n",
485479
"pdt.fit(X_train, y_train)\n",
486480
"\n",
487481
"print(f\"{pdt.score(X_test, y_test).float().mean().item():.4f}\")\n",
488482
"\n",
489-
"create_dashboard(pdt, X_test, y_test).run_server(debug=True)"
483+
"create_dashboard(pdt, X_test, y_test).run(debug=True)"
490484
]
491485
},
492486
{
@@ -499,7 +493,7 @@
499493
],
500494
"metadata": {
501495
"kernelspec": {
502-
"display_name": "base",
496+
"display_name": "manify",
503497
"language": "python",
504498
"name": "python3"
505499
},
@@ -513,7 +507,7 @@
513507
"name": "python",
514508
"nbconvert_exporter": "python",
515509
"pygments_lexer": "ipython3",
516-
"version": "3.10.14"
510+
"version": "3.10.0"
517511
}
518512
},
519513
"nbformat": 4,

0 commit comments

Comments
 (0)