Skip to content

Commit d3f6ca4

Browse files
committed
add rtd for gnn
1 parent 223b3d5 commit d3f6ca4

File tree

2 files changed

+343
-57
lines changed

2 files changed

+343
-57
lines changed

docs/source/python/m-surrogate.rst

Lines changed: 343 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,347 @@ The `grid_search.py` and `hyperparameter_tuning.py` modules provide tools for sy
163163
- Visualization of hyperparameter importance
164164
- Selection of optimal model configurations
165165

166-
SECIR Groups Model
167-
------------------
168166

169-
To be added...
167+
168+
Graph Neural Network (GNN) Surrogate Models
169+
--------------------------------------------
170+
171+
The Graph Neural Network (GNN) module provides advanced surrogate models that leverage spatial connectivity and age-stratified epidemiological dynamics. These models are designed for immediate and reliable pandemic response by combining mechanistic expert knowledge with machine learning efficiency.
172+
173+
Overview and Scientific Foundation
174+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
175+
176+
The GNN surrogate models are based on the research presented in:
177+
178+
|Graph_Neural_Network_Surrogates|
179+
180+
The implementation leverages the mechanistic ODE-SECIR model (see :doc:`ODE-SECIR documentation <../models/ode_secir>`) as the underlying expert model, using Python bindings to the C++ backend for efficient simulation during data generation.
181+
182+
Module Structure
183+
~~~~~~~~~~~~~~~~
184+
185+
The GNN module is located in `pycode/memilio-surrogatemodel/memilio/surrogatemodel/GNN <https://github.com/SciCompMod/memilio/tree/main/pycode/memilio-surrogatemodel/memilio/surrogatemodel/GNN>`_ and consists of:
186+
187+
- **data_generation.py**: Generates training and evaluation data by simulating epidemiological scenarios with the mechanistic SECIR model
188+
- **network_architectures.py**: Defines various GNN architectures (GCN, GAT, GIN) with configurable layers and preprocessing
189+
- **evaluate_and_train.py**: Implements training and evaluation pipelines for GNN models
190+
- **grid_search.py**: Provides hyperparameter optimization through systematic grid search
191+
- **GNN_utils.py**: Contains utility functions for data preprocessing, graph construction, and population data handling
192+
193+
Data Generation
194+
~~~~~~~~~~~~~~~
195+
196+
The data generation process in ``data_generation.py`` creates graph-structured training data through mechanistic simulations:
197+
198+
.. code-block:: python
199+
200+
from memilio.surrogatemodel.GNN import data_generation
201+
202+
# Generate training dataset
203+
dataset = data_generation.generate_dataset(
204+
num_runs=1000, # Number of simulation scenarios
205+
num_days=30, # Simulation horizon
206+
num_age_groups=6, # Age stratification
207+
data_dir='path/to/contact_data', # Contact matrices location
208+
mobility_dir='path/to/mobility', # Mobility data location
209+
save_path='gnn_training_data.pickle'
210+
)
211+
212+
**Data Generation Workflow:**
213+
214+
1. **Parameter Sampling**: Randomly sample epidemiological parameters (transmission rates, incubation periods, recovery rates) from predefined distributions to create diverse scenarios.
215+
216+
2. **Compartment Initialization**: Initialize epidemic compartments for each age group in each region based on realistic demographic data. Compartments are initialized using shared base factors.
217+
218+
3. **Mobility Graph Construction**: Build a spatial graph where:
219+
220+
- Nodes represent geographic regions (e.g., German counties)
221+
- Edges represent mobility connections with weights from commuting data
222+
- Node features include age-stratified population sizes
223+
224+
4. **Contact Matrix Configuration**: Load and configure baseline contact matrices for different location types (home, school, work, other) stratified by age groups.
225+
226+
5. **Damping Application**: Apply time-varying dampings to contact matrices to simulate NPIs:
227+
228+
- Multiple damping periods with random start days
229+
- Location-specific damping factors (e.g., stronger school closures, moderate workplace restrictions)
230+
- Realistic parameter ranges based on observed intervention strengths
231+
232+
6. **Simulation Execution**: Run the mechanistic ODE-SECIR model using MEmilio's C++ backend through Python bindings to generate the dataset.
233+
234+
7. **Data Processing**: Transform simulation results into graph-structured format:
235+
236+
- Extract compartment time series for each node (region) and age group
237+
- Apply logarithmic transformation for numerical stability
238+
- Store graph topology, node features, and temporal sequences
239+
240+
Network Architectures
241+
~~~~~~~~~~~~~~~~~~~~~
242+
243+
The ``network_architectures.py`` module provides flexible GNN model construction for different layer types.
244+
245+
.. code-block:: python
246+
247+
from memilio.surrogatemodel.GNN import network_architectures
248+
249+
# Define GNN architecture
250+
model_config = {
251+
'layer_type': 'GCN', # GNN layer type
252+
'num_layers': 3, # Network depth
253+
'hidden_dim': 64, # Hidden layer dimensions
254+
'activation': 'relu', # Activation function
255+
'dropout_rate': 0.2, # Dropout for regularization
256+
'use_batch_norm': True, # Batch normalization
257+
'aggregation': 'mean', # Neighborhood aggregation method
258+
}
259+
260+
# Build model
261+
model = network_architectures.build_gnn_model(
262+
config=model_config,
263+
input_shape=(num_timesteps, num_features),
264+
output_dim=num_compartments * num_age_groups
265+
)
266+
267+
268+
Training and Evaluation
269+
~~~~~~~~~~~~~~~~~~~~~~~
270+
271+
The ``evaluate_and_train.py`` module provides the training functionality:
272+
273+
.. code-block:: python
274+
275+
from memilio.surrogatemodel.GNN import evaluate_and_train
276+
277+
# Load training data
278+
with open('gnn_training_data.pickle', 'rb') as f:
279+
dataset = pickle.load(f)
280+
281+
# Define training configuration
282+
training_config = {
283+
'epochs': 100,
284+
'batch_size': 32,
285+
'learning_rate': 0.001,
286+
'optimizer': 'adam',
287+
'loss_function': 'mse',
288+
'early_stopping_patience': 10,
289+
'validation_split': 0.2
290+
}
291+
292+
# Train model
293+
history = evaluate_and_train.train_gnn_model(
294+
model=model,
295+
dataset=dataset,
296+
config=training_config,
297+
save_weights='best_gnn_model.h5'
298+
)
299+
300+
# Evaluate on test set
301+
metrics = evaluate_and_train.evaluate_model(
302+
model=model,
303+
test_data=test_dataset,
304+
metrics=['mae', 'mape', 'r2']
305+
)
306+
307+
**Training Features:**
308+
309+
1. **Mini-batch Training**: Graph batching for efficient training on large datasets
310+
2. **Custom Loss Functions**: MSE, MAE, MAPE, or custom compartment-weighted losses
311+
3. **Early Stopping**: Monitors validation loss to prevent overfitting
312+
4. **Learning Rate Scheduling**: Adaptive learning rate reduction on plateaus
313+
5. **Save Best Weights**: Saves best model weights based on validation performance
314+
315+
**Evaluation Metrics:**
316+
317+
- **Mean Absolute Error (MAE)**: Average absolute prediction error per compartment
318+
- **Mean Absolute Percentage Error (MAPE)**: Mean absolute error as percentage
319+
- **R² Score**: Coefficient of determination for prediction quality
320+
321+
**Data Splitting:**
322+
323+
- **Training Set (70%)**: For model parameter optimization
324+
- **Validation Set (15%)**: For hyperparameter tuning and early stopping
325+
- **Test Set (15%)**: For final performance evaluation
326+
327+
Hyperparameter Optimization
328+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
329+
330+
The ``grid_search.py`` module enables systematic exploration of hyperparameter space:
331+
332+
.. code-block:: python
333+
334+
from memilio.surrogatemodel.GNN import grid_search
335+
336+
# Define search space
337+
param_grid = {
338+
'layer_type': ['GCN', 'GAT', 'GIN'],
339+
'num_layers': [2, 3, 4, 5],
340+
'hidden_dim': [32, 64, 128, 256],
341+
'learning_rate': [0.001, 0.0005, 0.0001],
342+
'dropout_rate': [0.0, 0.1, 0.2, 0.3],
343+
'batch_size': [16, 32, 64],
344+
'activation': ['relu', 'elu', 'tanh']
345+
}
346+
347+
# Run grid search with cross-validation
348+
results = grid_search.run_hyperparameter_search(
349+
param_grid=param_grid,
350+
data_path='gnn_training_data.pickle',
351+
cv_folds=5,
352+
metric='mae',
353+
save_results='grid_search_results.csv'
354+
)
355+
356+
# Analyze best configuration
357+
best_config = grid_search.get_best_configuration(results)
358+
print(f"Best configuration: {best_config}")
359+
360+
Utility Functions
361+
~~~~~~~~~~~~~~~~~
362+
363+
The ``GNN_utils.py`` module provides essential helper functions used throughout the GNN workflow:
364+
365+
**Data Preprocessing:**
366+
367+
.. code-block:: python
368+
369+
from memilio.surrogatemodel.GNN import GNN_utils
370+
371+
# Remove confirmed compartments (simplify model)
372+
simplified_data = GNN_utils.remove_confirmed_compartments(
373+
dataset_entries=dataset,
374+
num_groups=6
375+
)
376+
377+
# Apply logarithmic scaling
378+
scaled_data = GNN_utils.scale_data(
379+
data=dataset,
380+
method='log',
381+
epsilon=1e-6 # Small constant to avoid log(0)
382+
)
383+
384+
# Load population data
385+
population = GNN_utils.load_population_data(
386+
data_dir='path/to/demographics',
387+
age_groups=[0, 5, 15, 35, 60, 80]
388+
)
389+
390+
**Graph Construction:**
391+
392+
.. code-block:: python
393+
394+
# Create mobility graph from commuting data
395+
graph = GNN_utils.create_mobility_graph(
396+
mobility_dir='path/to/mobility',
397+
num_regions=401, # German counties
398+
county_ids=county_list,
399+
models=models_per_region # SECIR models for each region
400+
)
401+
402+
# Get baseline contact matrix
403+
contact_matrix = GNN_utils.get_baseline_contact_matrix(
404+
data_dir='path/to/contact_matrices'
405+
)
406+
407+
Practical Usage Example
408+
~~~~~~~~~~~~~~~~~~~~~~~
409+
410+
Here is a complete example workflow from data generation to model evaluation:
411+
412+
.. code-block:: python
413+
414+
import pickle
415+
from pathlib import Path
416+
from memilio.surrogatemodel.GNN import (
417+
data_generation,
418+
network_architectures,
419+
evaluate_and_train
420+
)
421+
422+
# Step 1: Generate training data
423+
print("Generating training data...")
424+
dataset = data_generation.generate_dataset(
425+
num_runs=5000,
426+
num_days=30,
427+
num_age_groups=6,
428+
data_dir='/path/to/memilio/data/Germany',
429+
mobility_dir='/path/to/mobility_data',
430+
save_path='gnn_dataset_5000.pickle'
431+
)
432+
433+
# Step 2: Define and build GNN model
434+
print("Building GNN model...")
435+
model_config = {
436+
'layer_type': 'GCN',
437+
'num_layers': 4,
438+
'hidden_dim': 128,
439+
'activation': 'relu',
440+
'dropout_rate': 0.2,
441+
'use_batch_norm': True
442+
}
443+
444+
model = network_architectures.build_gnn_model(
445+
config=model_config,
446+
input_shape=(1, 48), # 6 age groups × 8 compartments
447+
output_dim=48 # Predict all compartments
448+
)
449+
450+
# Step 3: Train the model
451+
print("Training model...")
452+
training_config = {
453+
'epochs': 200,
454+
'batch_size': 32,
455+
'learning_rate': 0.001,
456+
'optimizer': 'adam',
457+
'loss_function': 'mae',
458+
'early_stopping_patience': 20,
459+
'validation_split': 0.2
460+
}
461+
462+
history = evaluate_and_train.train_gnn_model(
463+
model=model,
464+
dataset=dataset,
465+
config=training_config,
466+
save_weights='gnn_weights_best.h5'
467+
)
468+
469+
# Step 4: Evaluate on test data
470+
print("Evaluating model...")
471+
test_metrics = evaluate_and_train.evaluate_model(
472+
model=model,
473+
test_data='gnn_test_data.pickle',
474+
metrics=['mae', 'mape', 'r2']
475+
)
476+
477+
# Print results
478+
print(f"Test MAE: {test_metrics['mae']:.4f}")
479+
print(f"Test MAPE: {test_metrics['mape']:.2f}%")
480+
print(f"Test R²: {test_metrics['r2']:.4f}")
481+
482+
# Step 5: Make predictions on new scenarios
483+
with open('new_scenario.pickle', 'rb') as f:
484+
new_data = pickle.load(f)
485+
486+
predictions = model.predict(new_data)
487+
print(f"Predictions shape: {predictions.shape}")
488+
489+
**GPU Acceleration:**
490+
491+
- TensorFlow automatically uses GPU when available
492+
- Spektral layers are optimized for GPU execution
493+
- Training time can be heavily reduced with appropriate GPU hardware
494+
495+
Additional Resources
496+
~~~~~~~~~~~~~~~~~~~~
497+
498+
**Code and Examples:**
499+
500+
- `GNN Module <https://github.com/SciCompMod/memilio/tree/main/pycode/memilio-surrogatemodel/memilio/surrogatemodel/GNN>`_
501+
- `GNN README <https://github.com/SciCompMod/memilio/blob/main/pycode/memilio-surrogatemodel/memilio/surrogatemodel/GNN/README.md>`_
502+
- `Test Scripts <https://github.com/SciCompMod/memilio/tree/main/pycode/memilio-surrogatemodel/memilio/surrogatemodel_test>`_
503+
504+
**Related Documentation:**
505+
506+
- :doc:`ODE-SECIR Model <../models/ode_secir>`
507+
- :doc:`MEmilio Simulation Package <m-simulation>`
508+
- :doc:`Python Bindings <python_bindings>`
509+

0 commit comments

Comments
 (0)