Skip to content

Commit 4edc0cc

Browse files
committed
Adjustments to SIRS_hybrid notebook
1 parent 95da4b9 commit 4edc0cc

File tree

5 files changed

+406
-225
lines changed

5 files changed

+406
-225
lines changed

models/SIRS/SIRS_demo.ipynb

Lines changed: 238 additions & 224 deletions
Large diffs are not rendered by default.

models/SIRS/SIRS_hybrid_experiments.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@
179179
"cell_type": "markdown",
180180
"source": [
181181
"### Parameters only\n",
182-
"This is the first level of hybridisation."
182+
"This is the most basic level of hybridisation."
183183
],
184184
"id": "9201ffe6ae2deacb"
185185
},

models/SIRS/hybrid_models/sweep.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
import argparse
2+
import numpy
3+
import os
4+
import pickle
5+
import torch
6+
import tqdm
7+
8+
from typing import Literal, Union
9+
10+
# Local imports
11+
from parameters_only import get_params_NN
12+
from hybrid_1 import get_hybrid_1_NN
13+
from hybrid_2 import get_hybrid_2_NN
14+
from black_box import get_bb_NN
15+
from utils import epoch
16+
17+
def train(*,
18+
L: int = 30,
19+
n: int = 1,
20+
key: Literal['params', 'hybrid_1', 'hybrid_2', 'bb'] = 'params',
21+
seed: int = None,
22+
dt: Union[float, torch.Tensor] = 0.2,
23+
training_data: str,
24+
recursive: bool = False,
25+
out_dir: str = None,
26+
N_epochs: int = None
27+
):
28+
# Set the seed, if passed
29+
if seed is not None:
30+
numpy.random.seed(seed)
31+
torch.manual_seed(seed)
32+
33+
# Load the training data and set the initial condition
34+
Y = torch.load(training_data, weights_only=True)
35+
y0 = Y[:, 0, :]
36+
37+
# Set up a dictionary with all the required data
38+
data = {
39+
'Y_target': Y[:n],
40+
'X_input': Y[:n, :(L+1)].flatten(start_dim=1),
41+
'loss': []
42+
}
43+
44+
# Add dataset identifier if sweeping over more than one training dataset
45+
if n > 1:
46+
if recursive:
47+
data['Y_input'] = Y[:n, :L]
48+
data['z'] = Y[:n, :3, :2].flatten(start_dim=1)
49+
else:
50+
data['Y_input'] = torch.cat([
51+
Y[:n, :L, :], Y[:n, :3, :2].flatten(start_dim=1)[:, None, :].repeat(1, L, 1)
52+
], dim=2)
53+
else:
54+
data['Y_input'] = Y[:n, :(L+1)]
55+
56+
# Get the neural network
57+
if key == 'params':
58+
data['NN'] = get_params_NN(input_size=data['X_input'].shape[1], z=6 if n>1 else 0)
59+
elif key == 'hybrid_1':
60+
data['NN'] = get_hybrid_1_NN(input_size=data['X_input'].shape[1], z=6 if n>1 else 0)
61+
elif key == 'hybrid_2':
62+
data['NN'] = get_hybrid_2_NN(z=6 if n>1 else 0)
63+
elif key == 'bb':
64+
data['NN'] = get_bb_NN(z=6 if n>1 else 0)
65+
66+
# Save the trained network and loss evolution to a folder
67+
if out_dir is not None:
68+
path_name = f"{key}__n_{n}__L_{L}"
69+
if recursive:
70+
path_name += "__recursive"
71+
if seed is not None:
72+
path_name += f"__seed_{seed}"
73+
path_name = os.path.expanduser(os.path.join(out_dir, path_name))
74+
os.makedirs(path_name, exist_ok=True)
75+
if key !='hybrid_1':
76+
data['NN'].load_state_dict(torch.load(f"{path_name}/NN.pt", weights_only=True))
77+
data['NN'].eval()
78+
else:
79+
data['NN']['const_params'].load_state_dict(torch.load(f"{path_name}/NN_const_params.pt", weights_only=True))
80+
data['NN']['time_dep_params'].load_state_dict(torch.load(f"{path_name}/NN_time_dep_params.pt", weights_only=True))
81+
data['NN']['const_params'].eval()
82+
data['NN']['time_dep_params'].eval()
83+
with open(f"{path_name}/loss.pickle", "rb") as f:
84+
loss = pickle.load(f)
85+
86+
# Train for N_epochs
87+
if N_epochs is None:
88+
N_epochs = 10000 if key == 'params' else 20000
89+
90+
N_epochs -= len(loss)
91+
print(f'Remaining: {N_epochs}')
92+
for i in tqdm.tqdm(range(N_epochs)):
93+
epoch(key=key, NN=data['NN'], X_input=data['X_input'], Y_target=data['Y_target'][:, :L], Y_input=data['Y_input'],
94+
dt=dt, y0=y0, t_span=(0, (L-1)*dt), recursive=recursive, z=data.get('z', None), loss_array=data['loss'])
95+
96+
# Store the results every 100 epochs
97+
if ((i > 0 and i % 100 == 0) or i == N_epochs-1) and out_dir is not None:
98+
if key != 'hybrid_1':
99+
torch.save(data['NN'].state_dict(), f"{path_name}/NN.pt")
100+
else:
101+
torch.save(data['NN']['const_params'].state_dict(), f"{path_name}/NN_const_params.pt")
102+
torch.save(data['NN']['time_dep_params'].state_dict(), f"{path_name}/NN_time_dep_params.pt")
103+
with open(f"{path_name}/loss.pickle", "wb") as f:
104+
pickle.dump(data['loss'], f)
105+
106+
if __name__ == "__main__":
107+
parser = argparse.ArgumentParser()
108+
parser.add_argument("--L", type=int, default=30, help="Length of training time series")
109+
parser.add_argument("--n", type=int, default=1, help="Number of training datasets to use")
110+
parser.add_argument("--key", type=str, default='params', help="Model to use")
111+
parser.add_argument("--training_data", type=str, help="Path to training data")
112+
parser.add_argument("--recursive", action="store_true", help="Whether to generate predictions recursively")
113+
parser.add_argument("--seed", type=int, default=None, help="Set the seed")
114+
parser.add_argument("--N_epochs", type=int, default=None, help="Number of training epochs")
115+
parser.add_argument("--out_dir", type=str, help="Output directory")
116+
args = parser.parse_args()
117+
118+
train(key=args.key, n=args.n, L=args.L, training_data=args.training_data, recursive=args.recursive, seed=args.seed,
119+
N_epochs=args.N_epochs, out_dir=args.out_dir)
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#!/bin/bash
2+
for seed in {0..9}; do
3+
for L in 20 30 40 50 60 70 80; do
4+
for key in params hybrid_1 hybrid_2 bb; do
5+
for recursive in true false; do
6+
if [ "$key" = "params" ] && [ "$recursive" = true ]; then
7+
continue
8+
fi
9+
if [ "$recursive" = true ]; then
10+
rec_flag="--recursive"
11+
else
12+
rec_flag=""
13+
fi
14+
sbatch \
15+
-p <partition_name> \
16+
-N 1 \
17+
--ntasks=1 \
18+
--output=logs/slurm-%A.out \
19+
--job-name="${key}__L_${L}__rec_${recursive}__seed_${seed}" \
20+
--wrap="python sweep.py --key ${key} --n 1 --L ${L} ${rec_flag} --training_data 'data/SIRS/hybrid_training_data.pt' --seed ${seed} --out_dir='~/SIRS_sweep/'"
21+
done
22+
done
23+
done
24+
done
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#!/bin/bash
2+
for seed in {0..9}; do
3+
for n in 2 4 6 8 10 12 14; do
4+
for key in params hybrid_1 hybrid_2 bb; do
5+
for recursive in true false; do
6+
if [ "$key" = "params" ] && [ "$recursive" = true ]; then
7+
continue
8+
fi
9+
if [ "$recursive" = true ]; then
10+
rec_flag="--recursive"
11+
else
12+
rec_flag=""
13+
fi
14+
sbatch \
15+
-p <partition_name> \
16+
-N 1 \
17+
--ntasks=1 \
18+
--output=logs/slurm-%A.out \
19+
--job-name="${key}__n_${n}__rec_${recursive}__seed_${seed}" \
20+
--wrap="python sweep.py --key ${key} --n ${n} --L 101 ${rec_flag} --training_data 'data/SIRS/hybrid_training_data.pt' --seed ${seed} --out_dir='~/SIRS_sweep/'"
21+
done
22+
done
23+
done
24+
done

0 commit comments

Comments
 (0)