Skip to content

Commit e055d95

Browse files
author
Alexander Ororbia
committed
cleaned up gauss/laplace error cells
1 parent 7f3e7c8 commit e055d95

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

ngclearn/components/neurons/graded/gaussianErrorCell.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def _advance_state(dt, mu, target, Sigma, modulator, mask): ## compute Gaussian
8181
dmu = dmu * modulator * mask ## not sure how mask will apply to a full covariance...
8282
dtarget = dtarget * modulator * mask
8383
mask = mask * 0. + 1. ## "eat" the mask as it should only apply at time t
84-
return dmu, dtarget, dSigma, L[0, 0], mask
84+
return dmu, dtarget, dSigma, jnp.squeeze(L), mask
8585

8686
@resolver(_advance_state)
8787
def advance_state(self, dmu, dtarget, dSigma, L, mask):

ngclearn/components/neurons/graded/laplacianErrorCell.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def __init__(self, name, n_units, batch_size=1, scale=1., shape=None, **kwargs):
4343
else:
4444
_shape = (batch_size, shape[0], shape[1], shape[2]) ## shape is 4D tensor
4545
scale_shape = (1, 1)
46-
if not isinstance(scale, float):
46+
if not isinstance(scale, float) and not isinstance(sigma, int):
4747
scale_shape = jnp.array(scale).shape
4848
self.scale_shape = scale_shape
4949
## Layer Size setup
@@ -83,7 +83,7 @@ def _advance_state(dt, shift, target, Scale, modulator, mask): ## compute Laplac
8383
dshift = dshift * modulator * mask
8484
dtarget = dtarget * modulator * mask
8585
mask = mask * 0. + 1. ## "eat" the mask as it should only apply at time t
86-
return dshift, dtarget, dScale, L[0, 0], mask
86+
return dshift, dtarget, dScale, jnp.squeeze(L), mask
8787

8888
@resolver(_advance_state)
8989
def advance_state(self, dshift, dtarget, dScale, L, mask):

0 commit comments

Comments
 (0)