Skip to content

Commit 55e9fc7

Browse files
committed
update refactoring for gaussian error cell
1 parent 3193b72 commit 55e9fc7

File tree

1 file changed

+5
-21
lines changed

1 file changed

+5
-21
lines changed

ngclearn/components/neurons/graded/gaussianErrorCell.py

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from ngclearn.components.jaxComponent import JaxComponent
33
from jax import numpy as jnp, jit
44
from ngclearn.utils import tensorstats
5+
from ngcsimlib.compilers.process import transition
56

67
class GaussianErrorCell(JaxComponent): ## Rate-coded/real-valued error unit/cell
78
"""
@@ -64,8 +65,9 @@ def __init__(self, name, n_units, batch_size=1, sigma=1., shape=None, **kwargs):
6465
self.modulator = Compartment(restVals + 1.0) # to be set/consumed
6566
self.mask = Compartment(restVals + 1.0)
6667

68+
@transition(output_compartments=["dmu", "dtarget", "dSigma", "L", "mask"])
6769
@staticmethod
68-
def _advance_state(dt, mu, target, Sigma, modulator, mask): ## compute Gaussian error cell output
70+
def advance_state(dt, mu, target, Sigma, modulator, mask): ## compute Gaussian error cell output
6971
# Moves Gaussian cell dynamics one step forward. Specifically, this routine emulates the error unit
7072
# behavior of the local cost functional:
7173
# FIXME: Currently, below does: L(targ, mu) = -(1/(2*sigma)) * ||targ - mu||^2_2
@@ -83,16 +85,9 @@ def _advance_state(dt, mu, target, Sigma, modulator, mask): ## compute Gaussian
8385
mask = mask * 0. + 1. ## "eat" the mask as it should only apply at time t
8486
return dmu, dtarget, dSigma, jnp.squeeze(L), mask
8587

86-
@resolver(_advance_state)
87-
def advance_state(self, dmu, dtarget, dSigma, L, mask):
88-
self.dmu.set(dmu)
89-
self.dtarget.set(dtarget)
90-
self.dSigma.set(dSigma)
91-
self.L.set(L)
92-
self.mask.set(mask)
93-
88+
@transition(output_compartments=["dmu", "dtarget", "dSigma", "target", "mu", "modulator", "L", "mask"])
9489
@staticmethod
95-
def _reset(batch_size, shape, sigma_shape): ## reset core components/statistics
90+
def reset(batch_size, shape, sigma_shape): ## reset core components/statistics
9691
_shape = (batch_size, shape[0])
9792
if len(shape) > 1:
9893
_shape = (batch_size, shape[0], shape[1], shape[2])
@@ -107,17 +102,6 @@ def _reset(batch_size, shape, sigma_shape): ## reset core components/statistics
107102
mask = jnp.ones(_shape)
108103
return dmu, dtarget, dSigma, target, mu, modulator, L, mask
109104

110-
@resolver(_reset)
111-
def reset(self, dmu, dtarget, dSigma, target, mu, modulator, L, mask):
112-
self.dmu.set(dmu)
113-
self.dtarget.set(dtarget)
114-
self.dSigma.set(dSigma)
115-
self.target.set(target)
116-
self.mu.set(mu)
117-
self.modulator.set(modulator)
118-
self.L.set(L)
119-
self.mask.set(mask)
120-
121105
@classmethod
122106
def help(cls): ## component help function
123107
properties = {

0 commit comments

Comments
 (0)