Skip to content

Commit 9ea6d63

Browse files
MGN benchmark
1 parent c6c400a commit 9ea6d63

File tree

7 files changed

+8120
-61
lines changed

7 files changed

+8120
-61
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ benchmarks/cfd/gnn_porous_medium/*
178178
!benchmarks/cfd/gnn_porous_medium/gnn_porous_medium.ipynb
179179
!benchmarks/cfd/gnn_porous_medium/_assets
180180
benchmarks/cfd/mgn_cylinder_flow/*
181-
!benchmarks/cfd/mgn_cylinder_flow/gnn_cylinder_flow.ipynb
181+
!benchmarks/cfd/mgn_cylinder_flow/mgn_cylinder_flow.ipynb
182182
!benchmarks/cfd/mgn_cylinder_flow/_assets
183183

184184
# Sphinx docs

benchmarks/cfd/gnn_porous_medium/gnn_porous_medium.ipynb

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,10 @@
139139
" directory.mkdir(parents=True, exist_ok=True)\n",
140140
"\n",
141141
"# Add `notebook_connected` to the default renderer to ensure the figure are rendered in the the documentation\n",
142-
"plotly_renderer = f\"notebook_connected+{pio.renderers.default}\""
142+
"plotly_renderer = f\"notebook_connected+{pio.renderers.default}\"\n",
143+
"\n",
144+
"# Set device\n",
145+
"device_type=\"cuda\" if torch.cuda.is_available() else \"cpu\""
143146
]
144147
},
145148
{
@@ -1670,7 +1673,6 @@
16701673
" sample_file_name = f\"porous_network_graph_{porous_network_id:04d}.pt\"\n",
16711674
" # Set sample file path\n",
16721675
" sample_file_path = directories[\"processed_data\"] / sample_file_name\n",
1673-
" # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
16741676
" # Save graph sample file\n",
16751677
" torch.save(pyg_graph, sample_file_path)\n",
16761678
" # Save graph sample file path\n",
@@ -1817,7 +1819,7 @@
18171819
" dec_node_hidden_activ_type=\"leakyrelu\",\n",
18181820
" dec_node_output_activ_type=\"identity\",\n",
18191821
" # Set device\n",
1820-
" device_type=\"cuda\" if torch.cuda.is_available() else \"cpu\",\n",
1822+
" device_type=device_type,\n",
18211823
")"
18221824
]
18231825
},
@@ -1883,6 +1885,10 @@
18831885
" # Set verbosity\n",
18841886
" is_verbose=True,\n",
18851887
" tqdm_flavor=\"notebook\",\n",
1888+
" # Save loss history\n",
1889+
" save_loss_every=1,\n",
1890+
" # Set device\n",
1891+
" device_type=device_type,\n",
18861892
")\n",
18871893
"\n",
18881894
"# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
@@ -2930,8 +2936,6 @@
29302936
}
29312937
],
29322938
"source": [
2933-
"from plotly import graph_objects as go\n",
2934-
"\n",
29352939
"with open(directories[\"model\"] / \"loss_history_record.pkl\", \"rb\") as f:\n",
29362940
" loss_history_record = pickle.load(f)\n",
29372941
"\n",
@@ -2998,7 +3002,7 @@
29983002
" loss_kwargs={},\n",
29993003
" is_normalized_loss=False,\n",
30003004
" dataset_file_path=test_dataset_file_path,\n",
3001-
" device_type=\"cuda\" if torch.cuda.is_available() else \"cpu\",\n",
3005+
" device_type=device_type,\n",
30023006
" seed=None,\n",
30033007
" is_verbose=True,\n",
30043008
" tqdm_flavor=\"notebook\",\n",
1.95 MB
Binary file not shown.

0 commit comments

Comments
 (0)