Skip to content

Commit 1b90974

Browse files
committed
rework some files
1 parent ec9c932 commit 1b90974

File tree

4 files changed

+169
-210
lines changed

4 files changed

+169
-210
lines changed

docs/source/python/m-surrogate.rst

Lines changed: 115 additions & 149 deletions
Original file line numberDiff line numberDiff line change
@@ -185,28 +185,32 @@ Module Structure
185185
The GNN module is located in `pycode/memilio-surrogatemodel/memilio/surrogatemodel/GNN <https://github.com/SciCompMod/memilio/tree/main/pycode/memilio-surrogatemodel/memilio/surrogatemodel/GNN>`_ and consists of:
186186

187187
- **data_generation.py**: Generates training and evaluation data by simulating epidemiological scenarios with the mechanistic SECIR model
188-
- **network_architectures.py**: Defines various GNN architectures (GCN, GAT, GIN) with configurable layers and preprocessing
188+
- **network_architectures.py**: Defines various GNN architectures (ARMAConv, GCSConv, GATConv, GCNConv, APPNPConv) with configurable depth and channels
189189
- **evaluate_and_train.py**: Implements training and evaluation pipelines for GNN models
190190
- **grid_search.py**: Provides hyperparameter optimization through systematic grid search
191191
- **GNN_utils.py**: Contains utility functions for data preprocessing, graph construction, and population data handling
192192

193193
Data Generation
194194
~~~~~~~~~~~~~~~
195195

196-
The data generation process in ``data_generation.py`` creates graph-structured training data through mechanistic simulations:
196+
The data generation process in ``data_generation.py`` creates graph-structured training data through mechanistic simulations. Use ``generate_data`` to run multiple simulations and persist a pickle with inputs, labels, damping info, and contact matrices:
197197

