Skip to content

Commit ea3396d

Browse files
author
Alexander Ororbia
committed
revised tutorials to reflect new sim-lib config/syntax
1 parent e9e314d commit ea3396d

File tree

15 files changed

+165
-138
lines changed

15 files changed

+165
-138
lines changed

docs/tutorials/model_basics/evolving_synapses.md

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ We do this specifically as follows:
1818

1919
```python
2020
from jax import numpy as jnp, random, jit
21-
from ngcsimlib.compilers import compile_command, wrap_command
2221
from ngcsimlib.context import Context
22+
from ngcsimlib.compilers.process import Process
2323
from ngclearn.components import HebbianSynapse, RateCell
2424
import ngclearn.utils.weight_distribution as dist
2525

@@ -48,17 +48,18 @@ with Context("Circuit") as circuit:
4848
# wire output compartment (rate-coded output zF) of RateCell `b` to postsynaptic compartment of HebbianSynapse `Wab`
4949
Wab.post << b.zF
5050

51-
## create and compile core simulation commands
52-
reset_cmd, reset_args = circuit.compile_by_key(a, Wab, b, compile_key="reset")
53-
circuit.add_command(wrap_command(jit(circuit.reset)), name="reset")
54-
55-
advance_cmd, advance_args = circuit.compile_by_key(a, Wab, b,
56-
compile_key="advance_state")
57-
circuit.add_command(wrap_command(jit(circuit.advance_state)), name="advance")
58-
59-
evolve_cmd, evolve_args = circuit.compile_by_key(Wab, compile_key="evolve")
60-
circuit.add_command(wrap_command(jit(circuit.evolve)), name="evolve")
61-
51+
## create and compile core simulation commands
52+
evolve_process = (Process()
53+
>> a.evolve)
54+
circuit.wrap_and_add_command(jit(evolve_process.pure), name="evolve")
55+
56+
advance_process = (Process()
57+
>> a.advance_state)
58+
circuit.wrap_and_add_command(jit(advance_process.pure), name="advance")
59+
60+
reset_process = (Process()
61+
>> a.reset)
62+
circuit.wrap_and_add_command(jit(reset_process.pure), name="reset")
6263

6364
## set up non-compiled utility commands
6465
@Context.dynamicCommand

docs/tutorials/model_basics/model_building.md

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ While building our dynamical system we will set up a Context and then add the th
1010
```python
1111
from jax import numpy as jnp, random
1212
from ngclearn import Context
13-
from ngclearn.commands import Reset, Clamp, AdvanceState
13+
from ngcsimlib.compilers.process import Process
1414
from ngclearn.components import RateCell, HebbianSynapse
1515
import ngclearn.utils.weight_distribution as dist
1616

@@ -51,8 +51,9 @@ nodes `a` and `b` since a basic synapse component like `Wab` does not have a
5151
base/resting value), an `advance` (which moves all the nodes one step
5252
forward in time according to their compartments' ODEs), and `clamp` (which will
5353
allow us to insert data into particular nodes).
54-
This is simply done with the following few lines:
54+
This is simply done with the use of the following convenience function calls:
5555

56+
<!--
5657
```python
5758
## configure desired commands for simulation object
5859
Reset(command_name="reset",
@@ -65,6 +66,28 @@ This is simply done with the following few lines:
6566
compartment="j",
6667
clamp_name="x")
6768
```
69+
-->
70+
71+
72+
```python
73+
## configure desired commands for simulation object
74+
reset_process = (Process()
75+
>> a.reset
76+
>> Wab.reset
77+
>> b.reset)
78+
model.wrap_and_add_command(jit(reset_process.pure), name="reset")
79+
80+
advance_process = (Process()
81+
>> a.advance_state
82+
>> Wab.advance_state
83+
>> b.advance_state)
84+
model.wrap_and_add_command(jit(advance_process.pure), name="advance")
85+
86+
## set up clamp as a non-compiled utility commands
87+
@Context.dynamicCommand
88+
def clamp(x):
89+
a.j.set(x)
90+
```
6891

6992
## Running the Dynamical System's Controller
7093

docs/tutorials/neurocog/error_cell.md

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,8 @@ The code you would write amounts to the below:
5353

