Skip to content

Commit 289d119

Browse files
committed
readme
1 parent 5447765 commit 289d119

File tree

2 files changed

+46
-37
lines changed

2 files changed

+46
-37
lines changed

README.md

Lines changed: 32 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
# MARBLE - MAnifold Representation Basis LEarning
22

3-
MARBLE is an unsupervised geometric deep-learning method that can
3+
MARBLE is a geometric deep learning method for finding latent representations of dynamical systems. This repo includes diverse examples including non-linear dynamical systems, recurrent neural networks (RNNs), and neural recordings.
44

5-
1. Find interpretable latent representations of neural dynamics. It also applies to non-linear dynamical systems in other domains or, more generally, vector fields over manifolds.
6-
2. Perform unbiased comparisons across conditions within the same animal (or dynamical system).
7-
3. Compare dynamics across animals or artificial neural networks.
5+
Use MARBLE for:
6+
7+
1. **Deriving interpretable latent representations.** Useful for interpreting single-trial neural population recordings in terms of task variables. More generally, MARBLE can infer latent variables from time series observables in non-linear dynamical systems.
8+
2. **Downstream tasks.** MARBLE representations tend to be more 'unfolded', which makes them amenable for downstream tasks, e.g., decoding.
9+
3. **Dynamical comparisons.** MARBLE is an *intrinsic* method, which makes the latent representations less sensitive to the choice of observables, e.g., recorded neurons. This allows cross-animal decoding and comparisons.
810

911
The code is built around [PyG (PyTorch Geometric)](https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html).
1012

