Skip to content

Commit 4af85dc

Browse files
author
Alexander Ororbia
committed
cleaned up ode_utils, cleaned up gaussian/laplacian cell
1 parent b9227f0 commit 4af85dc

File tree

3 files changed

+107
-137
lines changed

3 files changed

+107
-137
lines changed

ngclearn/components/neurons/graded/gaussianErrorCell.py

Lines changed: 42 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -3,78 +3,35 @@
33
from jax import numpy as jnp, jit
44
from ngclearn.utils import tensorstats
55

6-
def _run_cell(dt, targ, mu, sigma):
7-
"""
8-
Moves cell dynamics one step forward.
9-
10-
Args:
11-
dt: integration time constant
12-
13-
targ: target pattern value
14-
15-
mu: prediction value
16-
17-
sigma: prediction variance
18-
19-
Returns:
20-
derivative w.r.t. mean "dmu", derivative w.r.t. target dtarg, local loss
21-
"""
22-
return _run_gaussian_cell(dt, targ, mu, sigma)
23-
24-
@jit
25-
def _run_gaussian_cell(dt, targ, mu, sigma):
26-
"""
27-
Moves Gaussian cell dynamics one step forward. Specifically, this
28-
routine emulates the error unit behavior of the local cost functional:
29-
30-
| L(targ, mu) = -(1/2) * ||targ - mu||^2_2
31-
| or log likelihood of the multivariate Gaussian with identity covariance
32-
33-
Args:
34-
dt: integration time constant
35-
36-
targ: target pattern value
37-
38-
mu: prediction value
39-
40-
sigma: prediction variance
41-
42-
Returns:
43-
derivative w.r.t. mean "dmu", derivative w.r.t. target dtarg, loss
44-
"""
45-
dmu = (targ - mu)/sigma # e (error unit)
46-
dtarg = -dmu # reverse of e
47-
dsigma = 1. # no derivative is calculated at this time for sigma
48-
L = -jnp.sum(jnp.square(dmu)) * 0.5 / sigma
49-
return dmu, dtarg, dsigma, L
50-
516
class GaussianErrorCell(JaxComponent): ## Rate-coded/real-valued error unit/cell
527
"""
538
A simple (non-spiking) Gaussian error cell - this is a fixed-point solution
549
of a mismatch signal.
5510
5611
| --- Cell Input Compartments: ---
5712
| mu - predicted value (takes in external signals)
13+
| Sigma - predicted covariance (takes in external signals)
5814
| target - desired/goal value (takes in external signals)
5915
| modulator - modulation signal (takes in optional external signals)
6016
| mask - binary/gating mask to apply to error neuron calculations
6117
| --- Cell Output Compartments: ---
6218
| L - local loss function embodied by this cell
6319
| dmu - derivative of L w.r.t. mu
20+
| dSigma - derivative of L w.r.t. Sigma
6421
| dtarget - derivative of L w.r.t. target
6522
6623
Args:
6724
name: the string name of this cell
6825
6926
n_units: number of cellular entities (neural population size)
7027
71-
tau_m: (Unused -- currently cell is a fixed-point model)
72-
73-
leakRate: (Unused -- currently cell is a fixed-point model)
28+
batch_size: batch size dimension of this cell (Default: 1)
7429
75-
sigma: prediction covariance matrix (𝚺) in multivariate gaussian distribution
30+
sigma: initial/fixed value for prediction covariance matrix (𝚺) in multivariate gaussian distribution;
31+
Note that if the compartment `Sigma` is never used, then this cell assumes that the covariance collapses
32+
to a constant/fixed `sigma`
7633
"""
77-
def __init__(self, name, n_units, batch_size=1, sigma=1, shape=None, **kwargs):
34+
def __init__(self, name, n_units, batch_size=1, sigma=1., shape=None, **kwargs):
7835
super().__init__(name, **kwargs)
7936

8037
## Layer Size Setup
@@ -83,60 +40,78 @@ def __init__(self, name, n_units, batch_size=1, sigma=1, shape=None, **kwargs):
8340
shape = (n_units,) ## we set shape to be equal to n_units if nothing provided
8441
else:
8542
_shape = (batch_size, shape[0], shape[1], shape[2]) ## shape is 4D tensor
43+
sigma_shape = (1,1)
44+
if not isinstance(sigma, float):
45+
sigma_shape = jnp.array(sigma).shape
46+
self.sigma_shape = sigma_shape
8647
self.shape = shape
8748
self.n_units = n_units
8849
self.batch_size = batch_size
89-
self.sigma = sigma
9050

9151
## Convolution shape setup
9252
self.width = self.height = n_units
9353