5454
```python
5555
from jax import numpy as jnp, jit
56-
import time
57-
5856
from ngcsimlib.context import Context
59-
from ngcsimlib.commands import Command
60-
from ngcsimlib.compilers import compile_command, wrap_command
61-
from ngclearn.utils.viz.raster import create_raster_plot
57+
from ngcsimlib.compilers.process import Process, transition
6258
## import model-specific mechanisms
6359
from ngclearn.components.neurons.graded.gaussianErrorCell import GaussianErrorCell
6460

@@ -68,11 +64,13 @@ T = 5 ## number time steps to simulate
6864
with Context("Model") as model:
6965
cell = GaussianErrorCell("z0", n_units=3)
7066

71-
reset_cmd, reset_args = model.compile_by_key(cell, compile_key="reset")
72-
advance_cmd, advance_args = model.compile_by_key(cell, compile_key="advance_state")
67+
advance_process = (Process()
68+
>> cell.advance_state)
69+
model.wrap_and_add_command(jit(advance_process.pure), name="advance")
7370

74-
model.add_command(wrap_command(jit(model.reset)), name="reset")
75-
model.add_command(wrap_command(jit(model.advance_state)), name="advance")
71+
reset_process = (Process()
72+
>> cell.reset)
73+
model.wrap_and_add_command(jit(reset_process.pure), name="reset")
7674

7775

7876
@Context.dynamicCommand

docs/tutorials/neurocog/fitzhugh_nagumo_cell.md

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,9 @@ single component system made up of the Fitzhugh-Nagumo (`F-N`) cell.
1717
from jax import numpy as jnp, random, jit
1818
import numpy as np
1919

20-
from ngclearn.utils.model_utils import scanner
21-
from ngcsimlib.compilers import compile_command, wrap_command
2220
from ngcsimlib.context import Context
23-
from ngcsimlib.commands import Command
21+
from ngcsimlib.compilers.process import Process
2422
## import model-specific mechanisms
25-
from ngclearn.operations import summation
2623
from ngclearn.components.neurons.spiking.fitzhughNagumoCell import FitzhughNagumoCell
2724

2825
## create seeding keys (JAX-style)
@@ -43,11 +40,13 @@ with Context("Model") as model:
4340
gamma=gamma, v0=v0, w0=w0, integration_type="euler")
4441

4542
## create and compile core simulation commands
46-
reset_cmd, reset_args = model.compile_by_key(cell, compile_key="reset")
47-
model.add_command(wrap_command(jit(model.reset)), name="reset")
48-
advance_cmd, advance_args = model.compile_by_key(cell, compile_key="advance_state")
49-
model.add_command(wrap_command(jit(model.advance_state)), name="advance")
43+
advance_process = (Process()
44+
>> cell.advance_state)
45+
model.wrap_and_add_command(jit(advance_process.pure), name="advance")
5046

47+
reset_process = (Process()
48+
>> cell.reset)
49+
model.wrap_and_add_command(jit(reset_process.pure), name="reset")
5150

5251
## set up non-compiled utility commands
5352
@Context.dynamicCommand

docs/tutorials/neurocog/hebbian.md

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,13 @@ Wab.post << b.zF
3838
as well as (a bit later in the model construction code):
3939

4040
```python
41-
advance_cmd, advance_args = circuit.compile_by_key(a, Wab, b, compile_key="advance_state")
42-
circuit.add_command(wrap_command(jit(circuit.advance_state)), name="advance")
41+
evolve_process = (Process()
42+
>> a.evolve)
43+
circuit.wrap_and_add_command(jit(evolve_process.pure), name="evolve")
4344

