Skip to content

Commit 58399fa

Browse files
committed
Remove RNN from SIRS_demo notebook
1 parent a2e91af commit 58399fa

File tree

11 files changed

+571
-258
lines changed

11 files changed

+571
-258
lines changed

include/graph.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import h5py as h5
22
import networkx as nx
33
import numpy as np
4+
from typing import Literal
45

56
""" Network generation function """
67

@@ -9,7 +10,7 @@ def generate_graph(
910
*,
1011
N: int,
1112
mean_degree: int = None,
12-
type: str,
13+
type: Literal["random", "BarabasiAlbert", "BollobasRiordan", "WattsStrogatz", "Star", "Regular"],
1314
seed: int = None,
1415
graph_props: dict = None,
1516
) -> nx.Graph:
@@ -22,6 +23,7 @@ def generate_graph(
2223
:param seed: the random seed to use for the graph generation (ensuring the graphs are always the same)
2324
:param graph_props: dictionary containing the type-specific parameters
2425
:return: the networkx graph object. All graphs are fully connected
26+
TODO: graph can have isolated components!
2527
"""
2628

2729
def _connect_isolates(G: nx.Graph) -> nx.Graph:

models/Kuramoto/Kuramoto_demo.ipynb

Lines changed: 181 additions & 0 deletions
Large diffs are not rendered by default.

models/Kuramoto/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from model import Kuramoto_euler, Kuramoto_dopri5, Kuramoto_rk4_adj

models/Kuramoto/ensemble_training/NN.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77
import torch
88
from dantro import logging
99
from dantro._import_tools import import_module_from_path
10+
from typing import Any
1011

1112
sys.path.append(up(up(__file__)))
1213
sys.path.append(up(up(up(__file__))))
1314

14-
Kuramoto = import_module_from_path(mod_path=up(up(__file__)), mod_str="Kuramoto")
15+
Kuramoto = import_module_from_path(mod_path=up(__file__), mod_str="ensemble_training")
1516
base = import_module_from_path(mod_path=up(up(up(__file__))), mod_str="include")
1617

1718
log = logging.getLogger(__name__)
@@ -23,7 +24,7 @@ def __init__(
2324
*,
2425
rng: np.random.Generator,
2526
output_data_group: h5.Group,
26-
neural_net: base.NeuralNet,
27+
neural_net: Any,
2728
loss_function: dict,
2829
ABM: Kuramoto.Kuramoto_ABM,
2930
true_network: torch.Tensor = None,

models/Kuramoto/ensemble_training/run.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,8 @@
1111
from dantro import logging
1212
from dantro._import_tools import import_module_from_path
1313

14-
sys.path.append(up(up(__file__)))
15-
sys.path.append(up(up(up(__file__))))
16-
17-
Kuramoto = import_module_from_path(mod_path=up(up(__file__)), mod_str="Kuramoto")
14+
sys.path.extend([up(up(__file__)), up(up(up(__file__)))])
15+
Kuramoto = import_module_from_path(mod_path=up(__file__), mod_str="ensemble_training")
1816
base = import_module_from_path(mod_path=up(up(up(__file__))), mod_str="include")
1917

2018
log = logging.getLogger(__name__)
@@ -88,7 +86,7 @@
8886
f" Initializing the neural net; input size: {num_agents}, output size: {output_size} ..."
8987
)
9088

91-
net = base.NeuralNet(
89+
net = base.FeedForwardNN(
9290
input_size=num_agents, output_size=output_size, **model_cfg["NeuralNet"]
9391
).to(device)
9492

models/Kuramoto/model.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
"""
2+
This file contains numerical solvers for the Kuramoto model. You can add your own solvers here (write your own decorator
3+
and add it to include.solvers if you wish).
4+
"""
5+
from typing import Union
6+
import torch
7+
8+
# Import the solver module (located in `include`)
9+
import sys
10+
from os.path import dirname as up
11+
from dantro._import_tools import import_module_from_path
12+
sys.path.append(up(up(__file__)))
13+
include = import_module_from_path(mod_path=up(up(__file__)), mod_str="include")
14+
from include.solvers import torchdiffeq_solver
15+
16+
17+
def Kuramoto_rhs(
18+
t,
19+
state,
20+
*,
21+
adjacency_matrix: torch.Tensor,
22+
eigen_frequencies: torch.Tensor,
23+
kappa: Union[torch.Tensor, float],
24+
beta: Union[torch.Tensor, float],
25+
alpha: Union[torch.Tensor, float] = 0.0,
26+
):
27+
"""
28+
Compute the right-hand side (RHS) of the Kuramoto model.
29+
30+
Parameters
31+
----------
32+
t : float or torch.Tensor
33+
Current time (unused in the autonomous Kuramoto equations, but kept
34+
for compatibility with generic ODE solver decorators).
35+
state : torch.Tensor
36+
Current system state.
37+
- If `alpha == 0` (first-order dynamics): shape (N, 1),
38+
containing oscillator phases θ_i.
39+
- If `alpha != 0` (second-order dynamics): shape (N, 2),
40+
with [:, 0] = phases θ_i and [:, 1] = velocities dθ_i/dt.
41+
adjacency_matrix : torch.Tensor, shape (N, N)
42+
Coupling matrix describing network connections.
43+
eigen_frequencies : torch.Tensor, shape (N, 1)
44+
Natural frequencies ω_i of the oscillators.
45+
kappa : float
46+
Global coupling strength.
47+
alpha : float
48+
Inertia parameter. If 0, reduces to the standard first-order
49+
Kuramoto model. Otherwise, adds a second-order (inertial) term.
50+
beta : float
51+
Damping parameter. Used in both first- and second-order cases.
52+
sigma : float
53+
Noise strength (unused here — stochastic increments should be
54+
added at the solver level, not inside the deterministic RHS).
55+
device : torch.device
56+
Torch device for tensor operations.
57+
58+
Returns
59+
-------
60+
torch.Tensor
61+
Derivatives of the state with the same shape as `state`:
62+
- First-order (alpha == 0): shape (N, 1), dθ/dt
63+
- Second-order (alpha != 0): shape (N, 2),
64+
[dθ/dt, d²θ/dt²]
65+
"""
66+
phases = state[:, 0]
67+
if alpha != 0: # second-order case
68+
velocities = state[:, 1]
69+
else:
70+
velocities = None
71+
72+
# Pairwise phase differences
73+
diffs = torch.sin(phases - phases.reshape((len(phases),)))
74+
75+
# Coupling contribution
76+
coupling = torch.matmul(kappa * adjacency_matrix, diffs).diag()
77+
78+
if alpha == 0:
79+
# First-order Kuramoto
80+
dtheta = (eigen_frequencies.squeeze() + coupling) / beta
81+
return torch.stack([dtheta], dim=1)
82+
else:
83+
# Second-order Kuramoto
84+
dtheta = velocities
85+
dvel = (eigen_frequencies.squeeze() + coupling - beta * velocities) / alpha
86+
return torch.stack([dtheta, dvel], dim=1)
87+
88+
# Euler solver
89+
@torchdiffeq_solver(method="euler", adjoint=False)
90+
def Kuramoto_euler(t, state, **params):
91+
return Kuramoto_rhs(t, state, **params)
92+
93+
94+
# Dopri5 solver
95+
@torchdiffeq_solver(method="dopri5", adjoint=False)
96+
def Kuramoto_dopri5(t, state, **params):
97+
return Kuramoto_rhs(t, state, **params)
98+
99+
100+
# Runge-Kutta 4th order solver
101+
@torchdiffeq_solver(method="rk4", adjoint=False)
102+
def Kuramoto_rk4_adj(t, state, **params):
103+
return Kuramoto_rhs(t, state, **params)
104+

models/Neurotransmission/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def initial_condition(*, n_v: torch.Tensor,
7979

8080
return init_state
8181

82+
# TODO: use solvers from solvers.py
8283
def solve_ODE(*, init_state: torch.Tensor,
8384
k_R: torch.Tensor,
8485
k_F: torch.Tensor,

models/SIRS/SIRS_demo.ipynb

Lines changed: 160 additions & 244 deletions
Large diffs are not rendered by default.

models/SIRS/ensemble_training/SIRS_cfg.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ Data:
2929
parameters:
3030
lower: 0.0
3131
upper: 0.3
32-
noise: 0.03 # Multiplicative noise variance
32+
noise: 0.025 # Multiplicative noise variance
3333

3434
# Time range; together with dt this determines the number of time steps
3535
t_span: [0, 50]

models/SIRS/ensemble_training/cfgs/test_performance/eval.yml

Lines changed: 111 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@ loss:
1717
tag: data
1818
x: n_train
1919
y: mean
20+
col: kind
2021
yerr: std
21-
hue: kind
22+
hue: noise
23+
sharey: False
2224

2325
# True and predicted data
2426
data:
@@ -47,4 +49,111 @@ data:
4749
row: n_train
4850
hue: Compartment
4951
scatter_kwargs:
50-
s: 0.1
52+
s: 0.1
53+
54+
SIRS_accuracy_2:
55+
add_legend: false
56+
col: kind
57+
creator: multiverse
58+
dag_options:
59+
define:
60+
_page_width: 7.5
61+
c_darkblue: '#2F7194'
62+
c_darkgreen: '#48675A'
63+
c_darkgrey: '#3D4244'
64+
c_lightblue: '#97c3d0'
65+
c_lightbrown: '#C6BFA2'
66+
c_lightgreen: '#AFD8BC'
67+
c_lightgrey: '#AFC1B9'
68+
c_orange: '#EC9F7E'
69+
c_pink: '#F8A6A6'
70+
c_purple: '#A07CB0'
71+
c_red: '#ec7070'
72+
c_yellow: '#F5DDA9'
73+
fifth_width:
74+
- div: [!dag_tag _page_width, 5]
75+
full_width: !dag_tag _page_width
76+
half_width:
77+
- div: [!dag_tag _page_width, 2]
78+
quarter_width:
79+
- div: [!dag_tag _page_width, 4]
80+
third_width:
81+
- div: [!dag_tag _page_width, 3]
82+
two_thirds_width:
83+
- div: [!dag_tag _page_width, 3]
84+
- mul: [!dag_node -1, 2]
85+
meta_operations:
86+
.isel_with_drop:
87+
- .isel: [!arg 0, !arg 1]
88+
kwargs: {drop: true}
89+
.sel_with_drop:
90+
- .sel: [!arg 0, !arg 1]
91+
kwargs: {drop: true}
92+
neg_exp:
93+
- mul: [!arg 0, -1]
94+
- np.exp: [!dag_node -1]
95+
dag_visualization: {enabled: false}
96+
figsize: [6.25598, 1.56]
97+
file_ext: pdf
98+
helpers:
99+
axis_specific:
100+
0:
101+
axis: [1, 0]
102+
set_legend: {framealpha: 0, handlelength: 1, ncols: 2, use_legend: true}
103+
set_scales: {y: linear}
104+
save_figure: {bbox_inches: tight, dpi: 900, pad_inches: 0, transparent: true}
105+
set_labels: {x: Number of training sets, y: ''}
106+
set_title: {title: null}
107+
setup_figure: {ncols: 2}
108+
hue: noise
109+
kind: errorbars
110+
module: dantro.plot.funcs.generic
111+
plot_func: facet_grid
112+
select_and_combine:
113+
base_path: data/SIRS
114+
fields:
115+
loss:
116+
path: loss
117+
transform:
118+
- .isel_with_drop:
119+
- !dag_node -1
120+
- {epoch: -1}
121+
sharey: false
122+
style: {axes.grid: true, axes.labelsize: 7.5, axes.prop_cycle: 'cycler(''color'',
123+
[''#AFD8BC'', ''#ec7070'', ''#2F7194'', ''#48675A'', ''#C6BFA2'', ''#EC9F7E'',
124+
''#F5DDA9'', ''#3D4244'', ''#F8A6A6'', ''#A07CB0'', ''#AFC1B9'' ])', axes.spines.right: false,
125+
axes.spines.top: false, axes.titlesize: 7.5, base_style: null, figure.dpi: 254,
126+
font.family: Arial, font.size: 7.5, grid.alpha: 0.5, grid.linewidth: 0.5, legend.fontsize: 7.5,
127+
lines.linewidth: 1.2, mathtext.fontset: cm, savefig.bbox: tight, text.latex.preamble: '\usepackage{amssymb}
128+
\usepackage{amsmath}', text.usetex: false, xtick.labelsize: 7.5, ytick.labelsize: 7.5}
129+
transform:
130+
- .mean: [!dag_tag loss, seed]
131+
- .std: [!dag_tag loss, seed]
132+
- mul: [!dag_prev , 1]
133+
- xr.Dataset:
134+
- {mean: !dag_node -3, std: !dag_node -1}
135+
- .assign_coords:
136+
- !dag_node -1
137+
- noise: [$\sigma=0$, $\sigma=0.01$, $\sigma=0.025$, $\sigma=0.05$]
138+
tag: data
139+
use_bands: true
140+
x: n_train
141+
y: mean
142+
yerr: std
143+
144+
overfitting:
145+
based_on:
146+
- .creator.multiverse
147+
- .plot.facet_grid.line
148+
select_and_combine:
149+
fields:
150+
data: loss
151+
subspace:
152+
noise: 0.01
153+
seed: 0
154+
col: kind
155+
hue: n_train
156+
sharey: False
157+
helpers:
158+
set_scales:
159+
y: log

0 commit comments

Comments
 (0)