Skip to content

Commit a8b156a

Browse files
author
Alexander Ororbia
committed
minor revisons/updates to hebb/dense syn, metric utils
1 parent 710d8b1 commit a8b156a

File tree

4 files changed

+26
-13
lines changed

4 files changed

+26
-13
lines changed

ngclearn/components/neurons/graded/rateCell.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def advance_state(self, dt):
226226
## self.pressure <-- "top-down" expectation / contextual pressure
227227
## self.current <-- "bottom-up" data-dependent signal
228228
dfx_val = self.dfx(z)
229-
j = _modulate(j, dfx_val)
229+
j = _modulate(j, dfx_val) ## TODO: make this optional (for NGC circuit dynamics)
230230
j = j * self.resist_scale
231231
tmp_z = _run_cell(
232232
dt, j, j_td, z, self.tau_m, leak_gamma=self.priorLeakRate, integType=self.intgFlag,

ngclearn/components/synapses/denseSynapse.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,20 @@ class DenseSynapse(JaxComponent): ## base dense synaptic cable
3636
p_conn: probability of a connection existing (default: 1.); setting
3737
this to < 1 and > 0. will result in a sparser synaptic structure
3838
(lower values yield sparse structure)
39+
40+
mask: if non-None, a (multiplicative) mask is applied to this synaptic weight matrix
3941
"""
4042

4143
def __init__(
42-
self, name, shape, weight_init=None, bias_init=None, resist_scale=1., p_conn=1., batch_size=1, **kwargs
44+
self, name, shape, weight_init=None, bias_init=None, resist_scale=1., p_conn=1., mask=None, batch_size=1,
45+
**kwargs
4346
):
4447
super().__init__(name, **kwargs)
4548

4649
self.batch_size = batch_size
50+
self.mask = 1.
51+
if mask is not None:
52+
self.mask = mask
4753

4854
## Synapse meta-parameters
4955
self.shape = shape
@@ -79,7 +85,9 @@ def __init__(
7985

8086
@compilable
8187
def advance_state(self):
82-
self.outputs.set((jnp.matmul(self.inputs.get(), self.weights.get()) * self.resist_scale) + self.biases.get())
88+
weights = self.weights.get()
89+
weights = weights * self.mask
90+
self.outputs.set((jnp.matmul(self.inputs.get(), weights) * self.resist_scale) + self.biases.get())
8391

8492
@compilable
8593
def reset(self):

ngclearn/components/synapses/hebbian/hebbianSynapse.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def _enforce_constraints(W, w_bound, is_nonnegative=True):
8686
"""
8787
_W = W
8888
if w_bound > 0.:
89-
if is_nonnegative == True:
89+
if is_nonnegative:
9090
_W = jnp.clip(_W, 0., w_bound)
9191
else:
9292
_W = jnp.clip(_W, -w_bound, w_bound)
@@ -173,7 +173,10 @@ def __init__(
173173
prior=("constant", 0.), w_decay=0., sign_value=1., optim_type="sgd", pre_wght=1., post_wght=1.,
174174
p_conn=1., resist_scale=1., batch_size=1, **kwargs
175175
):
176-
super().__init__(name, shape, weight_init, bias_init, resist_scale, p_conn, batch_size=batch_size, **kwargs)
176+
super().__init__(
177+
name, shape=shape, weight_init=weight_init, bias_init=bias_init, resist_scale=resist_scale, p_conn=p_conn,
178+
batch_size=batch_size, **kwargs
179+
)
177180

178181
if w_decay > 0.:
179182
prior = ('l2', w_decay)
@@ -243,19 +246,20 @@ def calc_update(self):
243246
post = self.post.get()
244247
weights = self.weights.get()
245248
biases = self.biases.get()
246-
opt_params = self.opt_params.get()
249+
#opt_params = self.opt_params.get()
247250

248251
## calculate synaptic update values
249252
dWeights, dBiases = HebbianSynapse._compute_update(
250-
self.w_bound, self.is_nonnegative, self.sign_value, self.prior_type, self.prior_lmbda, self.pre_wght, self.post_wght,
251-
pre, post, weights
253+
self.w_bound, self.is_nonnegative, self.sign_value, self.prior_type, self.prior_lmbda, self.pre_wght,
254+
self.post_wght, pre, post, weights
252255
)
253256

254257
self.dWeights.set(dWeights)
255258
self.dBiases.set(dBiases)
259+
#self.opt_params.set(opt_params)
256260

257261
@compilable
258-
def evolve(self):
262+
def evolve(self, dt):
259263
# Get the variables
260264
pre = self.pre.get()
261265
post = self.post.get()
@@ -268,6 +272,7 @@ def evolve(self):
268272
self.w_bound, self.is_nonnegative, self.sign_value, self.prior_type, self.prior_lmbda, self.pre_wght, self.post_wght,
269273
pre, post, weights
270274
)
275+
271276
## conduct a step of optimization - get newly evolved synaptic weight value matrix
272277
if self.bias_init != None:
273278
opt_params, [weights, biases] = self.opt(opt_params, [weights, biases], [dWeights, dBiases])

ngclearn/utils/metric_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ def measure_CatNLL(p, x, offset=1e-7, preserve_batch=False):
308308
nll = jnp.mean(nll)
309309
return nll #tf.reduce_mean(nll)
310310

311-
@jit
311+
@partial(jit, static_argnums=[2])
312312
def measure_RMSE(mu, x, preserve_batch=False):
313313
"""
314314
Measures root mean squared error (RMSE). Note: If batch is preserved, this returns a column vector where each
@@ -328,7 +328,7 @@ def measure_RMSE(mu, x, preserve_batch=False):
328328
mse = measure_MSE(mu, x, preserve_batch=preserve_batch)
329329
return jnp.sqrt(mse) ## sqrt(MSE) is the root-mean-squared-error
330330

331-
@jit
331+
@partial(jit, static_argnums=[2])
332332
def measure_MSE(mu, x, preserve_batch=False):
333333
"""
334334
Measures mean squared error (MSE), or the negative Gaussian log likelihood with variance of 1.0. Note: If batch
@@ -352,7 +352,7 @@ def measure_MSE(mu, x, preserve_batch=False):
352352
mse = jnp.mean(mse) # this is proper mse
353353
return mse
354354

355-
@jit
355+
@partial(jit, static_argnums=[2])
356356
def measure_MAE(shift, x, preserve_batch=False):
357357
"""
358358
Measures mean absolute error (MAE), or the negative Laplacian log likelihood with scale of 1.0. Note: If batch
@@ -376,7 +376,7 @@ def measure_MAE(shift, x, preserve_batch=False):
376376
mae = jnp.mean(mae) # this is proper mae
377377
return mae
378378

379-
@jit
379+
@partial(jit, static_argnums=[3])
380380
def measure_BCE(p, x, offset=1e-7, preserve_batch=False): #1e-10
381381
"""
382382
Calculates the negative Bernoulli log likelihood or binary cross entropy (BCE). Note: If batch is preserved,

0 commit comments

Comments
 (0)