9454
## Compartment setup
9555
restVals = jnp.zeros(_shape)
96-
self.L = Compartment(0.) # loss compartment
97-
self.mu = Compartment(restVals) # mean/mean name. input wire
56+
self.L = Compartment(0., display_name="Gaussian Log likelihood", units="nats") # loss compartment
57+
self.mu = Compartment(restVals, display_name="Gaussian mean") # mean/mean name. input wire
9858
self.dmu = Compartment(restVals) # derivative mean
99-
self.target = Compartment(restVals) # target. input wire
59+
_Sigma = jnp.zeros(sigma_shape)
60+
self.Sigma = Compartment(_Sigma + sigma, display_name="Gaussian variance/covariance")
61+
self.dSigma = Compartment(_Sigma)
62+
self.target = Compartment(restVals, display_name="Gaussian data/target variable") # target. input wire
10063
self.dtarget = Compartment(restVals) # derivative target
10164
self.modulator = Compartment(restVals + 1.0) # to be set/consumed
10265
self.mask = Compartment(restVals + 1.0)
10366

10467
@staticmethod
105-
def _advance_state(dt, mu, dmu, target, dtarget, sigma, modulator, mask):
106-
## compute Gaussian error cell output
107-
dmu, dtarget, dsigma, L = _run_cell(dt, target * mask, mu * mask, sigma)
108-
dmu = dmu * modulator * mask
68+
def _advance_state(dt, mu, target, Sigma, modulator, mask): ## compute Gaussian error cell output
69+
# Moves Gaussian cell dynamics one step forward. Specifically, this routine emulates the error unit
70+
# behavior of the local cost functional:
71+
# FIXME: Currently, below does: L(targ, mu) = -(1/(2*sigma)) * ||targ - mu||^2_2
72+
# but should support full log likelihood of the multivariate Gaussian with covariance of different types
73+
# TODO: could introduce a variant of GaussianErrorCell that moves according to an ODE
74+
# (using integration time constant dt)
75+
_dmu = (target - mu) # e (error unit)
76+
dmu = _dmu / Sigma
77+
dtarget = -dmu # reverse of e
78+
dSigma = Sigma * 0 + 1. # no derivative is calculated at this time for sigma
79+
L = -jnp.sum(jnp.square(_dmu)) * (0.5 / Sigma)
80+
81+
dmu = dmu * modulator * mask ## not sure how mask will apply to a full covariance...
10982
dtarget = dtarget * modulator * mask
110-
dsigma = dsigma * 0 + 1. # no derivative is calculated at this time for sigma
11183
mask = mask * 0. + 1. ## "eat" the mask as it should only apply at time t
112-
return dmu, dtarget, L, mask
84+
return dmu, dtarget, dSigma, L[0, 0], mask
11385

11486
@resolver(_advance_state)
115-
def advance_state(self, dmu, dtarget, L, mask):
87+
def advance_state(self, dmu, dtarget, dSigma, L, mask):
11688
self.dmu.set(dmu)
11789
self.dtarget.set(dtarget)
90+
self.dSigma.set(dSigma)
11891
self.L.set(L)
11992
self.mask.set(mask)
12093

12194
@staticmethod
122-
def _reset(batch_size, shape): #n_units
95+
def _reset(batch_size, shape, sigma_shape): ## reset core components/statistics
12396
_shape = (batch_size, shape[0])
12497
if len(shape) > 1:
12598
_shape = (batch_size, shape[0], shape[1], shape[2])
12699
restVals = jnp.zeros(_shape)
127100
dmu = restVals
128101
dtarget = restVals
102+
dSigma = jnp.zeros(sigma_shape)
129103
target = restVals
130104
mu = restVals
131105
modulator = mu + 1.
132-
L = 0.
106+
L = 0. #jnp.zeros((1, 1))
133107
mask = jnp.ones(_shape)
134-
return dmu, dtarget, target, mu, modulator, L, mask
108+
return dmu, dtarget, dSigma, target, mu, modulator, L, mask
135109

136110
@resolver(_reset)
137-
def reset(self, dmu, dtarget, target, mu, modulator, L, mask):
111+
def reset(self, dmu, dtarget, dSigma, target, mu, modulator, L, mask):
138112
self.dmu.set(dmu)
139113
self.dtarget.set(dtarget)
114+
self.dSigma.set(dSigma)
140115
self.target.set(target)
141116
self.mu.set(mu)
142117
self.modulator.set(modulator)
@@ -152,12 +127,14 @@ def help(cls): ## component help function
152127
compartment_props = {
153128
"inputs":
154129
{"mu": "External input prediction value(s)",
130+
"Sigma": "External variance/covariance prediction value(s)",
155131
"target": "External input target signal value(s)",
156132
"modulator": "External input modulatory/scaling signal(s)",
157133
"mask": "External binary/gating mask to apply to signals"},
158134
"outputs":
159135
{"L": "Local loss value computed/embodied by this error-cell",
160136
"dmu": "first derivative of loss w.r.t. prediction value(s)",
137+
"dSigma": "first derivative of loss w.r.t. variance/covariance value(s)",
161138
"dtarget": "first derivative of loss w.r.t. target value(s)"},
162139
}
163140
hyperparams = {

0 commit comments

Comments
 (0)