44-
evolve_cmd, evolve_args = circuit.compile_by_key(Wab, compile_key="evolve")
45-
circuit.add_command(wrap_command(jit(circuit.evolve)), name="evolve")
45+
advance_process = (Process()
46+
>> a.advance_state)
47+
circuit.wrap_and_add_command(jit(advance_process.pure), name="advance")
4648
```
4749

4850
Notice that beyond wiring component `a`'s values into the synapse `Wab`'s input compartment

docs/tutorials/neurocog/input_cells.md

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,11 @@ spike train over $100$ steps in time as follows:
3939

4040
```python
4141
from jax import numpy as jnp, random, jit
42-
import time
43-
4442
from ngcsimlib.context import Context
45-
from ngcsimlib.commands import Command
46-
from ngcsimlib.compilers import compile_command, wrap_command
43+
from ngcsimlib.compilers.process import Process
44+
4745
from ngclearn.utils.viz.raster import create_raster_plot
4846
## import model-specific mechanisms
49-
from ngclearn.operations import summation
5047
from ngclearn.components.input_encoders.bernoulliCell import BernoulliCell
5148

5249
## create seeding keys (JAX-style)
@@ -59,11 +56,13 @@ T = 100 ## number time steps to simulate
5956
with Context("Model") as model:
6057
cell = BernoulliCell("z0", n_units=10, key=subkeys[0])
6158

62-
reset_cmd, reset_args = model.compile_by_key(cell, compile_key="reset")
63-
advance_cmd, advance_args = model.compile_by_key(cell, compile_key="advance_state")
59+
advance_process = (Process()
60+
>> cell.advance_state)
61+
model.wrap_and_add_command(jit(advance_process.pure), name="advance")
6462

65-
model.add_command(wrap_command(jit(model.reset)), name="reset")
66-
model.add_command(wrap_command(jit(model.advance_state)), name="advance")
63+
reset_process = (Process()
64+
>> cell.reset)
65+
model.wrap_and_add_command(jit(reset_process.pure), name="reset")
6766

6867

6968
@Context.dynamicCommand

docs/tutorials/neurocog/integration.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ $\frac{\partial y(t)}{\partial t} = -2 t^3 + 12 t^2 - 20 t + 8.5$ which
121121
has the analytic solution $y(t) = -(1/2) t^4 + 4 t^3 - 10 t^2 + 8.5 t + C$ (
122122
where we will set $C = 1$). You can write code like below, importing from
123123
`ngclearn.utils.diffeq.ode_utils` the Euler routine (`step_euler`),
124-
the RK-2 routine (`step_rk2`), RK-4 routine (`step_rk4`), and Heun's method (`step_heun`), and compare
124+
the RK-2 routine (`step_rk2`), the RK-4 routine (`step_rk4`), and Heun's method (`step_heun`), and compare
125125
how these methods approximate the nonlinear dynamics inherent to our
126126
constructed $\frac{\partial y(t)}{\partial t}$ ODE below:
127127

docs/tutorials/neurocog/izhikevich_cell.md

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,9 @@ single component system made up of the Izhikevich (`IZH`) cell.
1919
from jax import numpy as jnp, random, jit
2020
import numpy as np
2121

22-
from ngclearn.utils.model_utils import scanner
23-
from ngcsimlib.compilers import compile_command, wrap_command
2422
from ngcsimlib.context import Context
25-
from ngcsimlib.commands import Command
23+
from ngcsimlib.compilers.process import Process
2624
## import model-specific mechanisms
27-
from ngclearn.operations import summation
2825
from ngclearn.components.neurons.spiking.izhikevichCell import IzhikevichCell
2926

3027
## create seeding keys (JAX-style)
@@ -47,11 +44,13 @@ with Context("Model") as model:
4744
integration_type="euler", v0=v0, w0=w0, key=subkeys[0])
4845

4946
## create and compile core simulation commands
50-
reset_cmd, reset_args = model.compile_by_key(cell, compile_key="reset")
51-
model.add_command(wrap_command(jit(model.reset)), name="reset")
52-
advance_cmd, advance_args = model.compile_by_key(cell, compile_key="advance_state")
53-
model.add_command(wrap_command(jit(model.advance_state)), name="advance")
47+
advance_process = (Process()
48+
>> cell.advance_state)
49+
model.wrap_and_add_command(jit(advance_process.pure), name="advance")
5450

51+
reset_process = (Process()
52+
>> cell.reset)
53+
model.wrap_and_add_command(jit(reset_process.pure), name="reset")
5554

5655
## set up non-compiled utility commands
5756
@Context.dynamicCommand

docs/tutorials/neurocog/lif.md

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,10 @@ cell, you would write code akin to the following:
2323

