Skip to content

Commit f4d47d4

Browse files
author
Alexander Ororbia
committed
made some corrections to bern err-cell and heb syn
1 parent 8a5bc68 commit f4d47d4

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

ngclearn/components/neurons/graded/bernoulliErrorCell.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,14 +102,14 @@ def _reset(batch_size, shape): ## reset core components/statistics
102102
_shape = (batch_size, shape[0])
103103
if len(shape) > 1:
104104
_shape = (batch_size, shape[0], shape[1], shape[2])
105-
restVals = jnp.zeros(_shape)
105+
restVals = jnp.zeros(_shape) ## "rest"/reset values
106106
dp = restVals
107107
dtarget = restVals
108108
target = restVals
109109
p = restVals
110-
modulator = mu + 1.
111-
L = 0. #jnp.zeros((1, 1))
112-
mask = jnp.ones(_shape)
110+
modulator = restVals + 1. ## reset modulator signal
111+
L = 0. #jnp.zeros((1, 1)) ## rest loss
112+
mask = jnp.ones(_shape) ## reset mask
113113
return dp, dtarget, target, p, modulator, L, mask
114114

115115
@resolver(_reset)

ngclearn/components/synapses/hebbian/hebbianSynapse.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,13 @@ def _calc_update(pre, post, W, w_bound, is_nonnegative=True, signVal=1.,
4848
dW = dW * (w_bound - jnp.abs(W))
4949

5050
if prior_type == "l2" or prior_type == "ridge":
51-
dW_reg = W
51+
dW_reg = -W
5252
if prior_type == "l1" or prior_type == "lasso":
53-
dW_reg = jnp.sign(W)
53+
dW_reg = -jnp.sign(W)
5454
if prior_type == "l1l2" or prior_type == "elastic_net":
5555
l1_ratio = prior_lmbda[1]
5656
prior_lmbda = prior_lmbda[0]
57-
dW_reg = jnp.sign(W) * l1_ratio + W * (1-l1_ratio)/2
57+
dW_reg = -jnp.sign(W) * l1_ratio - W * (1-l1_ratio)/2
5858

5959
dW = dW + prior_lmbda * dW_reg
6060
return dW * signVal, db * signVal

0 commit comments

Comments
 (0)