Skip to content

Commit cf7ee73

Browse files
author
Alexander Ororbia
committed
minor cleanup/patches to rate-cell/lif/hebb-syn/trace-stdp and dim_reduce method
1 parent a09a3c4 commit cf7ee73

File tree

7 files changed

+61
-29
lines changed

7 files changed

+61
-29
lines changed

ngclearn/components/neurons/graded/rateCell.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,8 @@ class RateCell(JaxComponent): ## Rate-coded/real-valued cell
145145
146146
act_fx: string name of activation function/nonlinearity to use
147147
148+
output_scale: factor to multiply output of nonlinearity of this cell by (Default: 1.)
149+
148150
integration_type: type of integration to use for this cell's dynamics;
149151
current supported forms include "euler" (Euler/RK-1 integration)
150152
and "midpoint" or "rk2" (midpoint method/RK-2 integration) (Default: "euler")
@@ -157,12 +159,13 @@ class RateCell(JaxComponent): ## Rate-coded/real-valued cell
157159
"""
158160

159161
# Define Functions
160-
def __init__(self, name, n_units, tau_m, prior=("gaussian", 0.), act_fx="identity",
161-
threshold=("none", 0.), integration_type="euler",
162-
batch_size=1, resist_scale=1., shape=None, is_stateful=True, **kwargs):
162+
def __init__(
163+
self, name, n_units, tau_m, prior=("gaussian", 0.), act_fx="identity", output_scale=1., threshold=("none", 0.),
164+
integration_type="euler", batch_size=1, resist_scale=1., shape=None, is_stateful=True, **kwargs):
163165
super().__init__(name, **kwargs)
164166

165167
## membrane parameter setup (affects ODE integration)
168+
self.output_scale = output_scale
166169
self.tau_m = tau_m ## membrane time constant -- setting to 0 triggers "stateless" mode
167170
self.is_stateful = is_stateful
168171
if isinstance(tau_m, float):
@@ -211,8 +214,9 @@ def __init__(self, name, n_units, tau_m, prior=("gaussian", 0.), act_fx="identit
211214

212215
@transition(output_compartments=["j", "j_td", "z", "zF"])
213216
@staticmethod
214-
def advance_state(dt, fx, dfx, tau_m, priorLeakRate, intgFlag, priorType,
215-
resist_scale, thresholdType, thr_lmbda, is_stateful, j, j_td, z):
217+
def advance_state(
218+
dt, fx, dfx, tau_m, priorLeakRate, intgFlag, priorType, resist_scale, thresholdType, thr_lmbda, is_stateful,
219+
output_scale, j, j_td, z):
216220
#if tau_m > 0.:
217221
if is_stateful:
218222
### run a step of integration over neuronal dynamics
@@ -231,12 +235,12 @@ def advance_state(dt, fx, dfx, tau_m, priorLeakRate, intgFlag, priorType,
231235
elif thresholdType == "cauchy_threshold":
232236
tmp_z = threshold_cauchy(tmp_z, thr_lmbda)
233237
z = tmp_z ## pre-activation function value(s)
234-
zF = fx(z) ## post-activation function value(s)
238+
zF = fx(z) * output_scale ## post-activation function value(s)
235239
else:
236240
## run in "stateless" mode (when no membrane time constant provided)
237241
j_total = j + j_td
238242
z = _run_cell_stateless(j_total)
239-
zF = fx(z)
243+
zF = fx(z) * output_scale
240244
return j, j_td, z, zF
241245

242246
@transition(output_compartments=["j", "j_td", "z", "zF"])

ngclearn/components/neurons/spiking/LIFCell.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -123,15 +123,14 @@ class LIFCell(JaxComponent): ## leaky integrate-and-fire cell
123123
"arctan" (arc-tangent estimator), and "secant_lif" (the
124124
LIF-specialized secant estimator)
125125
126-
lower_clamp_voltage: if True, this will ensure voltage never is below
127-
the value of `v_rest` (default: True)
126+
v_min: minimum voltage to clamp dynamics to (Default: None)
128127
""" ## batch_size arg?
129128

