33import jax
44from typing import Union
55
6+ from ngcsimlib .logger import info , warn
67from ngcsimlib .compartment import Compartment
78from ngcsimlib .parser import compilable
89
9-
1010class PhasorCell (JaxComponent ):
1111 """
1212 A phasor cell that emits a pulse at a regular interval.
@@ -33,25 +33,19 @@ class PhasorCell(JaxComponent):
3333
3434 # Define Functions
3535 def __init__ (
36- self , name : str , n_units : int , target_freq : float = 63.75 ,
37- batch_size : int = 1 , key : Union [jax .Array , None ] = None ):
38- super ().__init__ (name = name , key = key )
39-
40- _key , subkey = random .split (self .key .get (), 2 )
41- self .key .set (_key )
36+ self , name , n_units , target_freq = 63.75 , batch_size = 1 , disable_phasor = False , ** kwargs ):
37+ super ().__init__ (name , ** kwargs )
4238
4339 ## Phasor meta-parameters
44- self .target_freq = Compartment (target_freq , fixed = True ) ## maximum frequency (in Hertz/Hz)
45- self .base_scale = Compartment (random .poisson (subkey [0 ], lam = target_freq , shape = (batch_size , n_units )) / target_freq , fixed = True )
40+ self .target_freq = target_freq ## maximum frequency (in Hertz/Hz)
4641
4742 ## Layer Size Setup
48- self .batch_size = Compartment (batch_size , fixed = True )
49- self .n_units = Compartment (n_units , fixed = True )
50-
51-
52-
43+ self .batch_size = batch_size
44+ self .n_units = n_units
45+ _key , * subkey = random .split (self .key .get (), 3 )
46+ self .key .set (_key )
5347 ## Compartment setup
54- restVals = jnp .zeros ((batch_size , n_units ))
48+ restVals = jnp .zeros ((self . batch_size , self . n_units ))
5549 self .inputs = Compartment (restVals ,
5650 display_name = "Input Stimulus" ) # input
5751 # compartment
@@ -60,44 +54,100 @@ def __init__(
6054 self .tols = Compartment (initial_value = restVals ,
6155 display_name = "Time-of-Last-Spike" , units = "ms" ) # time of last spike
6256 self .angles = Compartment (restVals , display_name = "Angles" , units = "deg" )
63-
57+ # self.base_scale = random.uniform(subkey, self.angles.value.shape,
58+ # minval=0.75, maxval=1.25)
59+ # self.base_scale = ((random.normal(subkey, self.angles.value.shape) * 0.15) + 1)
60+ # alpha = ((random.normal(subkey, self.angles.value.shape) * (jnp.sqrt(target_freq) / target_freq)) + 1)
61+ # beta = random.poisson(subkey, lam=target_freq, shape=self.angles.value.shape) / target_freq
62+
63+ self .base_scale = random .poisson (subkey [0 ], lam = target_freq , shape = self .angles .get ().shape ) / target_freq
64+ self .disable_phasor = disable_phasor
65+
66+ def validate (self , dt = None , ** validation_kwargs ):
67+ valid = super ().validate (** validation_kwargs )
68+ if dt is None :
69+ warn (f"{ self .name } requires a validation kwarg of `dt`" )
70+ return False
71+ ## check for unstable combinations of dt and target-frequency
72+ # meta-params
73+ events_per_timestep = (dt / 1000. ) * self .target_freq ##
74+ # compute scaled probability
75+ if events_per_timestep > 1. :
76+ valid = False
77+ warn (
78+ f"{ self .name } will be unable to make as many temporal events "
79+ f"as "
80+ f"requested! ({ events_per_timestep } events/timestep) Unstable "
81+ f"combination of dt = { dt } and target_freq = "
82+ f"{ self .target_freq } "
83+ f"being used!"
84+ )
85+ return valid
86+
87+ # @transition(output_compartments=["outputs", "tols", "key", "angles"])
88+ # @staticmethod
6489 @compilable
65- def advance_state (self , t , dt ):
90+ def advance_state (self , t , dt , ):
91+
92+ inputs = self .inputs .get ()
93+ angles = self .angles .get ()
94+ tols = self .tols .get ()
95+
6696 ms_per_second = 1000 # ms/s
67- events_per_ms = self .target_freq . get () / ms_per_second # e/s s/ms -> e/ms
97+ events_per_ms = self .target_freq / ms_per_second # e/s s/ms -> e/ms
6898 ms_per_event = 1 / events_per_ms # ms/e
6999 time_step_per_event = ms_per_event / dt # ms/e * ts/ms -> ts / e
70100 angle_per_event = 2 * jnp .pi # rad / e
71101 angle_per_timestep = angle_per_event / time_step_per_event # rad / e
72102 # * e/ts -> rad / ts
73103 key , * subkey = random .split (self .key .get (), 3 )
104+ # scatter = random.uniform(subkey, angles.shape, minval=0.5,
105+ # maxval=1.5) * base_scale
74106
75- scatter = ((random .normal (subkey [0 ], self . angles .get (). shape ) * 0.2 ) + 1 ) * self .base_scale . get ()
107+ scatter = ((random .normal (subkey [0 ], angles .shape ) * 0.2 ) + 1 ) * self .base_scale
76108 scattered_update = angle_per_timestep * scatter
77- scaled_scattered_update = scattered_update * self .inputs .get ()
78-
79- updated_angles = self .angles .get () + scaled_scattered_update
80- self .outputs .set (jnp .where (updated_angles > angle_per_event , 1. , 0. ))
109+ scaled_scattered_update = scattered_update * inputs
81110
82- self .angles .set (jnp .where (updated_angles > angle_per_event ,
111+ updated_angles = angles + scaled_scattered_update
112+ outputs = jnp .where (updated_angles > angle_per_event , 1. , 0. )
113+ updated_angles = jnp .where (updated_angles > angle_per_event ,
83114 updated_angles - angle_per_event ,
84- updated_angles ))
115+ updated_angles )
116+ if self .disable_phasor :
117+ outputs = inputs + 0
118+ tols = tols * (1. - outputs ) + t * outputs
119+
120+ self .outputs .set (outputs )
121+ self .tols .set (tols )
122+ self .key .set (key )
123+ self .angles .set (updated_angles )
85124
86- self .tols .set (self .tols .get () * (1. - self .outputs .get ()) + t * self .outputs .get ())
87125
126+ # @transition(output_compartments=["inputs", "outputs", "tols", "angles", "key"])
127+ # @staticmethod
88128 @compilable
89129 def reset (self ):
90- restVals = jnp .zeros ((self .batch_size .get (), self .n_units .get ()))
130+ restVals = jnp .zeros ((self .batch_size , self .n_units ))
131+ key , * subkey = random .split (self .key .get (), 3 )
132+
91133 # BUG: the self.inputs here does not have the targeted field
92134 # NOTE: Quick workaround is to check if targeted is in the input or not
93135 hasattr (self .inputs , "targeted" ) and not self .inputs .targeted and self .inputs .set (restVals )
94136 self .outputs .set (restVals )
95137 self .tols .set (restVals )
96138 self .angles .set (restVals )
97- key , _ = random .split (self .key .get (), 2 )
98139 self .key .set (key )
99140
100141
142+ def save (self , directory , ** kwargs ):
143+ file_name = directory + "/" + self .name + ".npz"
144+ jnp .savez (file_name , key = self .key .value )
145+
146+ def load (self , directory , ** kwargs ):
147+ file_name = directory + "/" + self .name + ".npz"
148+ data = jnp .load (file_name )
149+ self .key .set (data ['key' ])
150+
101151 @classmethod
102152 def help (cls ): ## component help function
103153 properties = {
0 commit comments