1- from jax import numpy as jnp , jit
2- from ngclearn import resolver , Component , Compartment
31from ngclearn .components .jaxComponent import JaxComponent
2+ from jax import numpy as jnp , random , jit , nn
3+ from functools import partial
44from ngclearn .utils import tensorstats
55from ngcsimlib .deprecators import deprecate_args
6+ from ngcsimlib .logger import info , warn
67from ngclearn .utils .diffeq .ode_utils import get_integrator_code , \
78 step_euler , step_rk2
89
9- @jit
10- def _update_times (t , s , tols ):
11- """
12- Updates time-of-last-spike (tols) variable.
13-
14- Args:
15- t: current time (a scalar/int value)
16-
17- s: binary spike vector
18-
19- tols: current time-of-last-spike variable
20-
21- Returns:
22- updated tols variable
23- """
24- _tols = (1. - s ) * tols + (s * t )
25- return _tols
10+ from ngcsimlib .compilers .process import transition
11+ #from ngcsimlib.component import Component
12+ from ngcsimlib .compartment import Compartment
2613
2714@jit
2815def _dfv_internal (j , v , w , tau_m , omega , b ): ## "voltage" dynamics
@@ -48,11 +35,6 @@ def _dfw(t, w, params): ## angular driver dynamics wrapper
4835 dv_dt = _dfw_internal (j , v , w , tau_w , omega , b )
4936 return dv_dt
5037
51- @jit
52- def _emit_spike (v , v_thr ):
53- s = (v > v_thr ).astype (jnp .float32 )
54- return s
55-
5638class RAFCell (JaxComponent ):
5739 """
5840 The resonate-and-fire (RAF) neuronal cell
@@ -112,14 +94,15 @@ class RAFCell(JaxComponent):
11294 and "midpoint" or "rk2" (midpoint method/RK-2 integration) (Default: "euler")
11395
11496 :Note: setting the integration type to the midpoint method will
115- increase the accuray of the estimate of the cell's evolution
97+ increase the accuracy of the estimate of the cell's evolution
11698 at an increase in computational cost (and simulation time)
11799 """
118100
119101 @deprecate_args (resist_m = "resist_v" , tau_m = "tau_v" )
120- def __init__ (self , name , n_units , tau_v = 1. , tau_w = 1. , thr = 1. , omega = 10. ,
121- b = - 1. , v_reset = 1. , w_reset = 0. , v0 = 0. , w0 = 0. , resist_v = 1. ,
122- integration_type = "euler" , batch_size = 1 , ** kwargs ):
102+ def __init__ (
103+ self , name , n_units , tau_v = 1. , tau_w = 1. , thr = 1. , omega = 10. , b = - 1. , v_reset = 0. , w_reset = 0. , v0 = 0. , w0 = 0. ,
104+ resist_v = 1. , integration_type = "euler" , batch_size = 1 , ** kwargs
105+ ):
123106 #v_rest=-72., v_reset=-75., w_reset=0., thr=5., v0=-70., w0=0., tau_w=400., thr=5., omega=10., b=-1.
124107 super ().__init__ (name , ** kwargs )
125108
@@ -150,11 +133,13 @@ def __init__(self, name, n_units, tau_v=1., tau_w=1., thr=1., omega=10.,
150133 self .v = Compartment (restVals + self .v0 , display_name = "Voltage" , units = "mV" )
151134 self .w = Compartment (restVals + self .w0 , display_name = "Angular-Driver" )
152135 self .s = Compartment (restVals , display_name = "Spikes" )
153- self .tols = Compartment (restVals , display_name = "Time-of-Last-Spike" ,
154- units = "ms" ) ## time-of-last-spike
136+ self .tols = Compartment (
137+ restVals , display_name = "Time-of-Last-Spike" , units = "ms"
138+ ) ## time-of-last-spike
155139
140+ @transition (output_compartments = ["j" , "v" , "w" , "s" , "tols" ])
156141 @staticmethod
157- def _advance_state (t , dt , tau_v , resist_v , tau_w , thr , omega , b ,
142+ def advance_state (t , dt , tau_v , resist_v , tau_w , thr , omega , b ,
158143 v_reset , w_reset , intgFlag , j , v , w , tols ):
159144 ## continue with centered dynamics
160145 j_ = j * resist_v
@@ -170,24 +155,17 @@ def _advance_state(t, dt, tau_v, resist_v, tau_w, thr, omega, b,
170155 _ , _w = step_euler (0. , w , _dfw , dt , w_params )
171156 v_params = (j_ , _w , tau_v , omega , b )
172157 _ , _v = step_euler (0. , v , _dfv , dt , v_params )
173- s = _emit_spike (_v , thr )
158+ s = (_v > thr ) * 1. ## emit spikes/pulses
174159 ## hyperpolarize/reset/snap variables
175160 w = _w * (1. - s ) + s * w_reset
176161 v = _v * (1. - s ) + s * v_reset
177162
178- tols = _update_times ( t , s , tols )
163+ tols = ( 1. - s ) * tols + ( s * t ) ## update times-of-last-spike(s )
179164 return j , v , w , s , tols
180165
181- @resolver (_advance_state )
182- def advance_state (self , j , v , w , s , tols ):
183- self .j .set (j )
184- self .w .set (w )
185- self .v .set (v )
186- self .s .set (s )
187- self .tols .set (tols )
188-
166+ @transition (output_compartments = ["j" , "v" , "w" , "s" , "tols" ])
189167 @staticmethod
190- def _reset (batch_size , n_units , v0 , w0 ):
168+ def reset (batch_size , n_units , v0 , w0 ):
191169 restVals = jnp .zeros ((batch_size , n_units ))
192170 j = restVals # None
193171 v = restVals + v0
@@ -196,14 +174,6 @@ def _reset(batch_size, n_units, v0, w0):
196174 tols = restVals #+ 0
197175 return j , v , w , s , tols
198176
199- @resolver (_reset )
200- def reset (self , j , v , w , s , tols ):
201- self .j .set (j )
202- self .v .set (v )
203- self .w .set (w )
204- self .s .set (s )
205- self .tols .set (tols )
206-
207177 @classmethod
208178 def help (cls ): ## component help function
209179 properties = {
0 commit comments