130129
@deprecate_args(thr_jitter=None, v_decay="conduct_leak")
131130
def __init__(
132131
self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65., v_reset=-60., conduct_leak=1., tau_theta=1e7,
133132
theta_plus=0.05, refract_time=5., one_spike=False, integration_type="euler", surrogate_type="straight_through",
134-
lower_clamp_voltage=True, **kwargs
133+
v_min=None, max_one_spike=False, **kwargs
135134
):
136135
super().__init__(name, **kwargs)
137136

@@ -143,7 +142,8 @@ def __init__(
143142
self.tau_m = tau_m ## membrane time constant
144143
self.resist_m = resist_m ## resistance value
145144
self.one_spike = one_spike ## True => constrains system to simulate 1 spike per time step
146-
self.lower_clamp_voltage = lower_clamp_voltage ## True ==> ensures voltage is never < v_rest
145+
self.max_one_spike = max_one_spike
146+
self.v_min = v_min ## ensures voltage is never < v_min
147147

148148
self.v_rest = v_rest #-65. # mV
149149
self.v_reset = v_reset # -60. # -65. # mV (milli-volts)
@@ -189,11 +189,11 @@ def __init__(
189189
@transition(output_compartments=["v", "s", "s_raw", "rfr", "thr_theta", "tols", "key", "surrogate"])
190190
@staticmethod
191191
def advance_state(
192-
t, dt, tau_m, resist_m, v_rest, v_reset, g_L, refract_T, thr, tau_theta, theta_plus,
193-
one_spike, lower_clamp_voltage, intgFlag, d_spike_fx, key, j, v, rfr, thr_theta, tols
192+
t, dt, tau_m, resist_m, v_rest, v_reset, g_L, refract_T, thr, tau_theta, theta_plus, one_spike, max_one_spike,
193+
v_min, intgFlag, d_spike_fx, key, j, v, rfr, thr_theta, tols
194194
):
195195
skey = None ## this is an empty dkey if single_spike mode turned off
196-
if one_spike:
196+
if one_spike and not max_one_spike:
197197
key, skey = random.split(key, 2)
198198
## run one integration step for neuronal dynamics
199199
j = j * resist_m
@@ -209,6 +209,7 @@ def advance_state(
209209
_, _v = step_euler(0., v, _dfv, dt, v_params)
210210
## obtain action potentials/spikes/pulses
211211
s = (_v > _v_thr) * 1.
212+
v_prespike = v
212213
## update refractory variables
213214
_rfr = (rfr + dt) * (1. - s)
214215
## perform hyper-polarization of neuronal cells
@@ -223,6 +224,9 @@ def advance_state(
223224
rS = nn.one_hot(jnp.argmax(rS, axis=1), num_classes=s.shape[1],
224225
dtype=jnp.float32)
225226
s = s * (1. - m_switch) + rS * m_switch
227+
if max_one_spike:
228+
rS = nn.one_hot(jnp.argmax(v_prespike, axis=1), num_classes=s.shape[1], dtype=jnp.float32) ## get max-volt spike
229+
s = s * rS ## mask out non-max volt spikes
226230
############################################################################
227231
raw_spikes = raw_s
228232
v = _v
@@ -234,8 +238,8 @@ def advance_state(
234238
thr_theta = _update_theta(dt, thr_theta, raw_spikes, tau_theta, theta_plus)
235239
## update tols
236240
tols = (1. - s) * tols + (s * t)
237-
if lower_clamp_voltage: ## ensure voltage never < v_rest
238-
v = jnp.maximum(v, v_rest)
241+
if v_min is not None: ## ensures voltage never < v_rest
242+
v = jnp.maximum(v, v_min)
239243
return v, s, raw_spikes, rfr, thr_theta, tols, key, surrogate
240244

241245
@transition(output_compartments=["j", "v", "s", "s_raw", "rfr", "tols", "surrogate"])

ngclearn/components/synapses/denseSynapse.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,10 @@ class DenseSynapse(JaxComponent): ## base dense synaptic cable
4141
"""
4242

4343
# Define Functions
44-
def __init__(self, name, shape, weight_init=None, bias_init=None,
45-
resist_scale=1., p_conn=1., batch_size=1, **kwargs):
44+
def __init__(
45+
self, name, shape, weight_init=None, bias_init=None, resist_scale=1.,
46+
p_conn=1., batch_size=1, **kwargs
47+
):
4648
super().__init__(name, **kwargs)
4749

4850
self.batch_size = batch_size
@@ -63,7 +65,7 @@ def __init__(self, name, shape, weight_init=None, bias_init=None,
6365
mask = random.bernoulli(subkeys[1], p=p_conn, shape=shape)
6466
weights = weights * mask ## sparsify matrix
6567

66-
self.batch_size = 1
68+
self.batch_size = batch_size #1
6769
## Compartment setup
6870
preVals = jnp.zeros((self.batch_size, shape[0]))
6971
postVals = jnp.zeros((self.batch_size, shape[1]))

ngclearn/components/synapses/hebbian/hebbianSynapse.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -162,10 +162,11 @@ class HebbianSynapse(DenseSynapse):
162162

163163
# Define Functions
164164
@deprecate_args(_rebind=False, w_decay='prior')
165-
def __init__(self, name, shape, eta=0., weight_init=None, bias_init=None,
166-
w_bound=1., is_nonnegative=False, prior=("constant", 0.), w_decay=0., sign_value=1.,
167-
optim_type="sgd", pre_wght=1., post_wght=1., p_conn=1.,
168-
resist_scale=1., batch_size=1, **kwargs):
165+
def __init__(
166+
self, name, shape, eta=0., weight_init=None, bias_init=None, w_bound=1., is_nonnegative=False,
167+
prior=("constant", 0.), w_decay=0., sign_value=1., optim_type="sgd", pre_wght=1., post_wght=1., p_conn=1.,
168+
resist_scale=1., batch_size=1, **kwargs
169+
):
169170
super().__init__(name, shape, weight_init, bias_init, resist_scale,
170171
p_conn, batch_size=batch_size, **kwargs)
171172

ngclearn/components/synapses/hebbian/traceSTDPSynapse.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,20 +68,25 @@ class TraceSTDPSynapse(DenseSynapse): # power-law / trace-based STDP
6868
# Define Functions
6969
def __init__(
7070
self, name, shape, A_plus, A_minus, eta=1., mu=0., pretrace_target=0., weight_init=None, resist_scale=1.,
71-
p_conn=1., w_bound=1., batch_size=1, **kwargs
71+
p_conn=1., w_bound=1., tau_w=0., weight_mask=None, batch_size=1, **kwargs
7272
):
7373
super().__init__(name, shape, weight_init, None, resist_scale,
7474
p_conn, batch_size=batch_size, **kwargs)
7575

7676
## Synaptic hyper-parameters
7777
self.shape = shape ## shape of synaptic efficacy matrix
78+
self.tau_w = tau_w
7879
self.mu = mu ## controls power-scaling of STDP rule
7980
self.preTrace_target = pretrace_target ## target (pre-synaptic) trace activity value # 0.7
8081
self.Aplus = A_plus ## LTP strength
8182
self.Aminus = A_minus ## LTD strength
8283
self.Rscale = resist_scale ## post-transformation scale factor
8384
self.w_bound = w_bound #1. ## soft weight constraint
8485
self.w_eps = 0. ## w_eps = 0.01
86+
self.weight_mask = weight_mask
87+
if self.weight_mask is None:
88+
self.weight_mask = jnp.ones((1, 1))
89+
self.weights.set(self.weights.value * self.weight_mask)
8590

8691
## Compartment setup
8792
preVals = jnp.zeros((self.batch_size, shape[0]))
@@ -93,6 +98,12 @@ def __init__(
9398
self.dWeights = Compartment(self.weights.value * 0)
9499
self.eta = Compartment(jnp.ones((1, 1)) * eta) ## global learning rate
95100

101+
#@transition(output_compartments=["outputs"])
102+
#@staticmethod
103+
#def advance_state(Rscale, inputs, weights, biases, weight_mask):
104+
# outputs = (jnp.matmul(inputs, weights * weight_mask) * Rscale) + biases
105+
# return outputs
106+
96107
@staticmethod
97108
def _compute_update(
98109
dt, w_bound, preTrace_target, mu, Aplus, Aminus, preSpike, postSpike, preTrace, postTrace, weights
@@ -126,16 +137,23 @@ def _compute_update(
126137
@transition(output_compartments=["weights", "dWeights"])
127138
@staticmethod
128139
def evolve(
129-
dt, w_bound, w_eps, preTrace_target, mu, Aplus, Aminus, preSpike, postSpike, preTrace, postTrace, weights, eta
140+
dt, w_bound, w_eps, preTrace_target, mu, Aplus, Aminus, tau_w, preSpike, postSpike, preTrace,
141+
postTrace, weights, eta, weight_mask
130142
):
143+
#_wm = weight_mask #
144+
_wm = (weight_mask != 0.)
131145
dWeights = TraceSTDPSynapse._compute_update(
132146
dt, w_bound, preTrace_target, mu, Aplus, Aminus, preSpike, postSpike, preTrace, postTrace, weights
133147
)
134148
## do a gradient ascent update/shift
135-
weights = weights + dWeights * eta
149+
decayTerm = 0.
150+
if tau_w > 0.:
151+
decayTerm = weights / tau_w
152+
weights = weights + (dWeights * eta) - decayTerm #weight_mask * eta)
136153
## enforce non-negativity
137154
#w_eps = 0. # 0.01 # 0.001
138155
weights = jnp.clip(weights, w_eps, w_bound - w_eps) # jnp.abs(w_bound))
156+
weights = weights * _wm # weight_mask
139157
return weights, dWeights
140158

141159
@transition(output_compartments=["inputs", "outputs", "preSpike", "postSpike", "preTrace", "postTrace", "dWeights"])

ngclearn/components/synapses/staticSynapse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,4 @@ class StaticSynapse(DenseSynapse):
2727
this to < 1 and > 0. will result in a sparser synaptic structure
2828
(lower values yield sparse structure)
2929
"""
30-
pass
30+
pass

ngclearn/utils/viz/dim_reduce.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def extract_pca_latents(vectors): ## PCA mapping routine
2929
z_2D = vectors
3030
return z_2D
3131

32-
def extract_tsne_latents(vectors, perplexity=30, n_pca_comp=32): ## tSNE mapping routine
32+
def extract_tsne_latents(vectors, perplexity=30, n_pca_comp=32, batch_size=500): ## tSNE mapping routine
3333
"""
3434
Projects collection of K vectors (stored in a matrix) to a two-dimensional (2D)
3535
visualization space via the t-distributed stochastic neighbor embedding
@@ -42,10 +42,13 @@ def extract_tsne_latents(vectors, perplexity=30, n_pca_comp=32): ## tSNE mapping
4242
4343
perplexity: the perplexity control factor for t-SNE (Default: 30)
4444
45+
batch_size: number of sampled embedding vectors to use per iteration
46+
of online internal PCA
47+
4548
Returns:
4649
a matrix (K x 2) of projected vectors (to 2D space)
4750
"""
48-
batch_size = 500 #50
51+
#batch_size = 500 #50
4952
z_dim = vectors.shape[1]
5053
z_2D = None
5154
if z_dim != 2:

0 commit comments

Comments
 (0)