@@ -39,8 +39,7 @@ spike train over $100$ steps in time as follows:
3939
4040``` python
4141from jax import numpy as jnp, random, jit
42- from ngcsimlib.context import Context
43- from ngclearn.utils import JaxProcess
42+ from ngclearn import Context, MethodProcess
4443
4544from ngclearn.utils.viz.raster import create_raster_plot
4645# # import model-specific mechanisms
@@ -56,27 +55,24 @@ T = 100 ## number time steps to simulate
5655with Context(" Model" ) as model:
5756 cell = BernoulliCell(" z0" , n_units = 10 , key = subkeys[0 ])
5857
59- advance_process = (JaxProcess( )
58+ advance_process = (MethodProcess( " advance_proc " )
6059 >> cell.advance_state)
61- model.wrap_and_add_command(jit(advance_process.pure), name = " advance" )
6260
63- reset_process = (JaxProcess( )
61+ reset_process = (MethodProcess( " reset_proc " )
6462 >> cell.reset)
65- model.wrap_and_add_command(jit(reset_process.pure), name = " reset" )
6663
67-
68- @Context.dynamicCommand
69- def clamp (x ):
70- cell.inputs.set(x)
64+ def clamp (x ):
65+ cell.inputs.set(x)
66+
7167
7268probs = jnp.asarray([[0.8 , 0.2 , 0 ., 0.55 , 0.9 , 0 , 0.15 , 0 ., 0.6 , 0.77 ]], dtype = jnp.float32)
7369spikes = []
74- model.reset()
70+ reset_process.run()
7571for ts in range (T):
76- model. clamp(probs)
77- model.advance (t = ts * 1 ., dt = dt)
72+ clamp(probs)
73+ advance_process.run (t = ts * 1 ., dt = dt)
7874
79- s_t = cell.outputs.value
75+ s_t = cell.outputs.get()
8076 spikes.append(s_t)
8177spikes = jnp.concatenate(spikes, axis = 0 )
8278create_raster_plot(spikes, plot_fname = " input_cell_raster.jpg" )
@@ -121,7 +117,7 @@ and by replacing the line that has the `BernoulliCell` call with the
121117following line instead:
122118
123119``` python
124- cell = PoissonCell(" z0" , n_units = 10 , max_freq = 63.75 , key = subkeys[0 ])
120+ cell = PoissonCell(" z0" , n_units = 10 , target_freq = 63.75 , key = subkeys[0 ])
125121```
126122
127123Running the code with the two above small modifications will
@@ -149,12 +145,12 @@ mu = 0.
149145probs = jnp.asarray([[1 .]],dtype = jnp.float32)
150146for _ in range (n_trials):
151147 spikes = []
152- model.reset()
148+ reset_process.run()
153149 for ts in range (T):
154- model. clamp(probs)
155- model.advance (t = ts* 1 ., dt = dt)
150+ clamp(probs)
151+ advance_process.run (t = ts * 1 ., dt = dt)
156152
157- s_t = cell.outputs.value
153+ s_t = cell.outputs.get()
158154 spikes.append(s_t)
159155 count = jnp.sum(jnp.concatenate(spikes, axis = 0 ))
160156 mu += count
0 commit comments