Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pre-commit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ jobs:
uses: actions/checkout@v2

- name: Checks with pre-commit
uses: pre-commit/action@v2.0.3
uses: pre-commit/action@v3.0.0
26 changes: 16 additions & 10 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,19 +1,25 @@
repos:
- repo: https://github.com/ambv/black
rev: 22.3.0
rev: 24.3.0
hooks:
- id: black
- id: black

- repo: https://github.com/nbQA-dev/nbQA
rev: 1.2.3
rev: 1.7.0
hooks:
- id: nbqa-black
- id: nbqa-isort
- id: nbqa-flake8
- id: nbqa-black
additional_dependencies: [black, setuptools]
- id: nbqa-isort
additional_dependencies: [isort, setuptools]
- id: nbqa-flake8
additional_dependencies: [flake8, setuptools]

- repo: https://github.com/PyCQA/isort
rev: 5.12.0
rev: 5.13.2
hooks:
- id: isort
- id: isort

- repo: https://github.com/pycqa/flake8
rev: 4.0.1
rev: 6.1.0
hooks:
- id: flake8
- id: flake8
30 changes: 29 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,33 @@ method for training NCDEs.

---

## Update – 22nd May 2025

This repository now supports **Structured Linear Controlled Differential Equations** (SLiCEs), which replace the non-linear vector fields of NCDEs and Log-NCDEs with structured linear vector fields, retaining the same maximal expressivity whilst being significantly more efficient.

SLiCEs are defined by

$$
h_t = h_0 + \int_0^t \sum_{i=1}^{d_X} A^i_{\theta} h_s \mathrm{d}X_s,
$$

