|
46 | 46 | "!wget -nc https://dataverse.harvard.edu/api/access/datafile/7609512 -O data/rat_data.pkl\n", |
47 | 47 | "\n", |
48 | 48 | "with open('data/rat_data.pkl', 'rb') as handle:\n", |
49 | | - " hippocampus_pos = pickle.load(handle)\n", |
50 | | - " \n", |
51 | | - "rat = 'achilles'\n", |
52 | | - "hippocampus_pos = hippocampus_pos[rat]" |
| 49 | + " hippocampus_pos = pickle.load(handle)\n" |
53 | 50 | ] |
54 | 51 | }, |
55 | 52 | { |
|
79 | 76 | " \n", |
80 | 77 | " return neural_train.numpy(), neural_test.numpy(), label_train.numpy(), label_test.numpy()\n", |
81 | 78 | "\n", |
82 | | - "neural_train, neural_test, label_train, label_test = split_data(hippocampus_pos, 0.2)" |
| 79 | + "\n", |
| 80 | + "neural_train, neural_test, label_train, label_test = split_data(hippocampus_pos['achilles'], 0.2)" |
83 | 81 | ] |
84 | 82 | }, |
85 | 83 | { |
|
113 | 111 | "outputs": [], |
114 | 112 | "source": [ |
115 | 113 | "max_iterations = 10000\n", |
116 | | - "output_dimension = 32 #set to 3 for embeddings and 32 for decoding\n", |
| 114 | + "output_dimension = 32 #set to 3 for embeddings and 32 for decoding\n" |
| 115 | + ] |
| 116 | + }, |
| 117 | + { |
| 118 | + "cell_type": "code", |
| 119 | + "execution_count": null, |
| 120 | + "metadata": { |
| 121 | + "scrolled": true |
| 122 | + }, |
| 123 | + "outputs": [], |
| 124 | + "source": [ |
117 | 125 | "\n", |
118 | 126 | "for rat in list(hippocampus_pos.keys()):\n", |
119 | 127 | "\n", |
|
167 | 175 | "source": [ |
168 | 176 | "for rat in list(hippocampus_pos.keys()):\n", |
169 | 177 | " # build model \n", |
170 | | - " data, labels, _ = convert_spikes_to_rates(hippocampus_pos[rat][\"neural\"], label_train, pca_n=10)\n", |
| 178 | + " data, labels, _ = convert_spikes_to_rates(hippocampus_pos[rat][\"neural\"],\n", |
| 179 | + " labels=hippocampus_pos[rat]['continuous_index'].numpy(),\n", |
| 180 | + " pca_n=10)\n", |
171 | 181 | " pickle.dump([data, labels], open(f'data/{rat}_preprocessed_data.pkl','wb'))\n", |
172 | 182 | " \n", |
173 | 183 | " # build model\n", |
|
182 | 192 | " \"include_positions\": True,\n", |
183 | 193 | " }\n", |
184 | 194 | " \n", |
185 | | - " model = MARBLE.net(data, params=params) #define model\n", |
| 195 | + " model = MARBLE.net(data, params=params.copy()) #define model\n", |
186 | 196 | " model.fit(data, outdir=f\"data/hippocampus_{rat}\") # train model\n", |
187 | 197 | " data = model.transform(data) #evaluate model" |
188 | 198 | ] |
|
212 | 222 | " \"order\": 1, # order of derivatives\n", |
213 | 223 | " \"hidden_channels\": [64], # number of internal dimensions in MLP\n", |
214 | 224 | " \"out_channels\": 32, \n", |
| 225 | + " \"inner_product_features\": False,\n", |
215 | 226 | " \"emb_norm\": True, # spherical output embedding\n", |
216 | 227 | " \"diffusion\": False,\n", |
217 | 228 | " \"include_positions\": True,\n", |
|
221 | 232 | { |
222 | 233 | "cell_type": "code", |
223 | 234 | "execution_count": null, |
224 | | - "metadata": {}, |
| 235 | + "metadata": { |
| 236 | + "scrolled": true |
| 237 | + }, |
225 | 238 | "outputs": [], |
226 | 239 | "source": [ |
227 | 240 | "rat = 'achilles'\n", |
228 | 241 | "kernel_width = 10\n", |
229 | 242 | "\n", |
230 | 243 | "for pca_n in [3,5,10,20,30]:\n", |
231 | | - " data_train, label_train_marble, pca = prepare_marble(neural_train.T, \n", |
232 | | - " label_train, \n", |
233 | | - " pca_n=pca_n, \n", |
234 | | - " kernel_width=kernel_width)\n", |
235 | 244 | "\n", |
236 | | - " model = MARBLE.net(data_train, params=params)\n", |
237 | | - " model.fit(data_train, outdir=f\"data/hippocampus_{rat}_pca{pca_n}\")" |
| 245 | + " data, _, _ = convert_spikes_to_rates(hippocampus_pos[rat][\"neural\"],\n", |
| 246 | + " pca_n=pca_n,\n", |
| 247 | + " kernel_width=kernel_width)\n", |
| 248 | + " model = MARBLE.net(data, params=params.copy())\n", |
| 249 | + " model.fit(data, outdir=f\"data/hippocampus_{rat}_pca{pca_n}\")\n" |
238 | 250 | ] |
239 | 251 | }, |
240 | 252 | { |
|
247 | 259 | { |
248 | 260 | "cell_type": "code", |
249 | 261 | "execution_count": null, |
250 | | - "metadata": {}, |
| 262 | + "metadata": { |
| 263 | + "scrolled": true |
| 264 | + }, |
251 | 265 | "outputs": [], |
252 | 266 | "source": [ |
253 | 267 | "pca_n = 20\n", |
254 | 268 | "rat = 'achilles'\n", |
255 | 269 | "\n", |
256 | 270 | "for kernel_width in [3,5,10,20,30,50,100]:\n", |
257 | | - " data_train, label_train_marble, pca = convert_spikes_to_rates(neural_train.T, \n", |
258 | | - " label_train, \n", |
259 | | - " pca_n=pca_n, \n", |
260 | | - " kernel_width=kernel_width)\n", |
261 | | - " \n", |
262 | | - " model = MARBLE.net(data_train, params=params)\n", |
263 | | - " model.fit(data_train, outdir=f\"data/hippocampus_achilles_kw{k_width}\")" |
| 271 | + " data, _, _ = convert_spikes_to_rates(hippocampus_pos[rat][\"neural\"], \n", |
| 272 | + " pca_n=pca_n,\n", |
| 273 | + " kernel_width=kernel_width)\n", |
| 274 | + " model = MARBLE.net(data, params=params.copy())\n", |
| 275 | + " model.fit(data, outdir=f\"data/hippocampus_{rat}_kw{kernel_width}\")\n", |
| 276 | + " " |
264 | 277 | ] |
265 | 278 | } |
266 | 279 | ], |
|
280 | 293 | "name": "python", |
281 | 294 | "nbconvert_exporter": "python", |
282 | 295 | "pygments_lexer": "ipython3", |
283 | | - "version": "3.10.14" |
| 296 | + "version": "3.9.18" |
284 | 297 | } |
285 | 298 | }, |
286 | 299 | "nbformat": 4, |
|
0 commit comments