@@ -13,13 +15,12 @@ The code is built around [PyG (PyTorch Geometric)](https://pytorch-geometric.rea
1315
If you find this package useful or inspirational, please cite our work as follows
1416

1517
```
16-
@misc{gosztolai2023interpretable,
17-
title={Interpretable statistical representations of neural population dynamics and geometry},
18+
@misc{gosztolaipeach_MARBLE_2025,
19+
title={MARBLE: interpretable representations of neural population dynamics using geometric deep learning},
1820
author={Adam Gosztolai and Robert L. Peach and Alexis Arnaudon and Mauricio Barahona and Pierre Vandergheynst},
19-
year={2023},
20-
eprint={2304.03376},
21-
archivePrefix={arXiv},
22-
primaryClass={cs.LG}
21+
year={2025},
22+
doi={10.1038/s41592-024-02582-2},
23+
journal={Nature Methods},
2324
}
2425
```
2526

@@ -80,17 +81,21 @@ We suggest you study at least the example of a [simple vector fields over flat s
8081

8182
Briefly, MARBLE takes two inputs
8283

83-
1. `anchor` - a list of `nxd` arrays, each defining a cloud of anchor points describing the manifold
84-
2. `vector` - a list of `nxD` arrays, defining a vector signal at the anchor points over the respective manifolds. For dynamical systems, D=d, but our code can also handle signals of other dimensions. Read more about [inputs](#inputs) and [different conditions](#conditions).
84+
1. `anchor` - a list of `n_c x d` arrays, where `n_c` is the number of time points and `d` is the number of features (e.g., neurons or PCA loadings) in condition `c`. Each set of `n_i` points defines a manifold through a connected graph. For example, if you record time series observables, then `n_c = ts_1 + ts_2 + ... + ts_nc`, where `nc` is the number of time series under a given condition.
85+
2. `vector` - a list of `n_c x D` arrays, where `n_c` are the same time points as in `anchor` and `D` are vector features defining the dynamics over the manifold. For dynamical systems, `D = d`, but our code can also handle signals of other dimensions. For time series observables, it is convenient to take each vector to be the difference between consecutive time points.
86+
87+
Read more about [inputs](#inputs) and [different conditions](#conditions).
88+
89+
**MARBLE principle: divide and conquer.** The basic principle behind MARBLE is that dynamics vary continuously with respect to external conditions and inputs. Thus, if you have a large dataset with `c` conditions, it is wise to slit the data up into a list of `nxd` arrays, e.g., `anchor = [neural_states_condition_1, neural_states_condition_2, ...]` and `vector = [neural_state_changes_condition_1, neural_state_changes_condition_2, ...]`, rather than passing them all at once. This will ensure that the manifold features are correctly extracted. This will yield a set of `c` submanifolds, which will be combined into a joint manifold if they belong together, i.e., then are a continuation of the same system. This is possible because the learning algorithm using the features is unsupervised (unaware of these 'condition labels') and will find dynamical correspondences between conditions.
8590

86-
Using these inputs, you can construct a dataset for MARBLE.
91+
Using these inputs, you can construct a Pytorch Geometric data object for MARBLE.
8792

8893
```
8994
import MARBLE
9095
data = MARBLE.construct_dataset(anchor=pos, vector=x)
9196
```
9297

93-
The main attributes are `data.pos` - manifold positions concatenated, `data.x` - manifold signals concatenated and `data.y` - identifiers that tell you which manifold the point belongs to. Read more about [other useful data attributes](#construct).
98+
The attributes `data.pos`, `data.x`, `data.y` and `data.edge_index` will hold the anchors, vector signals, condition labels and adjacencies, respectively. See [other useful data attributes](#construct) for different preprocessing options.
9499

95100
Now, you can initialise and train a MARBLE model. Read more about [training parameters](#training).
96101

@@ -147,46 +152,34 @@ If you measure time series observables, such as neural firing rates, you can sta
147152

148153
If you do not directly have access to the velocities, you can approximate them as `x = np.vstack([np.diff(ts_1, axis=0), np.diff(ts_2, axis=0)])` and take `pos = np.vstack([ts_1[:-1,:], ts_2[:-1,:]])` to ensure `pos` and `x` have the same length.
149154

150-
If you just want to play around with dynamical systems, why not try our (experimental) sister package [DE_library](https://github.com/agosztolai/DE_library).
151-
152155
<a name="conditions"></a>
153156
### More on different conditions
154157

155158
Comparing dynamics in a data-driven way is equivalent to comparing the corresponding vector fields based on their respective sample sets. The dynamics to be compared might correspond to different experimental conditions (stimulation conditions, genetic perturbations, etc.) and dynamical systems (different tasks, different brain regions).
156159

157-
Suppose we have the data pairs `pos1, pos2` and `x1, x2`. Then we may concatenate them as a list to ensure that our pipeline handles them independently (on different manifolds), but embeds them jointly in the same space.
160+
Suppose we have the data pairs `pos1, pos2` and `x1, x2` for two conditions. Then we may concatenate them as a list to ensure that our pipeline handles them independently (on different manifolds) but then embeds them jointly in the same space.
158161

159162
```
160163
pos_list, x_list = [pos1, pos2], [x1, x2]
161164
```
162165

163-
It is sometimes useful to consider that two vector fields lie on independent manifolds (providing them as a list) even when we want to *discover* the contrary. However, when we know that two vector fields lie on the same manifold, then it can be advantageous to stack their corresponding samples (stacking them) as this will enforce geometric relationships between them through the proximity graph.
164-
165166
<a name="construct"></a>
166167
### More on constructing data object
167168

168-
Our pipeline is built around a Pytorch Geometric data object, which we can obtain by running the following constructor.
169+
The dataset constructor can take various parameters.
169170

170171
```
171172
import MARBLE
172-
data = MARBLE.construct_dataset(anchor=pos, vector=x, spacing=0.03, graph_type='cknn', k=15, local_gauge=False)
173+
data = MARBLE.construct_dataset(anchor=pos, vector=x, spacing=0.03, delta=1.2, local_gauge=True)
173174
```
174175

175176
This command will do several things.
176177

177-
1. Subsample the point cloud using farthest point sampling to achieve even sampling density. Using `spacing=0.03` means the average distance between the subsampled points will equal 3% of the manifold diameter.
178-
2. Fit a nearest neighbour graph to each point cloud using the `graph_type=cknn` method using `k=15` nearest neighbours. We implemented other graph algorithms, but cknn typically works. Note that `k` should be large enough to approximate the tangent space but small enough not to connect (geodesically) distant points of the manifold. The more data you have, the higher `k` you can use.
179-
3. Perform operations in local (manifold) gauges or global coordinates. Note that `local_gauge=False` should be used whenever the manifold has negligible curvature on the scale of the local feature. Setting `local_gauge=True` means that the code performs tangent space alignments before computing gradients. However, this will increase the cost of the computations $m^2$-fold, where $m$ is the manifold dimension because points will be treated as vector spaces. See the example of a [simple vector fields over curved surfaces](https://github.com/agosztolai/MARBLE/blob/main/examples/toy_examples/ex_vector_field_curved_surface.py) for illustration.
180-
178+
1. `spacing = 0.03` means the points will be subsampled using farthest point sampling to ensure that features are not overrepresented. The average distance between the subsampled points will equal 3% of the manifold diameter.
179+
2. `number_of_resamples = 2` resamples the dataset twice, which is particularly useful when subsampling the data using `spacing`. This will effectively double the training data because a new adjacency graph will be fit.
180+
3. `delta = 1.2` is a continuous parameter that adapts the density of the graph edges based on sample density. It is the single most useful parameter to tune MARBLE representations, with a higher `delta` achieving more 'unfolded' representations, as the cost of breaking things apart for too high `delta`. It has a similar effect to the minimum distance parameter in UMAP.
181+
4. `local_gauge=True` means that operations will be performed in local (manifold) gauges. The code will perform tangent space alignments before computing gradients. However, this will increase the cost of the computations $m^2$-fold, where $m$ is the manifold dimension because points will be treated as vector spaces. See the example of a [simple vector fields over curved surfaces](https://github.com/agosztolai/MARBLE/blob/main/examples/toy_examples/ex_vector_field_curved_surface.py) for illustration.
181182

182-
The final data object contains the following attributes (among others):
183-
184-
```
185-
data.pos: positions `pos` concatenated across manifolds
186-
data.x: vectors `x` concatenated across manifolds
187-
data.y: labels for each point denoting which manifold it belongs to
188-
data.edge_index: edge list of proximity graph (each manifold gets its graph, disconnected from others)
189-
```
190183

191184
<a name="training"></a>
192185
### Training
@@ -198,12 +191,14 @@ You first specify the hyperparameters. The key ones are the following, which wil
198191
```
199192
params = {'epochs': 50, #optimisation epochs
200193
'hidden_channels': 32, #number of internal dimensions in MLP
201-
'out_channels': 5,
194+
'out_channels': 3,
202195
'inner_product_features': True,
203196
}
204197
205198
```
206199

200+
**Note:** You will want to try gradually increase 'out_channels' from a small number in order to ensure information compression. If you want a CEBRA-like spherical layout, set 'emb_norm = True'.
201+
207202
Then, proceed by constructing a network object
208203

209204
```
@@ -232,7 +227,7 @@ One of the main features of our method is the ability to run in two different mo
232227
1. Embedding-aware mode - learn manifold embedding and dynamics
233228
2. Embedding-agnostic mode - learn dynamics only
234229

235-
To enable embedding-agnostic mode, set `inner_product_features=True` in training `params`. This engages an additional layer in the network after the computation of gradients, which makes them rotation invariant.
230+
To enable embedding-agnostic mode, set `inner_product_features = True` in training `params`. This engages an additional layer in the network after the computation of gradients, which makes them rotation invariant.
236231

237232
As a slight cost of expressivity, this feature enables the orientation- and embedding-independent representation of dynamics over the manifolds. Amongst others, this allows one to recognise similar dynamics across different manifolds. See [RNN example](https://github.com/agosztolai/MARBLE/blob/main/examples/RNN/RNN.ipynb) for an illustration.
238233

@@ -247,9 +242,9 @@ Training is successful when features are recognised to be similar across distinc
247242

248243
Problems with the above would be possible signs your solution will be suboptimal and will likely not generalise well. In this case, try the following
249244
* increase training time (increase `epochs`)
250-
* increase your data (e.g., decrease `spacing` and construct the dataset again)
245+
* increase your data (e.g., decrease `spacing` and increase `number_of_resamples`)
251246
* decrease number of parameters (decrease `hidden_channels`, or decrease order, try `order=1`)
252-
* improve the gradient approximation (increase `k` or `delta`)
247+
* improve the gradient approximation (increase `delta`)
253248

254249
If your data is very noisy, try enabling diffusion (`diffusion=True` in training `params`).
255250

examples/macaque_reaching/iframe_figures/figure_54.html

Lines changed: 14 additions & 0 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)