1+ from ngclearn .components .jaxComponent import JaxComponent
12from jax import numpy as jnp , random , jit , nn
3+ from functools import partial
24from ngclearn .utils import tensorstats
35from ngcsimlib .deprecators import deprecate_args
4- from ngclearn import resolver , Component , Compartment
5- from ngclearn .components .jaxComponent import JaxComponent
6+ from ngcsimlib .logger import info , warn
67from ngclearn .utils .diffeq .ode_utils import get_integrator_code , \
78 step_euler , step_rk2
8- from ngclearn .utils .surrogate_fx import (arctan_estimator ,
9+ from ngclearn .utils .surrogate_fx import (secant_lif_estimator , arctan_estimator ,
910 triangular_estimator ,
1011 straight_through_estimator )
1112
12- @jit
13- def _update_times (t , s , tols ):
14- """
15- Updates time-of-last-spike (tols) variable.
16-
17- Args:
18- t: current time (a scalar/int value)
13+ from ngcsimlib .compilers .process import transition
14+ #from ngcsimlib.component import Component
15+ from ngcsimlib .compartment import Compartment
1916
20- s: binary spike vector
21-
22- tols: current time-of-last-spike variable
23-
24- Returns:
25- updated tols variable
26- """
27- _tols = (1. - s ) * tols + (s * t )
28- return _tols
2917
3018@jit
3119def _dfv_internal (j , v , rfr , tau_m , refract_T ): ## raw voltage dynamics
@@ -166,32 +154,26 @@ def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65.,
166154 units = "ms" ) ## time-of-last-spike
167155 self .surrogate = Compartment (restVals + 1. , display_name = "Surrogate State Value" )
168156
157+ @transition (output_compartments = ["v" , "s" , "rfr" , "tols" , "key" , "surrogate" ])
169158 @staticmethod
170- def _advance_state (t , dt , tau_m , resist_m , v_rest , v_reset , refract_T ,
171- thr , lower_clamp_voltage , intgFlag , d_spike_fx , key ,
172- j , v , rfr , tols ):
159+ def advance_state (
160+ t , dt , tau_m , resist_m , v_rest , v_reset , refract_T , thr , lower_clamp_voltage , intgFlag , d_spike_fx , key ,
161+ j , v , rfr , tols
162+ ):
173163 ## run one integration step for neuronal dynamics
174164 j = j * resist_m
175165 v , s , rfr = _run_cell (dt , j , v , thr , rfr , tau_m , v_rest , v_reset ,
176166 refract_T , intgFlag )
177167 surrogate = d_spike_fx (v , thr )
178168 ## update tols
179- tols = _update_times ( t , s , tols )
169+ tols = ( 1. - s ) * tols + ( s * t )
180170 if lower_clamp_voltage : ## ensure voltage never < v_rest
181171 v = jnp .maximum (v , v_rest )
182172 return v , s , rfr , tols , key , surrogate
183173
184- @resolver (_advance_state )
185- def advance_state (self , v , s , rfr , tols , key , surrogate ):
186- self .v .set (v )
187- self .s .set (s )
188- self .rfr .set (rfr )
189- self .tols .set (tols )
190- self .key .set (key )
191- self .surrogate .set (surrogate )
192-
174+ @transition (output_compartments = ["j" , "v" , "s" , "rfr" , "tols" , "surrogate" ])
193175 @staticmethod
194- def _reset (batch_size , n_units , v_rest , refract_T ):
176+ def reset (batch_size , n_units , v_rest , refract_T ):
195177 restVals = jnp .zeros ((batch_size , n_units ))
196178 j = restVals #+ 0
197179 v = restVals + v_rest
@@ -201,15 +183,6 @@ def _reset(batch_size, n_units, v_rest, refract_T):
201183 surrogate = restVals + 1.
202184 return j , v , s , rfr , tols , surrogate
203185
204- @resolver (_reset )
205- def reset (self , j , v , s , rfr , tols , surrogate ):
206- self .j .set (j )
207- self .v .set (v )
208- self .s .set (s )
209- self .rfr .set (rfr )
210- self .tols .set (tols )
211- self .surrogate .set (surrogate )
212-
213186 def save (self , directory , ** kwargs ):
214187 ## do a protected save of constants, depending on whether they are floats or arrays
215188 tau_m = (self .tau_m if isinstance (self .tau_m , float )
0 commit comments