Skip to content

Commit 0ffb3d1

Browse files
author
Alexander Ororbia
committed
made some corrections to bern err-cell and heb syn
1 parent e408d9b commit 0ffb3d1

File tree

1 file changed

+11
-10
lines changed

1 file changed

+11
-10
lines changed

ngclearn/components/synapses/hebbian/hebbianSynapse.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,23 +40,24 @@ def _calc_update(pre, post, W, w_bound, is_nonnegative=True, signVal=1.,
4040
"""
4141
_pre = pre * pre_wght
4242
_post = post * post_wght
43-
dW = jnp.matmul(_pre.T, _post)
44-
db = jnp.sum(_post, axis=0, keepdims=True)
45-
dW_reg = 0.
43+
dW = jnp.matmul(_pre.T, _post) ## calc Hebbian adjustment
44+
db = jnp.sum(_post, axis=0, keepdims=True) ## calc Hebbian adjustment to bias/base-rates
45+
dW_reg = 0. ## synaptic decay term
4646

47-
if w_bound > 0.:
47+
if w_bound > 0.: ## induce any synaptic value bounding
4848
dW = dW * (w_bound - jnp.abs(W))
49-
49+
## apply synaptic priors
5050
if prior_type == "l2" or prior_type == "ridge":
51-
dW_reg = -W
51+
dW_reg = -W * prior_lmbda
5252
if prior_type == "l1" or prior_type == "lasso":
53-
dW_reg = -jnp.sign(W)
53+
dW_reg = -jnp.sign(W) * prior_lmbda
5454
if prior_type == "l1l2" or prior_type == "elastic_net":
5555
l1_ratio = prior_lmbda[1]
56-
prior_lmbda = prior_lmbda[0]
56+
prior_scale = prior_lmbda[0]
5757
dW_reg = -jnp.sign(W) * l1_ratio - W * (1-l1_ratio)/2
58-
59-
dW = dW + prior_lmbda * dW_reg
58+
dW_reg = dW_reg * prior_scale
59+
## produce final update/adjustment
60+
dW = dW + dW_reg
6061
return dW * signVal, db * signVal
6162

6263
@partial(jit, static_argnums=[1,2])

0 commit comments

Comments
 (0)