Skip to content

Commit 276fbbb

Browse files
author
Alexander Ororbia
committed
cleaned up trace
1 parent 0419997 commit 276fbbb

File tree

1 file changed

+19
-7
lines changed

1 file changed

+19
-7
lines changed

ngclearn/components/other/varTrace.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,15 @@ class VarTrace(JaxComponent): ## low-pass filter
7171
2) `'exp'` = exponential trace filter, i.e., decay = exp(-dt/tau_tr) * x_tr;
7272
3) `'step'` = step trace, i.e., decay = 0 (a pulse applied upon input value)
7373
74+
n_nearest_spikes: (k) if k > 0, this makes the trace act like a nearest-neighbor trace,
75+
i.e., k = 1 yields the 1-nearest (neighbor) trace (Default: 0)
76+
7477
batch_size: batch size dimension of this cell (Default: 1)
7578
"""
7679

7780
# Define Functions
7881
def __init__(self, name, n_units, tau_tr, a_delta, P_scale=1., gamma_tr=1, decay_type="exp",
79-
batch_size=1, **kwargs):
82+
n_nearest_spikes=0, batch_size=1, **kwargs):
8083
super().__init__(name, **kwargs)
8184

8285
## Trace control coefficients
@@ -85,6 +88,7 @@ def __init__(self, name, n_units, tau_tr, a_delta, P_scale=1., gamma_tr=1, decay
8588
self.P_scale = P_scale ## trace scale if non-additive trace to be used
8689
self.gamma_tr = gamma_tr
8790
self.decay_type = decay_type ## lin --> linear decay; exp --> exponential decay
91+
self.n_nearest_spikes = n_nearest_spikes
8892

8993
## Layer Size Setup
9094
self.batch_size = batch_size
@@ -97,17 +101,22 @@ def __init__(self, name, n_units, tau_tr, a_delta, P_scale=1., gamma_tr=1, decay
97101

98102
@transition(output_compartments=["outputs", "trace"])
99103
@staticmethod
100-
def advance_state(dt, decay_type, tau_tr, a_delta, P_scale, gamma_tr, inputs, trace):
104+
def advance_state(
105+
dt, decay_type, tau_tr, a_delta, P_scale, gamma_tr, inputs, trace, n_nearest_spikes
106+
):
101107
decayFactor = 0.
102108
if "exp" in decay_type:
103109
decayFactor = jnp.exp(-dt/tau_tr)
104110
elif "lin" in decay_type:
105111
decayFactor = (1. - dt/tau_tr)
106112
_x_tr = gamma_tr * trace * decayFactor
107-
if a_delta > 0.:
108-
_x_tr = _x_tr + inputs * a_delta
113+
if n_nearest_spikes > 0: ## run k-nearest neighbor trace
114+
_x_tr = _x_tr + inputs * (a_delta - (trace/n_nearest_spikes))
109115
else:
110-
_x_tr = _x_tr * (1. - inputs) + inputs * P_scale
116+
if a_delta > 0.: ## run full convolution trace
117+
_x_tr = _x_tr + inputs * a_delta
118+
else: ## run simple max-clamped trace
119+
_x_tr = _x_tr * (1. - inputs) + inputs * P_scale
111120
trace = _x_tr
112121
return trace, trace
113122

@@ -138,12 +147,15 @@ def help(cls): ## component help function
138147
"tau_tr": "Trace/filter time constant",
139148
"a_delta": "Increment to apply to trace (if not set to 0); "
140149
"otherwise, traces clamp to 1 and then decay",
150+
"P_scale": "Max value to snap trace to if a max-clamp trace is triggered/configured",
141151
"decay_type": "Indicator of what type of decay dynamics to use "
142-
"as filter is updated at time t"
152+
"as filter is updated at time t",
153+
"n_nearest_neighbors": "Number of nearest pulses to affect/increment trace (if > 0)"
143154
}
144155
info = {cls.__name__: properties,
145156
"compartments": compartment_props,
146-
"dynamics": "tau_tr * dz/dt ~ -z + inputs",
157+
"dynamics": "tau_tr * dz/dt ~ -z + inputs * a_delta (full convolution trace); "
158+
"tau_tr * dz/dt ~ -z + inputs * (a_delta - z/n_nearest_neighbors) (near trace)",
147159
"hyperparameters": hyperparams}
148160
return info
149161

0 commit comments

Comments
 (0)