@@ -71,12 +71,15 @@ class VarTrace(JaxComponent): ## low-pass filter
7171 2) `'exp'` = exponential trace filter, i.e., decay = exp(-dt/tau_tr) * x_tr;
7272 3) `'step'` = step trace, i.e., decay = 0 (a pulse applied upon input value)
7373
74+ n_nearest_spikes: (k) if k > 0, this makes the trace act like a nearest-neighbor trace,
75+ i.e., k = 1 yields the 1-nearest (neighbor) trace (Default: 0)
76+
7477 batch_size: batch size dimension of this cell (Default: 1)
7578 """
7679
7780 # Define Functions
7881 def __init__ (self , name , n_units , tau_tr , a_delta , P_scale = 1. , gamma_tr = 1 , decay_type = "exp" ,
79- batch_size = 1 , ** kwargs ):
82+ n_nearest_spikes = 0 , batch_size = 1 , ** kwargs ):
8083 super ().__init__ (name , ** kwargs )
8184
8285 ## Trace control coefficients
@@ -85,6 +88,7 @@ def __init__(self, name, n_units, tau_tr, a_delta, P_scale=1., gamma_tr=1, decay
8588 self .P_scale = P_scale ## trace scale if non-additive trace to be used
8689 self .gamma_tr = gamma_tr
8790 self .decay_type = decay_type ## lin --> linear decay; exp --> exponential decay
91+ self .n_nearest_spikes = n_nearest_spikes
8892
8993 ## Layer Size Setup
9094 self .batch_size = batch_size
@@ -97,17 +101,22 @@ def __init__(self, name, n_units, tau_tr, a_delta, P_scale=1., gamma_tr=1, decay
97101
98102 @transition (output_compartments = ["outputs" , "trace" ])
99103 @staticmethod
100- def advance_state (dt , decay_type , tau_tr , a_delta , P_scale , gamma_tr , inputs , trace ):
104+ def advance_state (
105+ dt , decay_type , tau_tr , a_delta , P_scale , gamma_tr , inputs , trace , n_nearest_spikes
106+ ):
101107 decayFactor = 0.
102108 if "exp" in decay_type :
103109 decayFactor = jnp .exp (- dt / tau_tr )
104110 elif "lin" in decay_type :
105111 decayFactor = (1. - dt / tau_tr )
106112 _x_tr = gamma_tr * trace * decayFactor
107- if a_delta > 0. :
108- _x_tr = _x_tr + inputs * a_delta
113+ if n_nearest_spikes > 0 : ## run k-nearest neighbor trace
114+ _x_tr = _x_tr + inputs * ( a_delta - ( trace / n_nearest_spikes ))
109115 else :
110- _x_tr = _x_tr * (1. - inputs ) + inputs * P_scale
116+ if a_delta > 0. : ## run full convolution trace
117+ _x_tr = _x_tr + inputs * a_delta
118+ else : ## run simple max-clamped trace
119+ _x_tr = _x_tr * (1. - inputs ) + inputs * P_scale
111120 trace = _x_tr
112121 return trace , trace
113122
@@ -138,12 +147,15 @@ def help(cls): ## component help function
138147 "tau_tr" : "Trace/filter time constant" ,
139148 "a_delta" : "Increment to apply to trace (if not set to 0); "
140149 "otherwise, traces clamp to 1 and then decay" ,
150+ "P_scale" : "Max value to snap trace to if a max-clamp trace is triggered/configured" ,
141151 "decay_type" : "Indicator of what type of decay dynamics to use "
142- "as filter is updated at time t"
152+ "as filter is updated at time t" ,
153+ "n_nearest_neighbors" : "Number of nearest pulses to affect/increment trace (if > 0)"
143154 }
144155 info = {cls .__name__ : properties ,
145156 "compartments" : compartment_props ,
146- "dynamics" : "tau_tr * dz/dt ~ -z + inputs" ,
157+ "dynamics" : "tau_tr * dz/dt ~ -z + inputs * a_delta (full convolution trace); "
158+ "tau_tr * dz/dt ~ -z + inputs * (a_delta - z/n_nearest_neighbors) (near trace)" ,
147159 "hyperparameters" : hyperparams }
148160 return info
149161
0 commit comments