where each $A^i_{\theta} \in \mathbb{R}^{d_h \times d_h}$ is a trainable matrix acting on the hidden state. When the $A^i_{\theta}$ are dense, this system is known as a **Linear Neural CDE (LNCDE)** and these models are *maximally expressive* (i.e., universal), see [here](https://github.com/Benjamin-Walker/selective-ssms-and-linear-cdes). However, the computational cost and number of parameters when using dense matrices scale as $\mathcal{O}(d_h^3)$, making them impractical for large models.

SLiCEs offer a solution: they retain the maximal expressivity **while reducing computational and memory costs** by structuring the $A^i_{\theta}$ matrices. This repository includes three SLiCE variants:
- **D-LNCDE**: Diagonal matrices: fastest, but limited expressivity.
- **BD-LNCDE**: Block-diagonal matrices: maximally expressive and efficient.
- **DE-LNCDE**: Fully dense matrices: maximally expressive, but computationally expensive.

**In practice**: Replacing the non-linear vector field of a Log-NCDE with the block-diagonal vector field of a BD-LNCDE leads to **20× faster training** per step on the UEA multivariate time-series tasks whilst achieving the same average test accuracy. The figure below compares models on their average test accuracy, average time per 1000 training steps, and average GPU memory, which is represented by the area of each circle.

<p align="center">
<img class="center" src="./assets/time_vs_acc.png" width="800"/>
</p>

For further details and an expansive comparison with other state-of-the-art sequence models, see the [official SLiCE repository](https://github.com/Benjamin-Walker/structured-linear-cdes).

---

## Introduction

Neural controlled differential equations (NCDEs) treat time series data as observations from a control path $X_t$,
Expand Down Expand Up @@ -57,6 +84,7 @@ The code for preprocessing the datasets, training S5, LRU, NCDE, NRDE, and Log-N
- `optax` for neural network optimisers.
- `diffrax` for differential equation solvers.
- `signax` for calculating the signature.
- `roughpy` for calculating the Hall basis.
- `sktime` for handling time series data in ARFF format.
- `tqdm` for progress bars.
- `matplotlib` for plotting.
Expand All @@ -67,7 +95,7 @@ conda create -n Log-NCDE python=3.10
conda activate Log-NCDE
conda install pre-commit=3.7.1 sktime=0.30.1 tqdm=4.66.4 matplotlib=3.8.4 -c conda-forge
# Substitue for correct Jax pip install: https://jax.readthedocs.io/en/latest/installation.html
pip install -U "jax[cuda12]" "jaxlib[cuda12]" equinox==0.11.8 optax==0.2.2 diffrax==0.6.0 signax==0.1.1
pip install -U "jax[cuda12]" "jaxlib[cuda12]" equinox==0.12.2 optax==0.2.4 diffrax==0.7.0 signax==0.1.1 roughpy==0.2.0
```

If running `data_dir/process_uea.py` throws this error: No module named 'packaging'
Expand Down
Binary file added assets/time_vs_acc.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion data_dir/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class Dataloader:
def __init__(self, data, labels, inmemory=True):
self.data = data
self.labels = labels
if type(self.data) == tuple:
if isinstance(self.data, tuple):
if len(data[1][0].shape) > 2:
self.data_is_coeffs = True
else:
Expand Down
23 changes: 23 additions & 0 deletions data_dir/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,11 @@ def dataset_generator(
)


def _scale_to_minus_one_one(x, data_min, data_max, eps=1e-8):
"""Affine‑maps x from [data_min,data_max] → [‑1,1] with broadcasting."""
return 2.0 * (x - data_min) / (data_max - data_min + eps) - 1.0


def create_uea_dataset(
data_dir,
name,
Expand All @@ -242,6 +247,7 @@ def create_uea_dataset(
depth,
include_time,
T,
scale=False,
*,
key,
):
Expand Down Expand Up @@ -294,6 +300,21 @@ def create_uea_dataset(
)
data = jnp.concatenate([ts[:, :, None], data], axis=2)

if scale:
if use_presplit:
# stack (N,L,C) arrays along N to get all samples
all_data = jnp.concatenate([train_data, val_data, test_data], axis=0)
data_min = all_data.min(axis=(0, 1), keepdims=True)
data_max = all_data.max(axis=(0, 1), keepdims=True)

train_data = _scale_to_minus_one_one(train_data, data_min, data_max)
val_data = _scale_to_minus_one_one(val_data, data_min, data_max)
test_data = _scale_to_minus_one_one(test_data, data_min, data_max)
else:
data_min = data.min(axis=(0, 1), keepdims=True)
data_max = data.max(axis=(0, 1), keepdims=True)
data = _scale_to_minus_one_one(data, data_min, data_max)

return dataset_generator(
name,
data,
Expand Down Expand Up @@ -396,6 +417,7 @@ def create_dataset(
depth,
include_time,
T,
scale=False,
*,
key,
):
Expand All @@ -416,6 +438,7 @@ def create_dataset(
depth,
include_time,
T,
scale=scale,
key=key,
)
elif name[:-1] in toy_subfolders:
Expand Down
19 changes: 11 additions & 8 deletions data_dir/process_uea.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,16 +83,19 @@ def convert_all_files(data_dir):
train_file, test_file
)
data = jnp.concatenate([train_data, test_data])
orig_data_len = data.shape[0]
labels = jnp.concatenate([train_labels, test_labels])

unique_rows, indices, inverse_indices = np.unique(
data, axis=0, return_index=True, return_inverse=True
)
data = data[indices]
labels = labels[indices]
print(
f"Deleting {len(inverse_indices) - len(indices)} repeated samples in {ds_name}"
)
# keep first occurrence of each unique row
_, first_idx = np.unique(data, axis=0, return_index=True)

# restore original ordering of those first occurrences
keep_idx = np.sort(first_idx)

data = data[keep_idx]
labels = labels[keep_idx]

print(f"Deleting {orig_data_len - len(data)} repeated samples in {ds_name}")

original_idxs = (
jnp.arange(0, train_data.shape[0]),
Expand Down
30 changes: 30 additions & 0 deletions experiment_configs/repeats/bd_linear_ncde/EigenWorms.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
{
"seeds": [
2345,
3456,
4567,
5678,
6789
],
"data_dir": "data_dir",
"output_parent_dir": "",
"lr_scheduler": "lambda lr: lr",
"num_steps": 100000,
"print_steps": 1000,
"early_stopping_steps": 10,
"batch_size": 32,
"model_name": "linear_ncde",
"metric": "accuracy",
"classification": true,
"dataset_name": "EigenWorms",
"use_presplit": false,
"T": 1,
"scale": 1,
"time": "True",
"lr": "0.001",
"hidden_dim": "128",
"lambd": 0.001,
"block_size": 4,
"stepsize": 12,
"depth": 2
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
{
"seeds": [
2345,
3456,
4567,
5678,
6789
],
"data_dir": "data_dir",
"output_parent_dir": "",
"lr_scheduler": "lambda lr: lr",
"num_steps": 100000,
"print_steps": 1000,
"early_stopping_steps": 10,
"batch_size": 32,
"model_name": "linear_ncde",
"metric": "accuracy",
"classification": true,
"dataset_name": "EthanolConcentration",
"use_presplit": false,
"T": 1,
"scale": 1,
"time": "True",
"lr": "0.0001",
"hidden_dim": "64",
"block_size": 4,
"depth": 1,
"stepsize": 1,
"lambd": 0.000001
}
30 changes: 30 additions & 0 deletions experiment_configs/repeats/bd_linear_ncde/Heartbeat.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
{
"seeds": [
2345,
3456,
4567,
5678,
6789
],
"data_dir": "data_dir",
"output_parent_dir": "",
"lr_scheduler": "lambda lr: lr",
"num_steps": 100000,
"print_steps": 1000,
"early_stopping_steps": 10,
"batch_size": 32,
"model_name": "linear_ncde",
"metric": "accuracy",
"classification": true,
"dataset_name": "Heartbeat",
"use_presplit": false,
"T": 1,
"scale": 1,
"time": "True",
"lr": "0.001",
"hidden_dim": "16",
"block_size": 4,
"depth": 2,
"stepsize": 2,
"lambd": 0.000001
}
30 changes: 30 additions & 0 deletions experiment_configs/repeats/bd_linear_ncde/MotorImagery.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
{
"seeds": [
2345,
3456,
4567,
5678,
6789
],
"data_dir": "data_dir",
"output_parent_dir": "",
"lr_scheduler": "lambda lr: lr",
"num_steps": 100000,
"print_steps": 1000,
"early_stopping_steps": 10,
"batch_size": 32,
"model_name": "linear_ncde",
"metric": "accuracy",
"classification": true,
"dataset_name": "MotorImagery",
"use_presplit": false,
"T": 1,
"scale": 1,
"time": "False",
"lr": "0.001",
"hidden_dim": "16",
"block_size": 4,
"depth": 2,
"stepsize": 16,
"lambd": 0.001
}
30 changes: 30 additions & 0 deletions experiment_configs/repeats/bd_linear_ncde/SelfRegulationSCP1.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
{
"seeds": [
2345,
3456,
4567,
5678,
6789
],
"data_dir": "data_dir",
"output_parent_dir": "",
"lr_scheduler": "lambda lr: lr",
"num_steps": 100000,
"print_steps": 1000,
"early_stopping_steps": 10,
"batch_size": 32,
"model_name": "linear_ncde",
"metric": "accuracy",
"classification": true,
"dataset_name": "SelfRegulationSCP1",
"use_presplit": false,
"T": 1,
"scale": 1,
"time": "False",
"lr": "0.0001",
"hidden_dim": "64",
"block_size": 4,
"stepsize": 16,
"depth": 2,
"lambd": 0.0
}
30 changes: 30 additions & 0 deletions experiment_configs/repeats/bd_linear_ncde/SelfRegulationSCP2.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
{
"seeds": [
2345,
3456,
4567,
5678,
6789
],
"data_dir": "data_dir",
"output_parent_dir": "",
"lr_scheduler": "lambda lr: lr",
"num_steps": 100000,
"print_steps": 1000,
"early_stopping_steps": 10,
"batch_size": 32,
"model_name": "linear_ncde",
"metric": "accuracy",
"classification": true,
"dataset_name": "SelfRegulationSCP2",
"use_presplit": false,
"T": 1,
"scale": 1,
"time": "False",
"lr": "0.0001",
"hidden_dim": "128",
"block_size": 4,
"stepsize": 4,
"depth": 2,
"lambd": 0.001
}
Loading