Skip to content

Commit b9227f0

Browse files
authored
add-sigma-to-gaussianErrorCell (#97)
* add-sigma-to-gaussianErrorCell add not updating scalar variance for gaussian errors * Update gaussianErrorCell.py
1 parent 0d720e1 commit b9227f0

File tree

1 file changed

+21
-11
lines changed

1 file changed

+21
-11
lines changed

ngclearn/components/neurons/graded/gaussianErrorCell.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from jax import numpy as jnp, jit
44
from ngclearn.utils import tensorstats
55

6-
def _run_cell(dt, targ, mu):
6+
def _run_cell(dt, targ, mu, sigma):
77
"""
88
Moves cell dynamics one step forward.
99
@@ -14,13 +14,15 @@ def _run_cell(dt, targ, mu):
1414
1515
mu: prediction value
1616
17+
sigma: prediction variance
18+
1719
Returns:
1820
derivative w.r.t. mean "dmu", derivative w.r.t. target dtarg, local loss
1921
"""
20-
return _run_gaussian_cell(dt, targ, mu)
22+
return _run_gaussian_cell(dt, targ, mu, sigma)
2123

2224
@jit
23-
def _run_gaussian_cell(dt, targ, mu):
25+
def _run_gaussian_cell(dt, targ, mu, sigma):
2426
"""
2527
Moves Gaussian cell dynamics one step forward. Specifically, this
2628
routine emulates the error unit behavior of the local cost functional:
@@ -35,13 +37,16 @@ def _run_gaussian_cell(dt, targ, mu):
3537
3638
mu: prediction value
3739
40+
sigma: prediction variance
41+
3842
Returns:
3943
derivative w.r.t. mean "dmu", derivative w.r.t. target dtarg, loss
4044
"""
41-
dmu = (targ - mu) # e (error unit)
45+
dmu = (targ - mu)/sigma # e (error unit)
4246
dtarg = -dmu # reverse of e
43-
L = -jnp.sum(jnp.square(dmu)) * 0.5
44-
return dmu, dtarg, L
47+
dsigma = 1. # no derivative is calculated at this time for sigma
48+
L = -jnp.sum(jnp.square(dmu)) * 0.5 / sigma
49+
return dmu, dtarg, dsigma, L
4550

4651
class GaussianErrorCell(JaxComponent): ## Rate-coded/real-valued error unit/cell
4752
"""
@@ -66,8 +71,10 @@ class GaussianErrorCell(JaxComponent): ## Rate-coded/real-valued error unit/cell
6671
tau_m: (Unused -- currently cell is a fixed-point model)
6772
6873
leakRate: (Unused -- currently cell is a fixed-point model)
74+
75+
sigma: prediction covariance matrix (𝚺) in multivariate gaussian distribution
6976
"""
70-
def __init__(self, name, n_units, batch_size=1, shape=None, **kwargs):
77+
def __init__(self, name, n_units, batch_size=1, sigma=1, shape=None, **kwargs):
7178
super().__init__(name, **kwargs)
7279

7380
## Layer Size Setup
@@ -79,6 +86,7 @@ def __init__(self, name, n_units, batch_size=1, shape=None, **kwargs):
7986
self.shape = shape
8087
self.n_units = n_units
8188
self.batch_size = batch_size
89+
self.sigma = sigma
8290

8391
## Convolution shape setup
8492
self.width = self.height = n_units
@@ -94,11 +102,12 @@ def __init__(self, name, n_units, batch_size=1, shape=None, **kwargs):
94102
self.mask = Compartment(restVals + 1.0)
95103

96104
@staticmethod
97-
def _advance_state(dt, mu, dmu, target, dtarget, modulator, mask):
105+
def _advance_state(dt, mu, dmu, target, dtarget, sigma, modulator, mask):
98106
## compute Gaussian error cell output
99-
dmu, dtarget, L = _run_cell(dt, target * mask, mu * mask)
107+
dmu, dtarget, dsigma, L = _run_cell(dt, target * mask, mu * mask, sigma)
100108
dmu = dmu * modulator * mask
101109
dtarget = dtarget * modulator * mask
110+
dsigma = dsigma * 0 + 1. # no derivative is calculated at this time for sigma
102111
mask = mask * 0. + 1. ## "eat" the mask as it should only apply at time t
103112
return dmu, dtarget, L, mask
104113

@@ -153,11 +162,12 @@ def help(cls): ## component help function
153162
}
154163
hyperparams = {
155164
"n_units": "Number of neuronal cells to model in this layer",
156-
"batch_size": "Batch size dimension of this component"
165+
"batch_size": "Batch size dimension of this component",
166+
"sigma": "External input variance value (currently fixed and not learnable)"
157167
}
158168
info = {cls.__name__: properties,
159169
"compartments": compartment_props,
160-
"dynamics": "Gaussian(x=target; mu, sigma=1)",
170+
"dynamics": "Gaussian(x=target; mu, sigma)",
161171
"hyperparameters": hyperparams}
162172
return info
163173

0 commit comments

Comments
 (0)