22from ngclearn .components .jaxComponent import JaxComponent
33from jax import numpy as jnp , jit
44from ngclearn .utils import tensorstats
5+ from ngcsimlib .compilers .process import transition
56
67class GaussianErrorCell (JaxComponent ): ## Rate-coded/real-valued error unit/cell
78 """
@@ -64,8 +65,9 @@ def __init__(self, name, n_units, batch_size=1, sigma=1., shape=None, **kwargs):
6465 self .modulator = Compartment (restVals + 1.0 ) # to be set/consumed
6566 self .mask = Compartment (restVals + 1.0 )
6667
68+ @transition (output_compartments = ["dmu" , "dtarget" , "dSigma" , "L" , "mask" ])
6769 @staticmethod
68- def _advance_state (dt , mu , target , Sigma , modulator , mask ): ## compute Gaussian error cell output
70+ def advance_state (dt , mu , target , Sigma , modulator , mask ): ## compute Gaussian error cell output
6971 # Moves Gaussian cell dynamics one step forward. Specifically, this routine emulates the error unit
7072 # behavior of the local cost functional:
7173 # FIXME: Currently, below does: L(targ, mu) = -(1/(2*sigma)) * ||targ - mu||^2_2
@@ -83,16 +85,9 @@ def _advance_state(dt, mu, target, Sigma, modulator, mask): ## compute Gaussian
8385 mask = mask * 0. + 1. ## "eat" the mask as it should only apply at time t
8486 return dmu , dtarget , dSigma , jnp .squeeze (L ), mask
8587
86- @resolver (_advance_state )
87- def advance_state (self , dmu , dtarget , dSigma , L , mask ):
88- self .dmu .set (dmu )
89- self .dtarget .set (dtarget )
90- self .dSigma .set (dSigma )
91- self .L .set (L )
92- self .mask .set (mask )
93-
88+ @transition (output_compartments = ["dmu" , "dtarget" , "dSigma" , "target" , "mu" , "modulator" , "L" , "mask" ])
9489 @staticmethod
95- def _reset (batch_size , shape , sigma_shape ): ## reset core components/statistics
90+ def reset (batch_size , shape , sigma_shape ): ## reset core components/statistics
9691 _shape = (batch_size , shape [0 ])
9792 if len (shape ) > 1 :
9893 _shape = (batch_size , shape [0 ], shape [1 ], shape [2 ])
@@ -107,17 +102,6 @@ def _reset(batch_size, shape, sigma_shape): ## reset core components/statistics
107102 mask = jnp .ones (_shape )
108103 return dmu , dtarget , dSigma , target , mu , modulator , L , mask
109104
110- @resolver (_reset )
111- def reset (self , dmu , dtarget , dSigma , target , mu , modulator , L , mask ):
112- self .dmu .set (dmu )
113- self .dtarget .set (dtarget )
114- self .dSigma .set (dSigma )
115- self .target .set (target )
116- self .mu .set (mu )
117- self .modulator .set (modulator )
118- self .L .set (L )
119- self .mask .set (mask )
120-
121105 @classmethod
122106 def help (cls ): ## component help function
123107 properties = {
0 commit comments