2424
```python
2525
from jax import numpy as jnp, random, jit
26-
import numpy as np
2726

28-
from ngclearn.utils.model_utils import scanner
29-
from ngcsimlib.compilers import compile_command, wrap_command
3027
from ngcsimlib.context import Context
31-
from ngcsimlib.commands import Command
28+
from ngcsimlib.compilers.process import Process
3229
## import model-specific mechanisms
33-
from ngclearn.operations import summation
3430
from ngclearn.components.neurons.spiking.LIFCell import LIFCell
3531
from ngclearn.utils.viz.spike_plot import plot_spiking_neuron
3632

@@ -51,11 +47,13 @@ with Context("Model") as model:
5147
refract_time=2., key=subkeys[0])
5248

5349
## create and compile core simulation commands
54-
reset_cmd, reset_args = model.compile_by_key(cell, compile_key="reset")
55-
model.add_command(wrap_command(jit(model.reset)), name="reset")
56-
advance_cmd, advance_args = model.compile_by_key(cell,
57-
compile_key="advance_state")
58-
model.add_command(wrap_command(jit(model.advance_state)), name="advance")
50+
advance_process = (Process()
51+
>> cell.advance_state)
52+
model.wrap_and_add_command(jit(advance_process.pure), name="advance")
53+
54+
reset_process = (Process()
55+
>> cell.reset)
56+
model.wrap_and_add_command(jit(reset_process.pure), name="reset")
5957

6058

6159
## set up non-compiled utility commands

docs/tutorials/neurocog/mod_stdp.md

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ and the required compiled simulation and dynamic commands, can be done as follow
4242
```python
4343
from jax import numpy as jnp, random, jit
4444
from ngcsimlib.context import Context
45-
from ngcsimlib.compilers import compile_command, wrap_command
45+
from ngcsimlib.compilers.process import Process
4646
## import model-specific mechanisms
4747
from ngclearn.components import (TraceSTDPSynapse, MSTDPETSynapse,
4848
RewardErrorCell, VarTrace)
@@ -75,16 +75,30 @@ with Context("Model") as model:
7575
tr1 = VarTrace("tr1", n_units=1, tau_tr=tau_post, a_delta=Aminus)
7676
rpe = RewardErrorCell("r", n_units=1, alpha=0.)
7777

78-
reset_cmd, reset_args = model.compile_by_key(
79-
W_stdp, W_mstdp, W_mstdpet, rpe, tr0, tr1, compile_key="reset")
80-
adv_tr_cmd, _ = model.compile_by_key(
81-
tr0, tr1, rpe, W_stdp, W_mstdp, W_mstdpet, compile_key="advance_state")
82-
evolve_cmd, _ = model.compile_by_key(
83-
W_stdp, W_mstdp, W_mstdpet, compile_key="evolve")
84-
85-
model.add_command(wrap_command(jit(model.reset)), name="reset")
86-
model.add_command(wrap_command(jit(model.advance_state)), name="advance")
87-
model.add_command(wrap_command(jit(model.evolve)), name="evolve")
78+
evolve_process = (Process()
79+
>> W_stdp.evolve
80+
>> W_mstdp.evolve
81+
>> W_mstdpet.evolve)
82+
model.wrap_and_add_command(jit(evolve_process.pure), name="evolve")
83+
84+
advance_process = (Process()
85+
>> tr0.advance_state
86+
>> tr1.advance_state
87+
>> rpe.advance_state
88+
>> W_stdp.advance_state
89+
>> W_mstdp.advance_state
90+
>> W_mstdpet.advance_state)
91+
model.wrap_and_add_command(jit(advance_process.pure), name="advance")
92+
93+
reset_process = (Process()
94+
>> W_stdp.reset
95+
>> W_mstdp.reset
96+
>> W_mstdpet.reset
97+
>> rpe.reset
98+
>> tr0.reset
99+
>> tr1.reset
100+
)
101+
model.wrap_and_add_command(jit(reset_process.pure), name="reset")
88102

89103
@Context.dynamicCommand
90104
def clamp_spikes(f_j, f_i):

0 commit comments

Comments
 (0)