Skip to content

Commit 223b3d5

Browse files
committed
.
1 parent d177e4c commit 223b3d5

File tree

3 files changed

+354
-127
lines changed

3 files changed

+354
-127
lines changed

pycode/memilio-surrogatemodel/memilio/surrogatemodel/GNN/grid_search.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
optimal configurations for epidemiological surrogate models.
2626
"""
2727

28-
import os
2928
from pathlib import Path
3029

3130
import pandas as pd
@@ -127,6 +126,9 @@ def perform_grid_search(
127126
raise ValueError(
128127
"Parameter grid must contain at least one configuration.")
129128

129+
# Convert save_dir to Path if it's a string
130+
save_dir = Path(save_dir) / "saves"
131+
130132
# Determine output dimension from data
131133
output_dim = data[0].y.shape[-1]
132134

@@ -230,10 +232,10 @@ def perform_grid_search(
230232
# Save intermediate results after each configuration
231233
results_df.to_csv(results_file, index=False)
232234
print(
233-
f"Configuration complete. Results saved to {results_file}")
235+
f"Configuration complete. Results saved to {results_file}")
234236

235237
except Exception as e:
236-
print(f"Error training configuration: {e}")
238+
print(f"Error training configuration: {e}")
237239
# Continue with next configuration rather than failing entire search
238240

239241
finally:
@@ -248,7 +250,7 @@ def perform_grid_search(
248250

249251
# Print best configuration
250252
if len(results_df) > 0:
251-
best_idx = results_df['mean_val_loss'].idxmin()
253+
best_idx = results_df['mean_validation_loss'].idxmin()
252254
best_config = results_df.loc[best_idx]
253255
print(f"\nBest Configuration:")
254256
print(f" Model: {best_config['model']}")

0 commit comments

Comments
 (0)