1- from ngclearn import resolver , Component , Compartment
21from ngclearn .components .jaxComponent import JaxComponent
3- from ngclearn .utils import tensorstats
4- from ngclearn .utils .model_utils import clamp_min , clamp_max
52from jax import numpy as jnp , random , jit
63from functools import partial
7- from ngcsimlib .logger import info
8-
9- @jit
10- def _update_times (t , s , tols ):
11- """
12- Updates time-of-last-spike (tols) variable.
13-
14- Args:
15- t: current time (a scalar/int value)
4+ from ngclearn .utils import tensorstats
5+ from ngcsimlib .deprecators import deprecate_args
6+ from ngcsimlib .logger import info , warn
167
17- s: binary spike vector
8+ from ngcsimlib .compilers .process import transition
9+ #from ngcsimlib.component import Component
10+ from ngcsimlib .compartment import Compartment
1811
19- tols: current time-of-last-spike variable
12+ from ngclearn . utils . model_utils import clamp_min , clamp_max
2013
21- Returns:
22- updated tols variable
23- """
24- _tols = (1. - s ) * tols + (s * t )
25- return _tols
2614
2715@partial (jit , static_argnums = [5 ])
2816def _calc_spike_times_linear (data , tau , thr , first_spk_t , num_steps = 1. ,
@@ -157,9 +145,10 @@ class LatencyCell(JaxComponent):
157145 """
158146
159147 # Define Functions
160- def __init__ (self , name , n_units , tau = 1. , threshold = 0.01 , first_spike_time = 0. ,
161- linearize = False , normalize = False , clip_spikes = False , num_steps = 1. ,
162- batch_size = 1 , ** kwargs ):
148+ def __init__ (
149+ self , name , n_units , tau = 1. , threshold = 0.01 , first_spike_time = 0. , linearize = False , normalize = False ,
150+ clip_spikes = False , num_steps = 1. , batch_size = 1 , ** kwargs
151+ ):
163152 super ().__init__ (name , ** kwargs )
164153
165154 ## latency meta-parameters
@@ -186,9 +175,11 @@ def __init__(self, name, n_units, tau=1., threshold=0.01, first_spike_time=0.,
186175 self .targ_sp_times = Compartment (restVals , display_name = "Target Spike Time" , units = "ms" )
187176 #self.reset()
188177
178+ @transition (output_compartments = ["targ_sp_times" , "clip_mask" ])
189179 @staticmethod
190- def _calc_spike_times (linearize , tau , threshold , first_spike_time , num_steps ,
191- normalize , clip_spikes , inputs ):
180+ def calc_spike_times (
181+ linearize , tau , threshold , first_spike_time , num_steps , normalize , clip_spikes , inputs
182+ ):
192183 ## would call this function before processing a spike train (at start)
193184 data = inputs
194185 if clip_spikes :
@@ -208,42 +199,27 @@ def _calc_spike_times(linearize, tau, threshold, first_spike_time, num_steps,
208199 targ_sp_times = stimes #* calcEvent + targ_sp_times * (1. - calcEvent)
209200 return targ_sp_times , clip_mask
210201
211- @resolver (_calc_spike_times )
212- def calc_spike_times (self , targ_sp_times , clip_mask ):
213- self .targ_sp_times .set (targ_sp_times )
214- self .clip_mask .set (clip_mask )
215-
202+ @transition (output_compartments = ["outputs" , "tols" , "mask" , "targ_sp_times" , "key" ])
216203 @staticmethod
217- def _advance_state (t , dt , key , inputs , mask , clip_mask , targ_sp_times , tols ):
204+ def advance_state (t , dt , key , inputs , mask , clip_mask , targ_sp_times , tols ):
218205 key , * subkeys = random .split (key , 2 )
219- data = inputs ## get sensory pattern data / features
220- spikes , spk_mask = _extract_spike (targ_sp_times , t , mask ) ## get spikes at t
221- tols = _update_times (t , spikes , tols )
206+ data = inputs ## get sensory pattern data / features
207+ spikes , spk_mask = _extract_spike (targ_sp_times , t , mask ) ## get spikes at t
208+
209+ # Updates time-of-last-spike (tols) variable:
210+ # output = s = binary spike vector
211+ # tols = current time-of-last-spike variable
212+ tols = (1. - spikes ) * tols + (spikes * t )
213+
222214 spikes = spikes * (1. - clip_mask )
223215 return spikes , tols , spk_mask , targ_sp_times , key
224216
225- @resolver (_advance_state )
226- def advance_state (self , outputs , tols , mask , targ_sp_times , key ):
227- self .outputs .set (outputs )
228- self .tols .set (tols )
229- self .mask .set (mask )
230- self .targ_sp_times .set (targ_sp_times )
231- self .key .set (key )
232-
217+ @transition (output_compartments = ["inputs" , "outputs" , "tols" , "mask" , "clip_mask" , "targ_sp_times" ])
233218 @staticmethod
234- def _reset (batch_size , n_units ):
219+ def reset (batch_size , n_units ):
235220 restVals = jnp .zeros ((batch_size , n_units ))
236221 return (restVals , restVals , restVals , restVals , restVals , restVals )
237222
238- @resolver (_reset )
239- def reset (self , inputs , outputs , tols , mask , clip_mask , targ_sp_times ):
240- self .inputs .set (inputs )
241- self .outputs .set (outputs )
242- self .tols .set (tols )
243- self .mask .set (mask )
244- self .clip_mask .set (clip_mask )
245- self .targ_sp_times .set (targ_sp_times )
246-
247223 def save (self , directory , ** kwargs ):
248224 file_name = directory + "/" + self .name + ".npz"
249225 jnp .savez (file_name , key = self .key .value )
0 commit comments