198198
.. code-block:: python
199199
200200
from memilio.surrogatemodel.GNN import data_generation
201-
202-
# Generate training dataset
203-
dataset = data_generation.generate_dataset(
204-
num_runs=1000, # Number of simulation scenarios
205-
num_days=30, # Simulation horizon
206-
num_age_groups=6, # Age stratification
207-
data_dir='path/to/contact_data', # Contact matrices location
208-
mobility_dir='path/to/mobility', # Mobility data location
209-
save_path='gnn_training_data.pickle'
201+
import memilio.simulation as mio
202+
203+
data = data_generation.generate_data(
204+
num_runs=5,
205+
data_dir="/path/to/memilio/data",
206+
output_path="/tmp/generated_datasets",
207+
input_width=5,
208+
label_width=30,
209+
start_date=mio.Date(2020, 10, 1),
210+
end_date=mio.Date(2021, 10, 31),
211+
mobility_file="commuter_mobility.txt", # or commuter_mobility_2022.txt
212+
transform=True,
213+
save_data=True
210214
)
211215
212216
**Data Generation Workflow:**
@@ -240,28 +244,18 @@ The data generation process in ``data_generation.py`` creates graph-structured t
240244
Network Architectures
241245
~~~~~~~~~~~~~~~~~~~~~
242246

243-
The ``network_architectures.py`` module provides flexible GNN model construction for different layer types.
247+
The ``network_architectures.py`` module provides flexible GNN model construction for supported layer types (ARMAConv, GCSConv, GATConv, GCNConv, APPNPConv).
244248

245249
.. code-block:: python
246250
247251
from memilio.surrogatemodel.GNN import network_architectures
248-
249-
# Define GNN architecture
250-
model_config = {
251-
'layer_type': 'GCN', # GNN layer type
252-
'num_layers': 3, # Network depth
253-
'hidden_dim': 64, # Hidden layer dimensions
254-
'activation': 'relu', # Activation function
255-
'dropout_rate': 0.2, # Dropout for regularization
256-
'use_batch_norm': True, # Batch normalization
257-
'aggregation': 'mean', # Neighborhood aggregation method
258-
}
259-
260-
# Build model
261-
model = network_architectures.build_gnn_model(
262-
config=model_config,
263-
input_shape=(num_timesteps, num_features),
264-
output_dim=num_compartments * num_age_groups
252+
253+
model = network_architectures.get_model(
254+
layer_type="GCNConv",
255+
num_layers=3,
256+
num_channels=64,
257+
activation="relu",
258+
num_output=48 # outputs per node
265259
)
266260
267261
@@ -272,45 +266,42 @@ The ``evaluate_and_train.py`` module provides the training functionality:
272266

273267
.. code-block:: python
274268
275-
from memilio.surrogatemodel.GNN import evaluate_and_train
276-
277-
# Load training data
278-
with open('gnn_training_data.pickle', 'rb') as f:
279-
dataset = pickle.load(f)
280-
281-
# Define training configuration
282-
training_config = {
283-
'epochs': 100,
284-
'batch_size': 32,
285-
'learning_rate': 0.001,
286-
'optimizer': 'adam',
287-
'loss_function': 'mse',
288-
'early_stopping_patience': 10,
289-
'validation_split': 0.2
290-
}
291-
292-
# Train model
293-
history = evaluate_and_train.train_gnn_model(
294-
model=model,
295-
dataset=dataset,
296-
config=training_config,
297-
save_weights='best_gnn_model.h5'
269+
from tensorflow.keras.losses import MeanAbsolutePercentageError
270+
from tensorflow.keras.optimizers import Adam
271+
from memilio.surrogatemodel.GNN import evaluate_and_train, network_architectures
272+
273+
dataset = evaluate_and_train.load_gnn_dataset(
274+
"/tmp/generated_datasets/GNN_data_30days_3dampings_classic5.pickle",
275+
"/path/to/memilio/data/Germany/mobility",
276+
number_of_nodes=400
298277
)
299-
300-
# Evaluate on test set
301-
metrics = evaluate_and_train.evaluate_model(
278+
279+
model = network_architectures.get_model(
280+
layer_type="GCNConv",
281+
num_layers=3,
282+
num_channels=32,
283+
activation="relu",
284+
num_output=48
285+
)
286+
287+
results = evaluate_and_train.train_and_evaluate(
288+
data=dataset,
289+
batch_size=32,
290+
epochs=50,
302291
model=model,
303-
test_data=test_dataset,
304-
metrics=['mae', 'mape', 'r2']
292+
loss_fn=MeanAbsolutePercentageError(),
293+
optimizer=Adam(learning_rate=0.001),
294+
es_patience=10,
295+
save_dir="/tmp/model_results",
296+
save_name="gnn_model"
305297
)
306298
307299
**Training Features:**
308300

309301
1. **Mini-batch Training**: Graph batching for efficient training on large datasets
310302
2. **Custom Loss Functions**: MSE, MAE, MAPE, or custom compartment-weighted losses
311303
3. **Early Stopping**: Monitors validation loss to prevent overfitting
312-
4. **Learning Rate Scheduling**: Adaptive learning rate reduction on plateaus
313-
5. **Save Best Weights**: Saves best model weights based on validation performance
304+
4. **Save Best Weights**: Saves best model weights based on validation performance
314305

315306
**Evaluation Metrics:**
316307

@@ -331,31 +322,31 @@ The ``grid_search.py`` module enables systematic exploration of hyperparameter s
331322

332323
.. code-block:: python
333324
334-
from memilio.surrogatemodel.GNN import grid_search
335-
336-
# Define search space
337-
param_grid = {
338-
'layer_type': ['GCN', 'GAT', 'GIN'],
339-
'num_layers': [2, 3, 4, 5],
340-
'hidden_dim': [32, 64, 128, 256],
341-
'learning_rate': [0.001, 0.0005, 0.0001],
342-
'dropout_rate': [0.0, 0.1, 0.2, 0.3],
343-
'batch_size': [16, 32, 64],
344-
'activation': ['relu', 'elu', 'tanh']
345-
}
346-
347-
# Run grid search with cross-validation
348-
results = grid_search.run_hyperparameter_search(
349-
param_grid=param_grid,
350-
data_path='gnn_training_data.pickle',
351-
cv_folds=5,
352-
metric='mae',
353-
save_results='grid_search_results.csv'
325+
from pathlib import Path
326+
from memilio.surrogatemodel.GNN import grid_search, evaluate_and_train
327+
328+
data = evaluate_and_train.create_dataset(
329+
"/tmp/generated_datasets/GNN_data_30days_3dampings_classic5.pickle",
330+
"/path/to/memilio/data/Germany/mobility",
331+
number_of_nodes=400
332+
)
333+
334+
parameter_grid = grid_search.generate_parameter_grid(
335+
layer_types=["GCNConv", "GATConv"],
336+
num_layers_options=[2, 3],
337+
num_channels_options=[16, 32],
338+
activation_functions=["relu", "elu"]
339+
)
340+
341+
grid_search.perform_grid_search(
342+
data=data,
343+
parameter_grid=parameter_grid,
344+
save_dir=str(Path("/tmp/grid_results")),
345+
batch_size=32,
346+
max_epochs=50,
347+
es_patience=10,
348+
learning_rate=0.001
354349
)
355-
356-
# Analyze best configuration
357-
best_config = grid_search.get_best_configuration(results)
358-
print(f"Best configuration: {best_config}")
359350
360351
Utility Functions
361352
~~~~~~~~~~~~~~~~~
@@ -404,80 +395,55 @@ Here is a complete example workflow from data generation to model evaluation:
404395

405396
.. code-block:: python
406397
407-
import pickle
408-
from pathlib import Path
398+
import memilio.simulation as mio
399+
from tensorflow.keras.losses import MeanAbsolutePercentageError
400+
from tensorflow.keras.optimizers import Adam
409401
from memilio.surrogatemodel.GNN import (
410-
data_generation,
411-
network_architectures,
402+
data_generation,
403+
network_architectures,
412404
evaluate_and_train
413405
)
414-
415-
# Step 1: Generate training data
416-
print("Generating training data...")
417-
dataset = data_generation.generate_dataset(
418-
num_runs=5000,
419-
num_days=30,
420-
num_age_groups=6,
421-
data_dir='/path/to/memilio/data/Germany',
422-
mobility_dir='/path/to/mobility_data',
423-
save_path='gnn_dataset_5000.pickle'
406+
407+
# Step 1: Generate and save training data
408+
data_generation.generate_data(
409+
num_runs=100,
410+
data_dir="/path/to/memilio/data",
411+
output_path="/tmp/generated_datasets",
412+
input_width=5,
413+
label_width=30,
414+
start_date=mio.Date(2020, 10, 1),
415+
end_date=mio.Date(2021, 10, 31),
416+
save_data=True,
417+
mobility_file="commuter_mobility.txt"
424418
)
425-
426-
# Step 2: Define and build GNN model
427-
print("Building GNN model...")
428-
model_config = {
429-
'layer_type': 'GCN',
430-
'num_layers': 4,
431-
'hidden_dim': 128,
432-
'activation': 'relu',
433-
'dropout_rate': 0.2,
434-
'use_batch_norm': True
435-
}
436-
437-
model = network_architectures.build_gnn_model(
438-
config=model_config,
439-
input_shape=(1, 48), # 6 age groups × 8 compartments
440-
output_dim=48 # Predict all compartments
419+
420+
# Step 2: Load dataset and build model
421+
dataset = evaluate_and_train.load_gnn_dataset(
422+
"/tmp/generated_datasets/GNN_data_30days_3dampings_classic100.pickle",
423+
"/path/to/memilio/data/Germany/mobility",
424+
number_of_nodes=400
441425
)
442-
443-
# Step 3: Train the model
444-
print("Training model...")
445-
training_config = {
446-
'epochs': 200,
447-
'batch_size': 32,
448-
'learning_rate': 0.001,
449-
'optimizer': 'adam',
450-
'loss_function': 'mae',
451-
'early_stopping_patience': 20,
452-
'validation_split': 0.2
453-
}
454-
455-
history = evaluate_and_train.train_gnn_model(
456-
model=model,
457-
dataset=dataset,
458-
config=training_config,
459-
save_weights='gnn_weights_best.h5'
426+
427+
model = network_architectures.get_model(
428+
layer_type="GCNConv",
429+
num_layers=4,
430+
num_channels=128,
431+
activation="relu",
432+
num_output=48
460433
)
461-
462-
# Step 4: Evaluate on test data
463-
print("Evaluating model...")
464-
test_metrics = evaluate_and_train.evaluate_model(
434+
435+
# Step 3: Train and evaluate
436+
results = evaluate_and_train.train_and_evaluate(
437+
data=dataset,
438+
batch_size=32,
439+
epochs=100,
465440
model=model,
466-
test_data='gnn_test_data.pickle',
467-
metrics=['mae', 'mape', 'r2']
441+
loss_fn=MeanAbsolutePercentageError(),
442+
optimizer=Adam(learning_rate=0.001),
443+
es_patience=20,
444+
save_dir="/tmp/model_results",
445+
save_name="gnn_weights_best"
468446
)
469-
470-
# Print results
471-
print(f"Test MAE: {test_metrics['mae']:.4f}")
472-
print(f"Test MAPE: {test_metrics['mape']:.2f}%")
473-
print(f"Test R²: {test_metrics['r2']:.4f}")
474-
475-
# Step 5: Make predictions on new scenarios
476-
with open('new_scenario.pickle', 'rb') as f:
477-
new_data = pickle.load(f)
478-
479-
predictions = model.predict(new_data)
480-
print(f"Predictions shape: {predictions.shape}")
481447
482448
**GPU Acceleration:**
483449

pycode/memilio-surrogatemodel/memilio/surrogatemodel/GNN/GNN_utils.py

Lines changed: 11 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -175,56 +175,22 @@ def scale_data(
175175
if not np.issubdtype(labels_array.dtype, np.number):
176176
raise ValueError("Label data must be numeric.")
177177

178-
# Calculate number of age groups from data shape
179-
num_groups = int(inputs_array.shape[-1] / num_compartments)
180-
181178
# Initialize transformer (log1p for numerical stability)
182179
transformer = FunctionTransformer(np.log1p, validate=True)
183180

184-
# Process inputs
185-
# Reshape: [samples, timesteps, nodes, features] -> [nodes, samples, timesteps, features]
186-
# -> [nodes * compartments, samples * timesteps * age_groups]
187-
inputs_reshaped = inputs_array.transpose(2, 0, 1, 3).reshape(
188-
num_groups * num_compartments, -1
189-
)
190-
191181
if transform:
192-
inputs_transformed = transformer.transform(inputs_reshaped)
182+
inputs_scaled = transformer.transform(
183+
inputs_array.reshape(-1, inputs_array.shape[-1])
184+
).reshape(inputs_array.shape)
185+
labels_scaled = transformer.transform(
186+
labels_array.reshape(-1, labels_array.shape[-1])
187+
).reshape(labels_array.shape)
193188
else:
194-
inputs_transformed = inputs_reshaped
195-
196-
original_shape_input = inputs_array.shape
197-
198-
# Reverse reshape to separate dimensions
199-
inputs_back = inputs_transformed.reshape(
200-
original_shape_input[2],
201-
original_shape_input[0],
202-
original_shape_input[1],
203-
original_shape_input[3]
204-
)
205-
206-
# Reverse transpose and reorder to [samples, features, timesteps, nodes]
207-
scaled_inputs = inputs_back.transpose(1, 2, 0, 3).transpose(0, 3, 1, 2)
208-
209-
# Process labels with same procedure
210-
labels_reshaped = labels_array.transpose(2, 0, 1, 3).reshape(
211-
num_groups * num_compartments, -1
212-
)
213-
214-
if transform:
215-
labels_transformed = transformer.transform(labels_reshaped)
216-
else:
217-
labels_transformed = labels_reshaped
218-
219-
original_shape_labels = labels_array.shape
220-
221-
labels_back = labels_transformed.reshape(
222-
original_shape_labels[2],
223-
original_shape_labels[0],
224-
original_shape_labels[1],
225-
original_shape_labels[3]
226-
)
189+
inputs_scaled = inputs_array
190+
labels_scaled = labels_array
227191

228-
scaled_labels = labels_back.transpose(1, 2, 0, 3).transpose(0, 3, 1, 2)
192+
# Reorder to [samples, features, timesteps, nodes]
193+
scaled_inputs = inputs_scaled.transpose(0, 3, 1, 2)
194+
scaled_labels = labels_scaled.transpose(0, 3, 1, 2)
229195

230196
return scaled_inputs, scaled_labels

0 commit comments

Comments
 (0)