Skip to content

Commit 09491ec

Browse files
authored
Merge pull request #32 from Dynamics-of-Neural-Systems-Lab/fix_rat_example
small fixes to rat example
2 parents cd3ac40 + 17056d8 commit 09491ec

File tree

2 files changed

+40
-26
lines changed

2 files changed

+40
-26
lines changed

examples/rat_hippocampus/rat_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import cebra
1515

1616
def convert_spikes_to_rates(spikes,
17-
labels,
17+
labels=None,
1818
pca=None,
1919
pca_n=10,
2020
spiking_rates=True,
@@ -42,7 +42,8 @@ def convert_spikes_to_rates(spikes,
4242

4343
rates_pca = rates_pca[:-1,:] # skip last
4444

45-
labels = labels[:rates_pca.shape[0]]
45+
if labels is not None:
46+
labels = labels[:rates_pca.shape[0]]
4647

4748
data = MARBLE.construct_dataset(
4849
anchor=rates_pca,

examples/rat_hippocampus/run_marble_and_cebra.ipynb

Lines changed: 37 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,7 @@
4646
"!wget -nc https://dataverse.harvard.edu/api/access/datafile/7609512 -O data/rat_data.pkl\n",
4747
"\n",
4848
"with open('data/rat_data.pkl', 'rb') as handle:\n",
49-
" hippocampus_pos = pickle.load(handle)\n",
50-
" \n",
51-
"rat = 'achilles'\n",
52-
"hippocampus_pos = hippocampus_pos[rat]"
49+
" hippocampus_pos = pickle.load(handle)\n"
5350
]
5451
},
5552
{
@@ -79,7 +76,8 @@
7976
" \n",
8077
" return neural_train.numpy(), neural_test.numpy(), label_train.numpy(), label_test.numpy()\n",
8178
"\n",
82-
"neural_train, neural_test, label_train, label_test = split_data(hippocampus_pos, 0.2)"
79+
"\n",
80+
"neural_train, neural_test, label_train, label_test = split_data(hippocampus_pos['achilles'], 0.2)"
8381
]
8482
},
8583
{
@@ -113,7 +111,17 @@
113111
"outputs": [],
114112
"source": [
115113
"max_iterations = 10000\n",
116-
"output_dimension = 32 #set to 3 for embeddings and 32 for decoding\n",
114+
"output_dimension = 32 #set to 3 for embeddings and 32 for decoding\n"
115+
]
116+
},
117+
{
118+
"cell_type": "code",
119+
"execution_count": null,
120+
"metadata": {
121+
"scrolled": true
122+
},
123+
"outputs": [],
124+
"source": [
117125
"\n",
118126
"for rat in list(hippocampus_pos.keys()):\n",
119127
"\n",
@@ -167,7 +175,9 @@
167175
"source": [
168176
"for rat in list(hippocampus_pos.keys()):\n",
169177
" # build model \n",
170-
" data, labels, _ = convert_spikes_to_rates(hippocampus_pos[rat][\"neural\"], label_train, pca_n=10)\n",
178+
" data, labels, _ = convert_spikes_to_rates(hippocampus_pos[rat][\"neural\"],\n",
179+
" labels=hippocampus_pos[rat]['continuous_index'].numpy(),\n",
180+
" pca_n=10)\n",
171181
" pickle.dump([data, labels], open(f'data/{rat}_preprocessed_data.pkl','wb'))\n",
172182
" \n",
173183
" # build model\n",
@@ -182,7 +192,7 @@
182192
" \"include_positions\": True,\n",
183193
" }\n",
184194
" \n",
185-
" model = MARBLE.net(data, params=params) #define model\n",
195+
" model = MARBLE.net(data, params=params.copy()) #define model\n",
186196
" model.fit(data, outdir=f\"data/hippocampus_{rat}\") # train model\n",
187197
" data = model.transform(data) #evaluate model"
188198
]
@@ -212,6 +222,7 @@
212222
" \"order\": 1, # order of derivatives\n",
213223
" \"hidden_channels\": [64], # number of internal dimensions in MLP\n",
214224
" \"out_channels\": 32, \n",
225+
" \"inner_product_features\": False,\n",
215226
" \"emb_norm\": True, # spherical output embedding\n",
216227
" \"diffusion\": False,\n",
217228
" \"include_positions\": True,\n",
@@ -221,20 +232,21 @@
221232
{
222233
"cell_type": "code",
223234
"execution_count": null,
224-
"metadata": {},
235+
"metadata": {
236+
"scrolled": true
237+
},
225238
"outputs": [],
226239
"source": [
227240
"rat = 'achilles'\n",
228241
"kernel_width = 10\n",
229242
"\n",
230243
"for pca_n in [3,5,10,20,30]:\n",
231-
" data_train, label_train_marble, pca = prepare_marble(neural_train.T, \n",
232-
" label_train, \n",
233-
" pca_n=pca_n, \n",
234-
" kernel_width=kernel_width)\n",
235244
"\n",
236-
" model = MARBLE.net(data_train, params=params)\n",
237-
" model.fit(data_train, outdir=f\"data/hippocampus_{rat}_pca{pca_n}\")"
245+
" data, _, _ = convert_spikes_to_rates(hippocampus_pos[rat][\"neural\"],\n",
246+
" pca_n=pca_n,\n",
247+
" kernel_width=kernel_width)\n",
248+
" model = MARBLE.net(data, params=params.copy())\n",
249+
" model.fit(data, outdir=f\"data/hippocampus_{rat}_pca{pca_n}\")\n"
238250
]
239251
},
240252
{
@@ -247,20 +259,21 @@
247259
{
248260
"cell_type": "code",
249261
"execution_count": null,
250-
"metadata": {},
262+
"metadata": {
263+
"scrolled": true
264+
},
251265
"outputs": [],
252266
"source": [
253267
"pca_n = 20\n",
254268
"rat = 'achilles'\n",
255269
"\n",
256270
"for kernel_width in [3,5,10,20,30,50,100]:\n",
257-
" data_train, label_train_marble, pca = convert_spikes_to_rates(neural_train.T, \n",
258-
" label_train, \n",
259-
" pca_n=pca_n, \n",
260-
" kernel_width=kernel_width)\n",
261-
" \n",
262-
" model = MARBLE.net(data_train, params=params)\n",
263-
" model.fit(data_train, outdir=f\"data/hippocampus_achilles_kw{k_width}\")"
271+
" data, _, _ = convert_spikes_to_rates(hippocampus_pos[rat][\"neural\"], \n",
272+
" pca_n=pca_n,\n",
273+
" kernel_width=kernel_width)\n",
274+
" model = MARBLE.net(data, params=params.copy())\n",
275+
" model.fit(data, outdir=f\"data/hippocampus_{rat}_kw{kernel_width}\")\n",
276+
" "
264277
]
265278
}
266279
],
@@ -280,7 +293,7 @@
280293
"name": "python",
281294
"nbconvert_exporter": "python",
282295
"pygments_lexer": "ipython3",
283-
"version": "3.10.14"
296+
"version": "3.9.18"
284297
}
285298
},
286299
"nbformat": 4,

0 commit comments

Comments
 (0)