|
441 | 441 | "Env = Environment(params={\"dimensionality\": \"1D\", \"boundary_conditions\": \"periodic\"})\n", |
442 | 442 | "\n", |
443 | 443 | "# Put agent (who will move randomly under the ratinabox Ornstein Uhlenbeck random motion policy) inside the environement\n", |
444 | | - "Ag = Agent(Env, params={'dt':0.01})\n", |
| 444 | + "Ag = Agent(Env, params={'dt':0.02})\n", |
445 | 445 | "Ag.speed_mean = 0\n", |
446 | 446 | "Ag.speed_std = 0.3\n", |
447 | 447 | "\n", |
|
461 | 461 | " params={\n", |
462 | 462 | " \"n\": n_cells,\n", |
463 | 463 | " \"name\": \"ConjunctiveCells_left\",\n", |
| 464 | + " # nb. this tutorial is now quite old so the way that FeedForwardLayer --- define in the main codebase --- activations are set (passing \"activation_function\" at initialisation) no longer matches the way DendriticCompartment --- defined above --- activations are set (setting \"activation_params\" after initialisation). Sorry about this! TODO: update DendriticCompartment to be FeedForwardLayer subclass\n", |
| 465 | + " \"activation_function\": {\n", |
| 466 | + " \"activation\": \"relu\",\n", |
| 467 | + " \"threshold\": 1,\n", |
| 468 | + " }\n", |
464 | 469 | " },\n", |
465 | 470 | ")\n", |
466 | 471 | "\n", |
|
469 | 474 | " params={\n", |
470 | 475 | " \"n\": n_cells,\n", |
471 | 476 | " \"name\": \"ConjunctiveCells_right\",\n", |
| 477 | + " \"activation_function\": {\n", |
| 478 | + " \"activation\": \"relu\",\n", |
| 479 | + " \"threshold\": 1,\n", |
| 480 | + " }\n", |
472 | 481 | " },\n", |
473 | 482 | ")\n", |
474 | 483 | "\n", |
|
497 | 506 | " [-1, 1]\n", |
498 | 507 | ") # thus right velocity excites these cells and rigleftht velocity shuts them off\n", |
499 | 508 | "ConjunctiveCells_left.inputs[\"RingAttractor\"][\"w\"] = np.identity(n_cells)\n", |
500 | | - "ConjunctiveCells_right.inputs[\"RingAttractor\"][\"w\"] = np.identity(n_cells)\n", |
501 | | - "ConjunctiveCells_left.activation_params = {\n", |
502 | | - " \"activation\": \"relu\",\n", |
503 | | - " \"threshold\": 1,\n", |
504 | | - " \"width_x\": 2,\n", |
505 | | - "}\n", |
506 | | - "ConjunctiveCells_right.activation_params = {\n", |
507 | | - " \"activation\": \"relu\",\n", |
508 | | - " \"threshold\": 1,\n", |
509 | | - " \"width_x\": 2,\n", |
510 | | - "}" |
| 509 | + "ConjunctiveCells_right.inputs[\"RingAttractor\"][\"w\"] = np.identity(n_cells)" |
511 | 510 | ] |
512 | 511 | }, |
513 | 512 | { |
|
516 | 515 | "source": [ |
517 | 516 | "### Train the network\n", |
518 | 517 | "\n", |
519 | | - "Train it for 20 minutes" |
| 518 | + "Train it for 60 minutes" |
520 | 519 | ] |
521 | 520 | }, |
522 | 521 | { |
|
525 | 524 | "metadata": {}, |
526 | 525 | "outputs": [], |
527 | 526 | "source": [ |
528 | | - "for i in tqdm(range(int(10 * 60 / Ag.dt))):\n", |
| 527 | + "for i in tqdm(range(int(60 * 60 / Ag.dt))):\n", |
529 | 528 | " # update agent\n", |
530 | 529 | " Ag.update()\n", |
531 | 530 | " # update firing rates of all the cell layers\n", |
|
556 | 555 | "source": [ |
557 | 556 | "fig, ax = RingAttractor.plot_loss()\n", |
558 | 557 | "\n", |
| 558 | + "save_plots = False\n", |
559 | 559 | "if save_plots == True:\n", |
560 | 560 | " tpl.saveFigure(fig, \"PI_loss\")" |
561 | 561 | ] |
|
0 commit comments