Skip to content

Commit 1c5164e

Browse files
author
Alexander Ororbia
committed
cleaned up bern-cell, hebb-syn
1 parent 0ffb3d1 commit 1c5164e

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

ngclearn/components/neurons/graded/bernoulliErrorCell.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,11 @@ def _advance_state(dt, p, target, modulator, mask, input_logits): ## compute Ber
8080
dL_dp = x - _p ## d(Bern LL)/dp where _p = sigmoid(p)
8181
else:
8282
dL_dp = x/(_p) - one_min_x/one_min_p ## d(Bern LL)/dp
83-
dL_dx = log_p - log_one_min_p ## d(Bern LL)/dx
84-
dp = dL_dp #* d_sigmoid(p)
85-
if input_logits:
86-
dp = dp * d_sigmoid(p)
83+
dL_dp = dL_dp * d_sigmoid(p)
84+
dL_dx = (log_p - log_one_min_p) ## d(Bern LL)/dx
85+
dp = dL_dp
8786

88-
dp = dL_dp * modulator * mask ## NOTE: how does mask apply to a multivariate Bernoulli?
87+
dp = dp * modulator * mask ## NOTE: how does mask apply to a multivariate Bernoulli?
8988
dtarget = dL_dx * modulator * mask
9089
mask = mask * 0. + 1. ## "eat" the mask as it should only apply at time t
9190
return dp, dtarget, jnp.squeeze(L), mask

ngclearn/components/synapses/hebbian/hebbianSynapse.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,10 @@ def __init__(self, name, shape, eta=0., weight_init=None, bias_init=None,
178178
self.shape = shape
179179
self.Rscale = resist_scale
180180
self.prior_type = prior_type
181+
if self.prior_type.lower() == "gaussian":
182+
self.prior_type = "ridge"
183+
elif self.prior_type.lower() == "laplacian":
184+
self.prior_type = "lasso"
181185
self.prior_lmbda = prior_lmbda
182186
self.w_bound = w_bound
183187
self.pre_wght = pre_wght

0 commit comments

Comments
 (0)