Skip to content

Commit 04daf65

Browse files
Added LNCDE doc string
1 parent 999bbf7 commit 04daf65

File tree

10 files changed

+87
-44
lines changed

10 files changed

+87
-44
lines changed

.github/workflows/pre-commit.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,4 @@ jobs:
1313
uses: actions/checkout@v2
1414

1515
- name: Checks with pre-commit
16-
uses: pre-commit/action@v2.0.3
16+
uses: pre-commit/action@v3.0.0

.pre-commit-config.yaml

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,25 @@
11
repos:
22
- repo: https://github.com/ambv/black
3-
rev: 22.3.0
3+
rev: 24.3.0
44
hooks:
5-
- id: black
5+
- id: black
6+
67
- repo: https://github.com/nbQA-dev/nbQA
7-
rev: 1.2.3
8+
rev: 1.7.0
89
hooks:
9-
- id: nbqa-black
10-
- id: nbqa-isort
11-
- id: nbqa-flake8
10+
- id: nbqa-black
11+
additional_dependencies: [black, setuptools]
12+
- id: nbqa-isort
13+
additional_dependencies: [isort, setuptools]
14+
- id: nbqa-flake8
15+
additional_dependencies: [flake8, setuptools]
16+
1217
- repo: https://github.com/PyCQA/isort
13-
rev: 5.12.0
18+
rev: 5.13.2
1419
hooks:
15-
- id: isort
20+
- id: isort
21+
1622
- repo: https://github.com/pycqa/flake8
17-
rev: 4.0.1
23+
rev: 6.1.0
1824
hooks:
19-
- id: flake8
25+
- id: flake8

data_dir/dataloaders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class Dataloader:
4242
def __init__(self, data, labels, inmemory=True):
4343
self.data = data
4444
self.labels = labels
45-
if type(self.data) == tuple:
45+
if isinstance(self.data, tuple):
4646
if len(data[1][0].shape) > 2:
4747
self.data_is_coeffs = True
4848
else:

models/LRU.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,7 @@ def __init__(self, N, H, r_min=0, r_max=1, max_phase=6.28, *, key):
6464
# between r_min and r_max, with phase in [0, max_phase].
6565
u1 = jr.uniform(u1_key, shape=(N,))
6666
u2 = jr.uniform(u2_key, shape=(N,))
67-
self.nu_log = jnp.log(
68-
-0.5 * jnp.log(u1 * (r_max**2 - r_min**2) + r_min**2)
69-
)
67+
self.nu_log = jnp.log(-0.5 * jnp.log(u1 * (r_max**2 - r_min**2) + r_min**2))
7068
self.theta_log = jnp.log(max_phase * u2)
7169

7270
# Glorot initialized Input/Output projection matrices

models/LinearNeuralCDEs.py

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,28 @@
1+
"""
2+
This module implements the `LogLinearCDE` class using JAX and Equinox. The model is a
3+
block-diagonal Linear Controlled Differential Equation (CDE), where the output is
4+
approximated during training using the Log-ODE method.
5+
6+
Attributes of the `LogLinearCDE` model:
7+
- `init_layer`: The linear layer used to initialize the hidden state $h_0$ from the input $x_0$.
8+
- `out_layer`: The linear layer used to produce final predictions from the hidden state.
9+
- `vf_A`: Learnable parameters for the linear vector field, shaped as flattened block matrices.
10+
- `hidden_dim`: The dimension of the hidden state $h_t$.
11+
- `block_size`: Size of each square block in the block-diagonal vector field.
12+
- `num_blocks`: Number of blocks, computed as `hidden_dim // block_size`.
13+
- `parallel_steps`: Number of log-flow matrices composed in parallel (using associative scan).
14+
- `logsig_depth`: The depth of the log-signature used in the Log-ODE method.
15+
- `basis_list`: The list of basis elements of the free Lie algebra up to the specified depth.
16+
- `lambd`: Regularization parameter applied to vector field scaling.
17+
- `w_init_std`: Standard deviation for the initial weights of the vector field.
18+
- `classification`: Boolean indicating if the model is used for classification tasks.
19+
20+
The class includes:
21+
- `log_ode`: Method for computing the iterated Lie brackets of the linear vector fields.
22+
- `__call__`: Performs the forward pass, where flows are composed and applied to the hidden state
23+
either step-by-step or in parallel (using associative scan), followed by output projection.
24+
"""
25+
126
from __future__ import annotations
227

328
from typing import List, Tuple
@@ -28,10 +53,6 @@ def depth(b):
2853

