33from jax import numpy as jnp , jit
44from ngclearn .utils import tensorstats
55
6- def _run_cell (dt , targ , mu ):
6+ def _run_cell (dt , targ , mu , sigma ):
77 """
88 Moves cell dynamics one step forward.
99
@@ -14,13 +14,15 @@ def _run_cell(dt, targ, mu):
1414
1515 mu: prediction value
1616
17+ sigma: prediction variance
18+
1719 Returns:
1820 derivative w.r.t. mean "dmu", derivative w.r.t. target dtarg, local loss
1921 """
20- return _run_gaussian_cell (dt , targ , mu )
22+ return _run_gaussian_cell (dt , targ , mu , sigma )
2123
2224@jit
23- def _run_gaussian_cell (dt , targ , mu ):
25+ def _run_gaussian_cell (dt , targ , mu , sigma ):
2426 """
2527 Moves Gaussian cell dynamics one step forward. Specifically, this
2628 routine emulates the error unit behavior of the local cost functional:
@@ -35,13 +37,16 @@ def _run_gaussian_cell(dt, targ, mu):
3537
3638 mu: prediction value
3739
40+ sigma: prediction variance
41+
3842 Returns:
3943 derivative w.r.t. mean "dmu", derivative w.r.t. target dtarg, loss
4044 """
41- dmu = (targ - mu ) # e (error unit)
45+ dmu = (targ - mu )/ sigma # e (error unit)
4246 dtarg = - dmu # reverse of e
43- L = - jnp .sum (jnp .square (dmu )) * 0.5
44- return dmu , dtarg , L
47+ dsigma = 1. # no derivative is calculated at this time for sigma
48+ L = - jnp .sum (jnp .square (dmu )) * 0.5 / sigma
49+ return dmu , dtarg , dsigma , L
4550
4651class GaussianErrorCell (JaxComponent ): ## Rate-coded/real-valued error unit/cell
4752 """
@@ -66,8 +71,10 @@ class GaussianErrorCell(JaxComponent): ## Rate-coded/real-valued error unit/cell
6671 tau_m: (Unused -- currently cell is a fixed-point model)
6772
6873 leakRate: (Unused -- currently cell is a fixed-point model)
74+
75+ sigma: prediction covariance matrix (𝚺) in multivariate gaussian distribution
6976 """
70- def __init__ (self , name , n_units , batch_size = 1 , shape = None , ** kwargs ):
77+ def __init__ (self , name , n_units , batch_size = 1 , sigma = 1 , shape = None , ** kwargs ):
7178 super ().__init__ (name , ** kwargs )
7279
7380 ## Layer Size Setup
@@ -79,6 +86,7 @@ def __init__(self, name, n_units, batch_size=1, shape=None, **kwargs):
7986 self .shape = shape
8087 self .n_units = n_units
8188 self .batch_size = batch_size
89+ self .sigma = sigma
8290
8391 ## Convolution shape setup
8492 self .width = self .height = n_units
@@ -94,11 +102,12 @@ def __init__(self, name, n_units, batch_size=1, shape=None, **kwargs):
94102 self .mask = Compartment (restVals + 1.0 )
95103
96104 @staticmethod
97- def _advance_state (dt , mu , dmu , target , dtarget , modulator , mask ):
105+ def _advance_state (dt , mu , dmu , target , dtarget , sigma , modulator , mask ):
98106 ## compute Gaussian error cell output
99- dmu , dtarget , L = _run_cell (dt , target * mask , mu * mask )
107+ dmu , dtarget , dsigma , L = _run_cell (dt , target * mask , mu * mask , sigma )
100108 dmu = dmu * modulator * mask
101109 dtarget = dtarget * modulator * mask
110+ dsigma = dsigma * 0 + 1. # no derivative is calculated at this time for sigma
102111 mask = mask * 0. + 1. ## "eat" the mask as it should only apply at time t
103112 return dmu , dtarget , L , mask
104113
@@ -153,11 +162,12 @@ def help(cls): ## component help function
153162 }
154163 hyperparams = {
155164 "n_units" : "Number of neuronal cells to model in this layer" ,
156- "batch_size" : "Batch size dimension of this component"
165+ "batch_size" : "Batch size dimension of this component" ,
166+ "sigma" : "External input variance value (currently fixed and not learnable)"
157167 }
158168 info = {cls .__name__ : properties ,
159169 "compartments" : compartment_props ,
160- "dynamics" : "Gaussian(x=target; mu, sigma=1 )" ,
170+ "dynamics" : "Gaussian(x=target; mu, sigma)" ,
161171 "hyperparameters" : hyperparams }
162172 return info
163173
0 commit comments