|
505 | 505 | }, |
506 | 506 | { |
507 | 507 | "cell_type": "code", |
508 | | - "execution_count": 9, |
| 508 | + "execution_count": null, |
509 | 509 | "id": "7ba18357", |
510 | 510 | "metadata": {}, |
511 | 511 | "outputs": [ |
|
544 | 544 | "from kooplearn.jax.nn import spectral_contrastive_loss, vamp_loss\n", |
545 | 545 | "\n", |
546 | 546 | "_vamp_loss = partial(vamp_loss, center_covariances=False)\n", |
547 | | - "for name, criterion in zip([\"VAMPNets\", \"Spectral Contrastive Loss\"], [ _vamp_loss, spectral_contrastive_loss]):\n", |
| 547 | + "for name, criterion in zip([\"VAMPNets\", \"Spectral Contrastive Loss\"],\n", |
| 548 | + " [ _vamp_loss, spectral_contrastive_loss]):\n", |
548 | 549 | " print(f\"Fitting {name}\")\n", |
549 | 550 | " trained_models[name] = train_encoder_only(criterion)" |
550 | 551 | ] |
|
784 | 785 | }, |
785 | 786 | { |
786 | 787 | "cell_type": "code", |
787 | | - "execution_count": 13, |
| 788 | + "execution_count": null, |
788 | 789 | "id": "9a43f6cd", |
789 | 790 | "metadata": {}, |
790 | 791 | "outputs": [ |
|
802 | 803 | "source": [ |
803 | 804 | "nun_models = len(report)\n", |
804 | 805 | "num_cols = len(report['Linear']['times'])\n", |
805 | | - "fig, axes = plt.subplots(nun_models, num_cols, figsize=(num_cols, nun_models), sharex=True, sharey=True)\n", |
| 806 | + "fig, axes = plt.subplots(\n", |
| 807 | + " nun_models, num_cols, figsize=(num_cols, nun_models), sharex=True, sharey=True\n", |
| 808 | + " )\n", |
806 | 809 | "\n", |
807 | 810 | "test_seed_idx = 0\n", |
808 | 811 | "# Remove margins between columns\n", |
|
815 | 818 | " ax.set_axis_off()\n", |
816 | 819 | " for prediction_step in range(num_cols - 1):\n", |
817 | 820 | " pred_label = report[model_name]['label'][prediction_step][test_seed_idx]\n", |
818 | | - " true_label = (test_labels[test_seed_idx] + report[model_name]['times'][prediction_step])%num_digits\n", |
| 821 | + " true_label = (\n", |
| 822 | + " test_labels[test_seed_idx] + report[model_name]['times'][prediction_step]\n", |
| 823 | + " )%num_digits\n", |
819 | 824 | " img = report[model_name]['image'][prediction_step][test_seed_idx]\n", |
820 | 825 | " logit = report[model_name]['logits'][prediction_step][test_seed_idx]\n", |
821 | 826 | "\n", |
|
848 | 853 | "\n", |
849 | 854 | "# Display the model names on the left of each row\n", |
850 | 855 | "for model_idx, model_name in enumerate(report.keys()):\n", |
851 | | - " axes[model_idx, 0].text(-0.1, 0.5, model_name.replace('_', ' '), fontsize=14, ha='right', va='center', transform=axes[model_idx, 0].transAxes)\n", |
| 856 | + " axes[model_idx, 0].text(\n", |
| 857 | + " -0.1,\n", |
| 858 | + " 0.5,\n", |
| 859 | + " model_name.replace('_', ' '),\n", |
| 860 | + " fontsize=14,\n", |
| 861 | + " ha='right',\n", |
| 862 | + " va='center',\n", |
| 863 | + " transform=axes[model_idx, 0].transAxes\n", |
| 864 | + " )\n", |
852 | 865 | "\n", |
853 | 866 | "for class_idx in range(num_cols):\n", |
854 | 867 | " title = (test_labels[test_seed_idx] + class_idx)%num_digits\n", |
|
879 | 892 | }, |
880 | 893 | { |
881 | 894 | "cell_type": "code", |
882 | | - "execution_count": 14, |
| 895 | + "execution_count": null, |
883 | 896 | "id": "617e1987", |
884 | 897 | "metadata": {}, |
885 | 898 | "outputs": [ |
|
906 | 919 | " ax.title.set_text(model_name.replace('_', ' '))\n", |
907 | 920 | " fitted_model = trained_models[model_name]['model']\n", |
908 | 921 | " embedder = trained_models[model_name]['embedder']\n", |
909 | | - " vals, lfuncs, rfuncs = fitted_model.eig(eval_right_on=embedder.transform(test_data), eval_left_on=embedder.transform(test_data))\n", |
| 922 | + " vals, lfuncs, rfuncs = fitted_model.eig(\n", |
| 923 | + " eval_right_on=embedder.transform(test_data),\n", |
| 924 | + " eval_left_on=embedder.transform(test_data)\n", |
| 925 | + " )\n", |
910 | 926 | "\n", |
911 | | - " unique_vals, idx_start = np.unique(np.abs(vals), return_index=True) # returns the unique values and the index of the first occurrence of a value\n", |
| 927 | + " unique_vals, idx_start = np.unique(np.abs(vals), return_index=True) # returns the unique values\n", |
| 928 | + " # and the index of the first occurrence of a value\n", |
912 | 929 | "\n", |
913 | 930 | " vals, lfuncs, rfuncs = vals[idx_start], lfuncs[:, idx_start], rfuncs[:, idx_start]\n", |
914 | 931 | " top_vals, top_indices = stable_topk(np.abs(vals), 2)\n", |
|
923 | 940 | "\n", |
924 | 941 | "# remove last axis and add legend\n", |
925 | 942 | "ax = axes[n_models-1]\n", |
926 | | - "legend = ax.legend(*scatter.legend_elements(num=4), title=\"Digits\", frameon=True, bbox_to_anchor=(1.3, 1))\n", |
| 943 | + "legend = ax.legend(\n", |
| 944 | + " *scatter.legend_elements(num=4), title=\"Digits\", frameon=True, bbox_to_anchor=(1.3, 1)\n", |
| 945 | + " )\n", |
927 | 946 | "ax.add_artist(legend)\n", |
928 | 947 | "fig.delaxes(axes[n_models])\n", |
929 | 948 | "\n", |
|
943 | 962 | ], |
944 | 963 | "metadata": { |
945 | 964 | "kernelspec": { |
946 | | - "display_name": "kooplearn", |
| 965 | + "display_name": ".venv", |
947 | 966 | "language": "python", |
948 | 967 | "name": "python3" |
949 | 968 | }, |
|
957 | 976 | "name": "python", |
958 | 977 | "nbconvert_exporter": "python", |
959 | 978 | "pygments_lexer": "ipython3", |
960 | | - "version": "3.13.1" |
| 979 | + "version": "3.13.7" |
961 | 980 | } |
962 | 981 | }, |
963 | 982 | "nbformat": 4, |
|
0 commit comments