|
139 | 139 | " directory.mkdir(parents=True, exist_ok=True)\n", |
140 | 140 | "\n", |
141 | 141 | "# Add `notebook_connected` to the default renderer to ensure the figure are rendered in the the documentation\n", |
142 | | - "plotly_renderer = f\"notebook_connected+{pio.renderers.default}\"" |
| 142 | + "plotly_renderer = f\"notebook_connected+{pio.renderers.default}\"\n", |
| 143 | + "\n", |
| 144 | + "# Set device\n", |
| 145 | + "device_type=\"cuda\" if torch.cuda.is_available() else \"cpu\"" |
143 | 146 | ] |
144 | 147 | }, |
145 | 148 | { |
|
1670 | 1673 | " sample_file_name = f\"porous_network_graph_{porous_network_id:04d}.pt\"\n", |
1671 | 1674 | " # Set sample file path\n", |
1672 | 1675 | " sample_file_path = directories[\"processed_data\"] / sample_file_name\n", |
1673 | | - " # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", |
1674 | 1676 | " # Save graph sample file\n", |
1675 | 1677 | " torch.save(pyg_graph, sample_file_path)\n", |
1676 | 1678 | " # Save graph sample file path\n", |
|
1817 | 1819 | " dec_node_hidden_activ_type=\"leakyrelu\",\n", |
1818 | 1820 | " dec_node_output_activ_type=\"identity\",\n", |
1819 | 1821 | " # Set device\n", |
1820 | | - " device_type=\"cuda\" if torch.cuda.is_available() else \"cpu\",\n", |
| 1822 | + " device_type=device_type,\n", |
1821 | 1823 | ")" |
1822 | 1824 | ] |
1823 | 1825 | }, |
|
1883 | 1885 | " # Set verbosity\n", |
1884 | 1886 | " is_verbose=True,\n", |
1885 | 1887 | " tqdm_flavor=\"notebook\",\n", |
| 1888 | + " # Save loss history\n", |
| 1889 | + " save_loss_every=1,\n", |
| 1890 | + " # Set device\n", |
| 1891 | + " device_type=device_type,\n", |
1886 | 1892 | ")\n", |
1887 | 1893 | "\n", |
1888 | 1894 | "# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", |
|
2930 | 2936 | } |
2931 | 2937 | ], |
2932 | 2938 | "source": [ |
2933 | | - "from plotly import graph_objects as go\n", |
2934 | | - "\n", |
2935 | 2939 | "with open(directories[\"model\"] / \"loss_history_record.pkl\", \"rb\") as f:\n", |
2936 | 2940 | " loss_history_record = pickle.load(f)\n", |
2937 | 2941 | "\n", |
|
2998 | 3002 | " loss_kwargs={},\n", |
2999 | 3003 | " is_normalized_loss=False,\n", |
3000 | 3004 | " dataset_file_path=test_dataset_file_path,\n", |
3001 | | - " device_type=\"cuda\" if torch.cuda.is_available() else \"cpu\",\n", |
| 3005 | + " device_type=device_type,\n", |
3002 | 3006 | " seed=None,\n", |
3003 | 3007 | " is_verbose=True,\n", |
3004 | 3008 | " tqdm_flavor=\"notebook\",\n", |
|
0 commit comments