1- from ngclearn import resolver , Compartment
21from ngclearn .components .jaxComponent import JaxComponent
3- from ngclearn .utils import tensorstats
42from jax import numpy as jnp , random
5- from ngcsimlib .logger import warn
3+ from ngclearn .utils import tensorstats
4+ from ngcsimlib .deprecators import deprecate_args
5+ from ngcsimlib .logger import info , warn
6+
7+ from ngcsimlib .compilers .process import transition
8+ #from ngcsimlib.component import Component
9+ from ngcsimlib .compartment import Compartment
10+
611
712class PhasorCell (JaxComponent ):
813 """
@@ -12,8 +17,9 @@ class PhasorCell(JaxComponent):
1217 | inputs - input (takes in external signals)
1318 | --- Cell State Compartments: ---
1419 | key - JAX PRNG key
20+ | angles - current angle of phasor
1521 | --- Cell Output Compartments: ---
16- | outputs - output
22+ | outputs - output of phasor cell
1723 | tols - time-of-last-spike
1824
1925 Args:
@@ -23,6 +29,8 @@ class PhasorCell(JaxComponent):
2329
2430 target_freq: maximum frequency (in Hertz) of this spike train
2531 (must be > 0.)
32+
33+ batch_size: batch size dimension of this cell (Default: 1)
2634 """
2735
2836 # Define Functions
@@ -63,8 +71,7 @@ def validate(self, dt=None, **validation_kwargs):
6371 return False
6472 ## check for unstable combinations of dt and target-frequency
6573 # meta-params
66- events_per_timestep = (
67- dt / 1000. ) * self .target_freq ##
74+ events_per_timestep = (dt / 1000. ) * self .target_freq ##
6875 # compute scaled probability
6976 if events_per_timestep > 1. :
7077 valid = False
@@ -78,9 +85,9 @@ def validate(self, dt=None, **validation_kwargs):
7885 )
7986 return valid
8087
88+ @transition (output_compartments = ["outputs" , "tols" , "key" , "angles" ])
8189 @staticmethod
82- def _advance_state (t , dt , target_freq , key ,
83- inputs , angles , tols , base_scale ):
90+ def advance_state (t , dt , target_freq , key , inputs , angles , tols , base_scale ):
8491 ms_per_second = 1000 # ms/s
8592 events_per_ms = target_freq / ms_per_second # e/s s/ms -> e/ms
8693 ms_per_event = 1 / events_per_ms # ms/e
@@ -105,27 +112,13 @@ def _advance_state(t, dt, target_freq, key,
105112
106113 return outputs , tols , key , updated_angles
107114
108- @resolver (_advance_state )
109- def advance_state (self , outputs , tols , key , angles ):
110- self .outputs .set (outputs )
111- self .tols .set (tols )
112- self .key .set (key )
113- self .angles .set (angles )
114-
115+ @transition (output_compartments = ["inputs" , "outputs" , "tols" , "angles" , "key" ])
115116 @staticmethod
116- def _reset (batch_size , n_units , key , target_freq ):
117+ def reset (batch_size , n_units , key , target_freq ):
117118 restVals = jnp .zeros ((batch_size , n_units ))
118119 key , subkey = random .split (key , 2 )
119120 return restVals , restVals , restVals , restVals , key
120121
121- @resolver (_reset )
122- def reset (self , inputs , outputs , tols , angles , key ):
123- self .inputs .set (inputs )
124- self .outputs .set (outputs )
125- self .tols .set (tols )
126- self .key .set (key )
127- self .angles .set (angles )
128-
129122 def save (self , directory , ** kwargs ):
130123 file_name = directory + "/" + self .name + ".npz"
131124 jnp .savez (file_name , key = self .key .value )
@@ -154,7 +147,7 @@ def help(cls): ## component help function
154147 hyperparams = {
155148 "n_units" : "Number of neuronal cells to model in this layer" ,
156149 "batch_size" : "Batch size dimension of this component" ,
157- "target_freq" : "Maximum spike frequency of the train produced" ,
150+ "target_freq" : "Maximum spike frequency of the (spike) train produced" ,
158151 }
159152 info = {cls .__name__ : properties ,
160153 "compartments" : compartment_props ,
0 commit comments