11from jax import numpy as jnp , random , jit , nn
2- from ngclearn import resolver , Component , Compartment
32from ngclearn .components .jaxComponent import JaxComponent
3+ from jax import numpy as jnp , random , jit , nn
4+ from functools import partial
45from ngclearn .utils import tensorstats
5- from ngclearn .utils .model_utils import softmax
6-
7- @jit
8- def _update_times (t , s , tols ):
9- """
10- Updates time-of-last-spike (tols) variable.
11-
12- Args:
13- t: current time (a scalar/int value)
14-
15- s: binary spike vector
16-
17- tols: current time-of-last-spike variable
18-
19- Returns:
20- updated tols variable
21- """
22- _tols = (1. - s ) * tols + (s * t )
23- return _tols
24- @jit
25- def _run_cell (dt , j , v , rfr , v_thr , tau_m , R_m , thr_gain = 0.002 , refract_T = 0. ):
26- """
27- Runs leaky integrator neuronal dynamics
28-
29- Args:
30- dt: integration time constant (milliseconds, or ms)
31-
32- j: electrical current value
33-
34- v: membrane potential (voltage, in milliVolts or mV) value (at t)
35-
36- rfr: refractory variable vector (one per neuronal cell)
37-
38- v_thr: base voltage threshold value (in mV)
39-
40- tau_m: cell membrane time constant
6+ from ngcsimlib .deprecators import deprecate_args
7+ from ngcsimlib .logger import info , warn
418
42- R_m: cell membrane resistance
43-
44- thr_gain: increment to be applied to threshold upon spike occurrence
45-
46- refract_T: (relative) refractory time period (in ms; Default
47- value is 1 ms)
9+ from ngcsimlib .compilers .process import transition
10+ #from ngcsimlib.component import Component
11+ from ngcsimlib .compartment import Compartment
12+ from ngclearn .utils .model_utils import softmax
4813
49- Returns:
50- voltage(t+dt), spikes, updated voltage thresholds, updated refactory variables
51- """
52- mask = (rfr >= refract_T ).astype (jnp .float32 ) ## check refractory period
53- v = (j * R_m ) * mask
54- vp = softmax (v ) # convert to Categorical (spike) probabilities
55- #s = nn.one_hot(jnp.argmax(vp, axis=1), j.shape[1]) ## hard-max spike
56- s = (vp > v_thr ).astype (jnp .float32 ) ## calculate action potential
57- q = 1. ## Note: thr_gain ==> "rho_b"
58- dthr = jnp .sum (s , axis = 1 , keepdims = True ) - q
59- v_thr = jnp .maximum (v_thr + dthr * thr_gain , 0.025 ) ## calc new threshold
60- rfr = (rfr + dt ) * (1. - s ) + s * dt # set refract to dt
61- return v , s , v_thr , rfr
6214
6315class WTASCell (JaxComponent ): ## winner-take-all spiking cell
6416 """
@@ -136,22 +88,26 @@ def __init__(
13688 self .rfr = Compartment (restVals + self .refract_T )
13789 self .tols = Compartment (restVals ) ## time-of-last-spike
13890
91+ @transition (output_compartments = ["v" , "s" , "thr" , "rfr" , "tols" ])
13992 @staticmethod
140- def _advance_state (t , dt , tau_m , R_m , thr_gain , refract_T , j , v , thr , rfr , tols ):
141- v , s , thr , rfr = _run_cell (dt , j , v , rfr , thr , tau_m , R_m , thr_gain , refract_T )
142- tols = _update_times (t , s , tols ) ## update tols
93+ def advance_state (t , dt , tau_m , R_m , thr_gain , refract_T , j , v , thr , rfr , tols ):
94+ mask = (rfr >= refract_T ) * 1. ## check refractory period
95+ v = (j * R_m ) * mask
96+ vp = softmax (v ) # convert to Categorical (spike) probabilities
97+ # s = nn.one_hot(jnp.argmax(vp, axis=1), j.shape[1]) ## hard-max spike
98+ s = (vp > thr ) * 1. ## calculate action potential
99+ q = 1. ## Note: thr_gain ==> "rho_b"
100+ ## increment threshold upon spike(s) occurrence
101+ dthr = jnp .sum (s , axis = 1 , keepdims = True ) - q
102+ thr = jnp .maximum (thr + dthr * thr_gain , 0.025 ) ## calc new threshold
103+ rfr = (rfr + dt ) * (1. - s ) + s * dt # set refract to dt
104+
105+ tols = (1. - s ) * tols + (s * t ) ## update tols
143106 return v , s , thr , rfr , tols
144107
145- @resolver (_advance_state )
146- def advance_state (self , v , s , thr , rfr , tols ):
147- self .v .set (v )
148- self .s .set (s )
149- self .thr .set (thr )
150- self .rfr .set (rfr )
151- self .tols .set (tols )
152-
108+ @transition (output_compartments = ["j" , "v" , "s" , "rfr" , "tols" ])
153109 @staticmethod
154- def _reset (batch_size , n_units , refract_T ):
110+ def reset (batch_size , n_units , refract_T ):
155111 restVals = jnp .zeros ((batch_size , n_units ))
156112 j = restVals #+ 0
157113 v = restVals #+ 0
@@ -160,14 +116,6 @@ def _reset(batch_size, n_units, refract_T):
160116 tols = restVals #+ 0
161117 return j , v , s , rfr , tols
162118
163- @resolver (_reset )
164- def reset (self , j , v , s , rfr , tols ):
165- self .j .set (j )
166- self .v .set (v )
167- self .s .set (s )
168- self .rfr .set (rfr )
169- self .tols .set (tols )
170-
171119 def save (self , directory , ** kwargs ):
172120 file_name = directory + "/" + self .name + ".npz"
173121 jnp .savez (file_name , threshold = self .thr .value )
0 commit comments