You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: README.md
+32-37Lines changed: 32 additions & 37 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -1,10 +1,12 @@
1
1
# MARBLE - MAnifold Representation Basis LEarning
2
2
3
-
MARBLE is an unsupervised geometric deep-learning method that can
3
+
MARBLE is a geometric deeplearning method for finding latent representations of dynamical systems. This repo includes diverse examples including non-linear dynamical systems, recurrent neural networks (RNNs), and neural recordings.
4
4
5
-
1. Find interpretable latent representations of neural dynamics. It also applies to non-linear dynamical systems in other domains or, more generally, vector fields over manifolds.
6
-
2. Perform unbiased comparisons across conditions within the same animal (or dynamical system).
7
-
3. Compare dynamics across animals or artificial neural networks.
5
+
Use MARBLE for:
6
+
7
+
1.**Deriving interpretable latent representations.** Useful for interpreting single-trial neural population recordings in terms of task variables. More generally, MARBLE can infer latent variables from time series observables in non-linear dynamical systems.
8
+
2.**Downstream tasks.** MARBLE representations tend to be more 'unfolded', which makes them amenable for downstream tasks, e.g., decoding.
9
+
3.**Dynamical comparisons.** MARBLE is an *intrinsic* method, which makes the latent representations less sensitive to the choice of observables, e.g., recorded neurons. This allows cross-animal decoding and comparisons.
8
10
9
11
The code is built around [PyG (PyTorch Geometric)](https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html).
10
12
@@ -13,13 +15,12 @@ The code is built around [PyG (PyTorch Geometric)](https://pytorch-geometric.rea
13
15
If you find this package useful or inspirational, please cite our work as follows
14
16
15
17
```
16
-
@misc{gosztolai2023interpretable,
17
-
title={Interpretable statistical representations of neural population dynamics and geometry},
18
+
@misc{gosztolaipeach_MARBLE_2025,
19
+
title={MARBLE: interpretable representations of neural population dynamics using geometric deep learning},
18
20
author={Adam Gosztolai and Robert L. Peach and Alexis Arnaudon and Mauricio Barahona and Pierre Vandergheynst},
19
-
year={2023},
20
-
eprint={2304.03376},
21
-
archivePrefix={arXiv},
22
-
primaryClass={cs.LG}
21
+
year={2025},
22
+
doi={10.1038/s41592-024-02582-2},
23
+
journal={Nature Methods},
23
24
}
24
25
```
25
26
@@ -80,17 +81,21 @@ We suggest you study at least the example of a [simple vector fields over flat s
80
81
81
82
Briefly, MARBLE takes two inputs
82
83
83
-
1.`anchor` - a list of `nxd` arrays, each defining a cloud of anchor points describing the manifold
84
-
2.`vector` - a list of `nxD` arrays, defining a vector signal at the anchor points over the respective manifolds. For dynamical systems, D=d, but our code can also handle signals of other dimensions. Read more about [inputs](#inputs) and [different conditions](#conditions).
84
+
1.`anchor` - a list of `n_c x d` arrays, where `n_c` is the number of time points and `d` is the number of features (e.g., neurons or PCA loadings) in condition `c`. Each set of `n_i` points defines a manifold through a connected graph. For example, if you record time series observables, then `n_c = ts_1 + ts_2 + ... + ts_nc`, where `nc` is the number of time series under a given condition.
85
+
2.`vector` - a list of `n_c x D` arrays, where `n_c` are the same time points as in `anchor` and `D` are vector features defining the dynamics over the manifold. For dynamical systems, `D = d`, but our code can also handle signals of other dimensions. For time series observables, it is convenient to take each vector to be the difference between consecutive time points.
86
+
87
+
Read more about [inputs](#inputs) and [different conditions](#conditions).
88
+
89
+
**MARBLE principle: divide and conquer.** The basic principle behind MARBLE is that dynamics vary continuously with respect to external conditions and inputs. Thus, if you have a large dataset with `c` conditions, it is wise to slit the data up into a list of `nxd` arrays, e.g., `anchor = [neural_states_condition_1, neural_states_condition_2, ...]` and `vector = [neural_state_changes_condition_1, neural_state_changes_condition_2, ...]`, rather than passing them all at once. This will ensure that the manifold features are correctly extracted. This will yield a set of `c` submanifolds, which will be combined into a joint manifold if they belong together, i.e., then are a continuation of the same system. This is possible because the learning algorithm using the features is unsupervised (unaware of these 'condition labels') and will find dynamical correspondences between conditions.
85
90
86
-
Using these inputs, you can construct a dataset for MARBLE.
91
+
Using these inputs, you can construct a Pytorch Geometric data object for MARBLE.
87
92
88
93
```
89
94
import MARBLE
90
95
data = MARBLE.construct_dataset(anchor=pos, vector=x)
91
96
```
92
97
93
-
The main attributes are `data.pos` - manifold positions concatenated, `data.x` - manifold signals concatenated and `data.y` - identifiers that tell you which manifold the point belongs to. Read more about [other useful data attributes](#construct).
98
+
The attributes `data.pos`, `data.x`, `data.y`and `data.edge_index` will hold the anchors, vector signals, condition labels and adjacencies, respectively. See [other useful data attributes](#construct) for different preprocessing options.
94
99
95
100
Now, you can initialise and train a MARBLE model. Read more about [training parameters](#training).
96
101
@@ -147,46 +152,34 @@ If you measure time series observables, such as neural firing rates, you can sta
147
152
148
153
If you do not directly have access to the velocities, you can approximate them as `x = np.vstack([np.diff(ts_1, axis=0), np.diff(ts_2, axis=0)])` and take `pos = np.vstack([ts_1[:-1,:], ts_2[:-1,:]])` to ensure `pos` and `x` have the same length.
149
154
150
-
If you just want to play around with dynamical systems, why not try our (experimental) sister package [DE_library](https://github.com/agosztolai/DE_library).
151
-
152
155
<aname="conditions"></a>
153
156
### More on different conditions
154
157
155
158
Comparing dynamics in a data-driven way is equivalent to comparing the corresponding vector fields based on their respective sample sets. The dynamics to be compared might correspond to different experimental conditions (stimulation conditions, genetic perturbations, etc.) and dynamical systems (different tasks, different brain regions).
156
159
157
-
Suppose we have the data pairs `pos1, pos2` and `x1, x2`. Then we may concatenate them as a list to ensure that our pipeline handles them independently (on different manifolds), but embeds them jointly in the same space.
160
+
Suppose we have the data pairs `pos1, pos2` and `x1, x2` for two conditions. Then we may concatenate them as a list to ensure that our pipeline handles them independently (on different manifolds) but then embeds them jointly in the same space.
158
161
159
162
```
160
163
pos_list, x_list = [pos1, pos2], [x1, x2]
161
164
```
162
165
163
-
It is sometimes useful to consider that two vector fields lie on independent manifolds (providing them as a list) even when we want to *discover* the contrary. However, when we know that two vector fields lie on the same manifold, then it can be advantageous to stack their corresponding samples (stacking them) as this will enforce geometric relationships between them through the proximity graph.
164
-
165
166
<aname="construct"></a>
166
167
### More on constructing data object
167
168
168
-
Our pipeline is built around a Pytorch Geometric data object, which we can obtain by running the following constructor.
169
+
The dataset constructor can take various parameters.
169
170
170
171
```
171
172
import MARBLE
172
-
data = MARBLE.construct_dataset(anchor=pos, vector=x, spacing=0.03, graph_type='cknn', k=15, local_gauge=False)
173
+
data = MARBLE.construct_dataset(anchor=pos, vector=x, spacing=0.03, delta=1.2, local_gauge=True)
173
174
```
174
175
175
176
This command will do several things.
176
177
177
-
1.Subsample the point cloud using farthest point sampling to achieve even sampling density. Using `spacing=0.03` means the average distance between the subsampled points will equal 3% of the manifold diameter.
178
-
2.Fit a nearest neighbour graph to each point cloud using the `graph_type=cknn` method using `k=15` nearest neighbours. We implemented other graph algorithms, but cknn typically works. Note that `k` should be large enough to approximate the tangent space but small enough not to connect (geodesically) distant points of the manifold. The more data you have, the higher `k` you can use.
179
-
3.Perform operations in local (manifold) gauges or global coordinates. Note that `local_gauge=False` should be used whenever the manifold has negligible curvature on the scale of the local feature. Setting `local_gauge=True` means that the code performs tangent space alignments before computing gradients. However, this will increase the cost of the computations $m^2$-fold, where $m$ is the manifold dimension because points will be treated as vector spaces. See the example of a [simple vector fields over curved surfaces](https://github.com/agosztolai/MARBLE/blob/main/examples/toy_examples/ex_vector_field_curved_surface.py) for illustration.
180
-
178
+
1.`spacing = 0.03` means the points will be subsampled using farthest point sampling to ensure that features are not overrepresented. The average distance between the subsampled points will equal 3% of the manifold diameter.
179
+
2.`number_of_resamples = 2` resamples the dataset twice, which is particularly useful when subsampling the data using `spacing`. This will effectively double the training data because a new adjacency graph will be fit.
180
+
3.`delta = 1.2` is a continuous parameter that adapts the density of the graph edges based on sample density. It is the single most useful parameter to tune MARBLE representations, with a higher `delta` achieving more 'unfolded' representations, as the cost of breaking things apart for too high `delta`. It has a similar effect to the minimum distance parameter in UMAP.
181
+
4.`local_gauge=True` means that operations will be performed in local (manifold) gauges. The code will perform tangent space alignments before computing gradients. However, this will increase the cost of the computations $m^2$-fold, where $m$ is the manifold dimension because points will be treated as vector spaces. See the example of a [simple vector fields over curved surfaces](https://github.com/agosztolai/MARBLE/blob/main/examples/toy_examples/ex_vector_field_curved_surface.py) for illustration.
181
182
182
-
The final data object contains the following attributes (among others):
183
-
184
-
```
185
-
data.pos: positions `pos` concatenated across manifolds
186
-
data.x: vectors `x` concatenated across manifolds
187
-
data.y: labels for each point denoting which manifold it belongs to
188
-
data.edge_index: edge list of proximity graph (each manifold gets its graph, disconnected from others)
189
-
```
190
183
191
184
<aname="training"></a>
192
185
### Training
@@ -198,12 +191,14 @@ You first specify the hyperparameters. The key ones are the following, which wil
198
191
```
199
192
params = {'epochs': 50, #optimisation epochs
200
193
'hidden_channels': 32, #number of internal dimensions in MLP
201
-
'out_channels': 5,
194
+
'out_channels': 3,
202
195
'inner_product_features': True,
203
196
}
204
197
205
198
```
206
199
200
+
**Note:** You will want to try gradually increase 'out_channels' from a small number in order to ensure information compression. If you want a CEBRA-like spherical layout, set 'emb_norm = True'.
201
+
207
202
Then, proceed by constructing a network object
208
203
209
204
```
@@ -232,7 +227,7 @@ One of the main features of our method is the ability to run in two different mo
232
227
1. Embedding-aware mode - learn manifold embedding and dynamics
233
228
2. Embedding-agnostic mode - learn dynamics only
234
229
235
-
To enable embedding-agnostic mode, set `inner_product_features=True` in training `params`. This engages an additional layer in the network after the computation of gradients, which makes them rotation invariant.
230
+
To enable embedding-agnostic mode, set `inner_product_features = True` in training `params`. This engages an additional layer in the network after the computation of gradients, which makes them rotation invariant.
236
231
237
232
As a slight cost of expressivity, this feature enables the orientation- and embedding-independent representation of dynamics over the manifolds. Amongst others, this allows one to recognise similar dynamics across different manifolds. See [RNN example](https://github.com/agosztolai/MARBLE/blob/main/examples/RNN/RNN.ipynb) for an illustration.
238
233
@@ -247,9 +242,9 @@ Training is successful when features are recognised to be similar across distinc
247
242
248
243
Problems with the above would be possible signs your solution will be suboptimal and will likely not generalise well. In this case, try the following
249
244
* increase training time (increase `epochs`)
250
-
* increase your data (e.g., decrease `spacing` and construct the dataset again)
245
+
* increase your data (e.g., decrease `spacing` and increase `number_of_resamples`)
251
246
* decrease number of parameters (decrease `hidden_channels`, or decrease order, try `order=1`)
252
-
* improve the gradient approximation (increase `k` or `delta`)
247
+
* improve the gradient approximation (increase `delta`)
253
248
254
249
If your data is very noisy, try enabling diffusion (`diffusion=True` in training `params`).
0 commit comments