@@ -42,7 +42,7 @@ and the required compiled simulation and dynamic commands, can be done as follow
4242``` python
4343from jax import numpy as jnp, random, jit
4444from ngcsimlib.context import Context
45- from ngcsimlib.compilers import compile_command, wrap_command
45+ from ngcsimlib.compilers.process import Process
4646# # import model-specific mechanisms
4747from ngclearn.components import (TraceSTDPSynapse, MSTDPETSynapse,
4848 RewardErrorCell, VarTrace)
@@ -75,16 +75,30 @@ with Context("Model") as model:
7575 tr1 = VarTrace(" tr1" , n_units = 1 , tau_tr = tau_post, a_delta = Aminus)
7676 rpe = RewardErrorCell(" r" , n_units = 1 , alpha = 0 .)
7777
78- reset_cmd, reset_args = model.compile_by_key(
79- W_stdp, W_mstdp, W_mstdpet, rpe, tr0, tr1, compile_key = " reset" )
80- adv_tr_cmd, _ = model.compile_by_key(
81- tr0, tr1, rpe, W_stdp, W_mstdp, W_mstdpet, compile_key = " advance_state" )
82- evolve_cmd, _ = model.compile_by_key(
83- W_stdp, W_mstdp, W_mstdpet, compile_key = " evolve" )
84-
85- model.add_command(wrap_command(jit(model.reset)), name = " reset" )
86- model.add_command(wrap_command(jit(model.advance_state)), name = " advance" )
87- model.add_command(wrap_command(jit(model.evolve)), name = " evolve" )
78+ evolve_process = (Process()
79+ >> W_stdp.evolve
80+ >> W_mstdp.evolve
81+ >> W_mstdpet.evolve)
82+ model.wrap_and_add_command(jit(evolve_process.pure), name = " evolve" )
83+
84+ advance_process = (Process()
85+ >> tr0.advance_state
86+ >> tr1.advance_state
87+ >> rpe.advance_state
88+ >> W_stdp.advance_state
89+ >> W_mstdp.advance_state
90+ >> W_mstdpet.advance_state)
91+ model.wrap_and_add_command(jit(advance_process.pure), name = " advance" )
92+
93+ reset_process = (Process()
94+ >> W_stdp.reset
95+ >> W_mstdp.reset
96+ >> W_mstdpet.reset
97+ >> rpe.reset
98+ >> tr0.reset
99+ >> tr1.reset
100+ )
101+ model.wrap_and_add_command(jit(reset_process.pure), name = " reset" )
88102
89103 @Context.dynamicCommand
90104 def clamp_spikes (f_j , f_i ):
0 commit comments