33from jax import numpy as jnp , jit
44from ngclearn .utils import tensorstats
55
6- def _run_cell (dt , targ , mu , sigma ):
7- """
8- Moves cell dynamics one step forward.
9-
10- Args:
11- dt: integration time constant
12-
13- targ: target pattern value
14-
15- mu: prediction value
16-
17- sigma: prediction variance
18-
19- Returns:
20- derivative w.r.t. mean "dmu", derivative w.r.t. target dtarg, local loss
21- """
22- return _run_gaussian_cell (dt , targ , mu , sigma )
23-
24- @jit
25- def _run_gaussian_cell (dt , targ , mu , sigma ):
26- """
27- Moves Gaussian cell dynamics one step forward. Specifically, this
28- routine emulates the error unit behavior of the local cost functional:
29-
30- | L(targ, mu) = -(1/2) * ||targ - mu||^2_2
31- | or log likelihood of the multivariate Gaussian with identity covariance
32-
33- Args:
34- dt: integration time constant
35-
36- targ: target pattern value
37-
38- mu: prediction value
39-
40- sigma: prediction variance
41-
42- Returns:
43- derivative w.r.t. mean "dmu", derivative w.r.t. target dtarg, loss
44- """
45- dmu = (targ - mu )/ sigma # e (error unit)
46- dtarg = - dmu # reverse of e
47- dsigma = 1. # no derivative is calculated at this time for sigma
48- L = - jnp .sum (jnp .square (dmu )) * 0.5 / sigma
49- return dmu , dtarg , dsigma , L
50-
516class GaussianErrorCell (JaxComponent ): ## Rate-coded/real-valued error unit/cell
527 """
538 A simple (non-spiking) Gaussian error cell - this is a fixed-point solution
549 of a mismatch signal.
5510
5611 | --- Cell Input Compartments: ---
5712 | mu - predicted value (takes in external signals)
13+ | Sigma - predicted covariance (takes in external signals)
5814 | target - desired/goal value (takes in external signals)
5915 | modulator - modulation signal (takes in optional external signals)
6016 | mask - binary/gating mask to apply to error neuron calculations
6117 | --- Cell Output Compartments: ---
6218 | L - local loss function embodied by this cell
6319 | dmu - derivative of L w.r.t. mu
20+ | dSigma - derivative of L w.r.t. Sigma
6421 | dtarget - derivative of L w.r.t. target
6522
6623 Args:
6724 name: the string name of this cell
6825
6926 n_units: number of cellular entities (neural population size)
7027
71- tau_m: (Unused -- currently cell is a fixed-point model)
72-
73- leakRate: (Unused -- currently cell is a fixed-point model)
28+ batch_size: batch size dimension of this cell (Default: 1)
7429
75- sigma: prediction covariance matrix (𝚺) in multivariate gaussian distribution
30+ sigma: initial/fixed value for prediction covariance matrix (𝚺) in multivariate gaussian distribution;
31+ Note that if the compartment `Sigma` is never used, then this cell assumes that the covariance collapses
32+ to a constant/fixed `sigma`
7633 """
77- def __init__ (self , name , n_units , batch_size = 1 , sigma = 1 , shape = None , ** kwargs ):
34+ def __init__ (self , name , n_units , batch_size = 1 , sigma = 1. , shape = None , ** kwargs ):
7835 super ().__init__ (name , ** kwargs )
7936
8037 ## Layer Size Setup
@@ -83,60 +40,78 @@ def __init__(self, name, n_units, batch_size=1, sigma=1, shape=None, **kwargs):
8340 shape = (n_units ,) ## we set shape to be equal to n_units if nothing provided
8441 else :
8542 _shape = (batch_size , shape [0 ], shape [1 ], shape [2 ]) ## shape is 4D tensor
43+ sigma_shape = (1 ,1 )
44+ if not isinstance (sigma , float ):
45+ sigma_shape = jnp .array (sigma ).shape
46+ self .sigma_shape = sigma_shape
8647 self .shape = shape
8748 self .n_units = n_units
8849 self .batch_size = batch_size
89- self .sigma = sigma
9050
9151 ## Convolution shape setup
9252 self .width = self .height = n_units
9353
9454 ## Compartment setup
9555 restVals = jnp .zeros (_shape )
96- self .L = Compartment (0. ) # loss compartment
97- self .mu = Compartment (restVals ) # mean/mean name. input wire
56+ self .L = Compartment (0. , display_name = "Gaussian Log likelihood" , units = "nats" ) # loss compartment
57+ self .mu = Compartment (restVals , display_name = "Gaussian mean" ) # mean/mean name. input wire
9858 self .dmu = Compartment (restVals ) # derivative mean
99- self .target = Compartment (restVals ) # target. input wire
59+ _Sigma = jnp .zeros (sigma_shape )
60+ self .Sigma = Compartment (_Sigma + sigma , display_name = "Gaussian variance/covariance" )
61+ self .dSigma = Compartment (_Sigma )
62+ self .target = Compartment (restVals , display_name = "Gaussian data/target variable" ) # target. input wire
10063 self .dtarget = Compartment (restVals ) # derivative target
10164 self .modulator = Compartment (restVals + 1.0 ) # to be set/consumed
10265 self .mask = Compartment (restVals + 1.0 )
10366
10467 @staticmethod
105- def _advance_state (dt , mu , dmu , target , dtarget , sigma , modulator , mask ):
106- ## compute Gaussian error cell output
107- dmu , dtarget , dsigma , L = _run_cell (dt , target * mask , mu * mask , sigma )
108- dmu = dmu * modulator * mask
68+ def _advance_state (dt , mu , target , Sigma , modulator , mask ): ## compute Gaussian error cell output
69+ # Moves Gaussian cell dynamics one step forward. Specifically, this routine emulates the error unit
70+ # behavior of the local cost functional:
71+ # FIXME: Currently, below does: L(targ, mu) = -(1/(2*sigma)) * ||targ - mu||^2_2
72+ # but should support full log likelihood of the multivariate Gaussian with covariance of different types
73+ # TODO: could introduce a variant of GaussianErrorCell that moves according to an ODE
74+ # (using integration time constant dt)
75+ _dmu = (target - mu ) # e (error unit)
76+ dmu = _dmu / Sigma
77+ dtarget = - dmu # reverse of e
78+ dSigma = Sigma * 0 + 1. # no derivative is calculated at this time for sigma
79+ L = - jnp .sum (jnp .square (_dmu )) * (0.5 / Sigma )
80+
81+ dmu = dmu * modulator * mask ## not sure how mask will apply to a full covariance...
10982 dtarget = dtarget * modulator * mask
110- dsigma = dsigma * 0 + 1. # no derivative is calculated at this time for sigma
11183 mask = mask * 0. + 1. ## "eat" the mask as it should only apply at time t
112- return dmu , dtarget , L , mask
84+ return dmu , dtarget , dSigma , L [ 0 , 0 ] , mask
11385
11486 @resolver (_advance_state )
115- def advance_state (self , dmu , dtarget , L , mask ):
87+ def advance_state (self , dmu , dtarget , dSigma , L , mask ):
11688 self .dmu .set (dmu )
11789 self .dtarget .set (dtarget )
90+ self .dSigma .set (dSigma )
11891 self .L .set (L )
11992 self .mask .set (mask )
12093
12194 @staticmethod
122- def _reset (batch_size , shape ): #n_units
95+ def _reset (batch_size , shape , sigma_shape ): ## reset core components/statistics
12396 _shape = (batch_size , shape [0 ])
12497 if len (shape ) > 1 :
12598 _shape = (batch_size , shape [0 ], shape [1 ], shape [2 ])
12699 restVals = jnp .zeros (_shape )
127100 dmu = restVals
128101 dtarget = restVals
102+ dSigma = jnp .zeros (sigma_shape )
129103 target = restVals
130104 mu = restVals
131105 modulator = mu + 1.
132- L = 0.
106+ L = 0. #jnp.zeros((1, 1))
133107 mask = jnp .ones (_shape )
134- return dmu , dtarget , target , mu , modulator , L , mask
108+ return dmu , dtarget , dSigma , target , mu , modulator , L , mask
135109
136110 @resolver (_reset )
137- def reset (self , dmu , dtarget , target , mu , modulator , L , mask ):
111+ def reset (self , dmu , dtarget , dSigma , target , mu , modulator , L , mask ):
138112 self .dmu .set (dmu )
139113 self .dtarget .set (dtarget )
114+ self .dSigma .set (dSigma )
140115 self .target .set (target )
141116 self .mu .set (mu )
142117 self .modulator .set (modulator )
@@ -152,12 +127,14 @@ def help(cls): ## component help function
152127 compartment_props = {
153128 "inputs" :
154129 {"mu" : "External input prediction value(s)" ,
130+ "Sigma" : "External variance/covariance prediction value(s)" ,
155131 "target" : "External input target signal value(s)" ,
156132 "modulator" : "External input modulatory/scaling signal(s)" ,
157133 "mask" : "External binary/gating mask to apply to signals" },
158134 "outputs" :
159135 {"L" : "Local loss value computed/embodied by this error-cell" ,
160136 "dmu" : "first derivative of loss w.r.t. prediction value(s)" ,
137+ "dSigma" : "first derivative of loss w.r.t. variance/covariance value(s)" ,
161138 "dtarget" : "first derivative of loss w.r.t. target value(s)" },
162139 }
163140 hyperparams = {
0 commit comments