1+ from ngclearn .components .jaxComponent import JaxComponent
12from jax import numpy as jnp , random , jit
23from functools import partial
3- from ngclearn import resolver , Component , Compartment
4- from ngclearn .components .jaxComponent import JaxComponent
54from ngclearn .utils import tensorstats
5+ from ngcsimlib .deprecators import deprecate_args
6+ from ngcsimlib .logger import info , warn
7+
8+ from ngcsimlib .compilers .process import transition
9+ #from ngcsimlib.component import Component
10+ from ngcsimlib .compartment import Compartment
611
712@partial (jit , static_argnums = [4 ])
813def _run_varfilter (dt , x , x_tr , decayFactor , gamma_tr , a_delta = 0. ):
@@ -54,13 +59,17 @@ class VarTrace(JaxComponent): ## low-pass filter
5459 a_delta: value to increment a trace by in presence of a spike; note if set
5560 to a value <= 0, then a piecewise gated trace will be used instead
5661
62+ gamma_tr: an extra multiplier in front of the leak of the trace (Default: 1)
63+
5764 decay_type: string indicating the decay type to be applied to ODE
5865 integration; low-pass filter configuration
5966
6067 :Note: string values that this can be (Default: "exp") are:
6168 1) `'lin'` = linear trace filter, i.e., decay = x_tr + (-x_tr) * (dt/tau_tr);
6269 2) `'exp'` = exponential trace filter, i.e., decay = exp(-dt/tau_tr) * x_tr;
6370 3) `'step'` = step trace, i.e., decay = 0 (a pulse applied upon input value)
71+
72+ batch_size: batch size dimension of this cell (Default: 1)
6473 """
6574
6675 # Define Functions
@@ -83,38 +92,28 @@ def __init__(self, name, n_units, tau_tr, a_delta, gamma_tr=1, decay_type="exp",
8392 self .outputs = Compartment (restVals ) # output compartment
8493 self .trace = Compartment (restVals )
8594
95+ @transition (output_compartments = ["outputs" , "trace" ])
8696 @staticmethod
87- def _advance_state (dt , decay_type , tau_tr , a_delta , gamma_tr , inputs , trace ):
97+ def advance_state (dt , decay_type , tau_tr , a_delta , gamma_tr , inputs , trace ):
8898 decayFactor = 0.
8999 if "exp" in decay_type :
90100 decayFactor = jnp .exp (- dt / tau_tr )
91101 elif "lin" in decay_type :
92102 decayFactor = (1. - dt / tau_tr )
93-
94103 _x_tr = gamma_tr * trace * decayFactor
95104 if a_delta > 0. :
96105 _x_tr = _x_tr + inputs * a_delta
97106 else :
98107 _x_tr = _x_tr * (1. - inputs ) + inputs
99-
108+ trace = _x_tr
100109 return trace , trace
101110
102- @resolver (_advance_state )
103- def advance_state (self , outputs , trace ):
104- self .outputs .set (outputs )
105- self .trace .set (trace )
106-
111+ @transition (output_compartments = ["inputs" , "outputs" , "trace" ])
107112 @staticmethod
108- def _reset (batch_size , n_units ):
113+ def reset (batch_size , n_units ):
109114 restVals = jnp .zeros ((batch_size , n_units ))
110115 return restVals , restVals , restVals
111116
112- @resolver (_reset )
113- def reset (self , inputs , outputs , trace ):
114- self .inputs .set (inputs )
115- self .outputs .set (outputs )
116- self .trace .set (trace )
117-
118117 @classmethod
119118 def help (cls ): ## component help function
120119 properties = {
0 commit comments