2954

3055
class LogLinearCDE(eqx.Module):
31-
"""
32-
Block‑diagonal Linear Controlled Differential Equation layer.
33-
"""
34-
3556
init_layer: eqx.nn.Linear
3657
out_layer: eqx.nn.Linear
3758
vf_A: jnp.ndarray
@@ -41,10 +62,10 @@ class LogLinearCDE(eqx.Module):
4162
parallel_steps: int
4263
logsig_depth: int
4364
basis_list: List[Tuple[int, ...]]
44-
stepsize: int
4565
lambd: float
66+
w_init_std: float
67+
classification: bool
4668

47-
classification: bool = True
4869
lip2: bool = True
4970
nondeterministic: bool = False
5071
stateful: bool = False
@@ -57,10 +78,10 @@ def __init__(
5778
label_dim: int,
5879
block_size: int,
5980
logsig_depth: int,
60-
stepsize: int,
6181
lambd: float = 1.0,
6282
w_init_std: float = 0.25,
6383
parallel_steps: int = 128,
84+
classification: bool = True,
6485
key,
6586
):
6687
if hidden_dim % block_size != 0:
@@ -70,24 +91,25 @@ def __init__(
7091
self.num_blocks = hidden_dim // block_size
7192
self.parallel_steps = parallel_steps
7293
self.logsig_depth = logsig_depth
73-
self.stepsize = stepsize
7494
ctx = rp.get_context(width=data_dim, depth=self.logsig_depth, coeffs=rp.DPReal)
7595
basis = ctx.lie_basis
7696
basis_list = []
7797
for i in range(basis.size(self.logsig_depth)):
7898
basis_list.append(eval(str(basis.index_to_key(i))))
7999
self.basis_list = basis_list
80100
self.lambd = lambd
101+
self.w_init_std = w_init_std
81102

82103
k_init, k_A, k_B = jr.split(key, 3)
83104
self.init_layer = eqx.nn.Linear(data_dim, hidden_dim, key=k_init)
84105
self.out_layer = eqx.nn.Linear(hidden_dim, label_dim, key=k_B)
85106

86107
self.vf_A = (
87108
jr.normal(k_A, (data_dim + 1, self.num_blocks * block_size * block_size))
88-
* w_init_std
109+
* self.w_init_std
89110
/ jnp.sqrt(block_size)
90111
)
112+
self.classification = classification
91113

92114
def log_ode(self, vf):
93115

@@ -111,7 +133,7 @@ def log_ode(self, vf):
111133

112134
left_indices = []
113135
right_indices = []
114-
for (i_b, b) in curr_elements:
136+
for i_b, b in curr_elements:
115137
u_tuple = to_tuple(b[0])
116138
v_tuple = to_tuple(b[1])
117139
i_u = basis_index[u_tuple]
@@ -184,7 +206,12 @@ def parallel_step(y, flows):
184206
inp_rem = flows[-remainder:]
185207
_, y_rem = jax.lax.scan(step, ys[-1], inp_rem)
186208
ys = jnp.vstack([ys, y_rem])
187-
ys = jnp.mean(ys, axis=0)
188-
ys = self.out_layer(ys)
189-
preds = jax.nn.softmax(ys)
209+
210+
if self.classification:
211+
ys = jnp.mean(ys, axis=0)
212+
preds = jax.nn.softmax(self.out_layer(ys))
213+
else:
214+
ys = jax.vmap(self.out_layer)(ys)
215+
preds = jnp.tanh(ys)
216+
190217
return preds

models/generate_model.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ def create_model(
7272
max_steps=16**4,
7373
scale=1.0,
7474
lambd=0.0,
75-
stepsize=1,
7675
w_init_std=0.25,
7776
*,
7877
key,
@@ -113,9 +112,9 @@ def create_model(
113112
label_dim=label_dim,
114113
block_size=block_size,
115114
logsig_depth=logsig_depth,
116-
stepsize=stepsize,
117115
lambd=lambd,
118116
w_init_std=w_init_std,
117+
classification=classification,
119118
key=key,
120119
),
121120
None,

results/analyse_results.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,12 +107,12 @@
107107
for tr_idx in train_idxs:
108108
idx = idxs[tr_idx]
109109
print(
110-
f"{model[:-1]} {dataset[:-1]} {exps[idx]} {100*val_metrics[idx]}"
110+
f"{model[:-1]} {dataset[:-1]} {exps[idx]} {100 * val_metrics[idx]}"
111111
)
112112

113113
elif experiment == "repeats":
114114
test_metrics = np.array(test_metrics)
115115
print(
116116
f"{model[:-1]} {dataset[:-1]} {np.mean([len(x) for x in val_metrics])} "
117-
f"{100*np.mean(test_metrics)} {100*np.std(test_metrics)}"
117+
f"{100 * np.mean(test_metrics)} {100 * np.std(test_metrics)}"
118118
)

run_experiment.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,6 @@ def run_experiments(model_names, dataset_names, experiment_folder, pytorch_exper
177177
"stepsize_controller": diffrax.ConstantStepSize(),
178178
"scale": scale,
179179
"lambd": lambd,
180-
"stepsize": stepsize,
181180
}
182181
run_args = {
183182
"data_dir": data_dir,
@@ -218,11 +217,11 @@ def run_experiments(model_names, dataset_names, experiment_folder, pytorch_exper
218217
model_names = ["mamba", "S6"]
219218
else:
220219
model_names = [
221-
"ncde",
222-
"log_ncde",
223-
"nrde",
224-
"S5",
225-
"lru",
220+
# "ncde",
221+
# "log_ncde",
222+
# "nrde",
223+
# "S5",
224+
# "lru",
226225
"bd_linear_ncde",
227226
"dense_linear_ncde",
228227
"diagonal_linear_ncde",

simple_example.ipynb

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@
7676
"outputs": [],
7777
"source": [
7878
"# Define the NeuralCDE class\n",
79+
"\n",
80+
"\n",
7981
"class NeuralCDE(eqx.Module):\n",
8082
" vf: eqx.nn.MLP # Vector field parameterised as an MLP\n",
8183
" data_dim: int # Dimension of the input data\n",
@@ -164,6 +166,8 @@
164166
"outputs": [],
165167
"source": [
166168
"# Define the LogNeuralCDE class, which is identical to NeuralCDE, except for the get_ode method\n",
169+
"\n",
170+
"\n",
167171
"class LogNeuralCDE(NeuralCDE):\n",
168172
" stepsize: int # The interval size for the Log-ODE method\n",
169173
" depth: int # The log-signature truncation depth for the Log-ODE method\n",
@@ -307,6 +311,8 @@
307311
],
308312
"source": [
309313
"# Generate synthetic multivariate time series data using JAX\n",
314+
"\n",
315+
"\n",
310316
"def generate_multivariate_sine_wave(frequencies, length, noise_level, key):\n",
311317
" t = jnp.linspace(0, 1, length) # Generate time points\n",
312318
" num_features = len(frequencies) # Number of features in the data\n",
@@ -422,6 +428,8 @@
422428
"outputs": [],
423429
"source": [
424430
"# Define the Dataloader class\n",
431+
"\n",
432+
"\n",
425433
"class Dataloader:\n",
426434
" data: jnp.ndarray # Data array\n",
427435
" labels: jnp.ndarray # Labels array\n",
@@ -464,6 +472,8 @@
464472
"outputs": [],
465473
"source": [
466474
"# Define the classification loss function with gradient calculation\n",
475+
"\n",
476+
"\n",
467477
"@eqx.filter_value_and_grad\n",
468478
"def classification_loss(model, X, y):\n",
469479
" # Predict the output using the model\n",
@@ -492,6 +502,8 @@
492502
"outputs": [],
493503
"source": [
494504
"# Define the training function for the model\n",
505+
"\n",
506+
"\n",
495507
"def train_model(\n",
496508
" model,\n",
497509
" num_steps=2000, # Number of training steps\n",

train.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -102,12 +102,14 @@ def regression_loss(diff_model, static_model, X, y, state, key):
102102
pred_y = pred_y[:, :, 0]
103103
norm = 0
104104
if model.lip2:
105-
for layer in model.vf.mlp.layers:
106-
norm += jnp.mean(
107-
jnp.linalg.norm(layer.weight, axis=-1)
108-
+ jnp.linalg.norm(layer.bias, axis=-1)
109-
)
110-
norm *= model.lambd
105+
if hasattr(model, "vf"):
106+
for layer in model.vf.mlp.layers:
107+
norm += jnp.mean(
108+
jnp.linalg.norm(layer.weight, axis=-1)
109+
+ jnp.linalg.norm(layer.bias, axis=-1)
110+
)
111+
else:
112+
norm += jnp.mean(jnp.linalg.norm(model.vf_A, axis=-1))
111113
return (
112114
jnp.mean(jnp.mean((pred_y - y) ** 2, axis=1)) + norm,
113115
state,

0 commit comments

Comments
 (0)