1- from ngclearn import resolver , Component , Compartment
1+ # from ngclearn import resolver, Component, Compartment
22from ngclearn .components .jaxComponent import JaxComponent
33from jax import numpy as jnp , random , jit
44from ngclearn .utils import tensorstats
55from functools import partial
66from ngcsimlib .deprecators import deprecate_args
77from ngcsimlib .logger import info , warn
88
9- @jit
10- def _update_times (t , s , tols ):
11- """
12- Updates time-of-last-spike (tols) variable.
13-
14- Args:
15- t: current time (a scalar/int value)
16-
17- s: binary spike vector
18-
19- tols: current time-of-last-spike variable
20-
21- Returns:
22- updated tols variable
23- """
24- _tols = (1. - s ) * tols + (s * t )
25- return _tols
9+ from ngcsimlib .compilers .process import Process , transition
10+ from ngcsimlib .component import Component
11+ from ngcsimlib .compartment import Compartment
2612
2713class BernoulliCell (JaxComponent ):
2814 """
2915 A Bernoulli cell that produces spikes by sampling a Bernoulli distribution
3016 on-the-fly (to produce data-scaled Bernoulli spike trains).
3117
3218 | --- Cell Input Compartments: ---
33- | inputs - input (takes in external signals)
19+ | inputs - input (takes in external signals -- should be probabilities w/ values in [0,1] )
3420 | --- Cell State Compartments: ---
3521 | key - JAX PRNG key
3622 | --- Cell Output Compartments: ---
@@ -41,10 +27,13 @@ class BernoulliCell(JaxComponent):
4127 name: the string name of this cell
4228
4329 n_units: number of cellular entities (neural population size)
30+
31+ batch_size: batch size dimension of this cell (Default: 1)
4432 """
4533
4634 def __init__ (self , name , n_units , batch_size = 1 , ** kwargs ):
47- super ().__init__ (name , ** kwargs )
35+ #super().__init__(name, **kwargs)
36+ super (JaxComponent , self ).__init__ (name , ** kwargs )
4837
4938 ## Layer Size Setup
5039 self .batch_size = batch_size
@@ -56,30 +45,24 @@ def __init__(self, name, n_units, batch_size=1, **kwargs):
5645 self .outputs = Compartment (restVals , display_name = "Spikes" ) # output compartment
5746 self .tols = Compartment (restVals , display_name = "Time-of-Last-Spike" , units = "ms" ) # time of last spike
5847
48+ @transition (output_compartments = ["outputs" , "tols" , "key" ])
5949 @staticmethod
60- def _advance_state (t , key , inputs , tols ):
50+ def advance_state (t , key , inputs , tols ):
51+ ## NOTE: should `inputs` be checked if bounded to [0,1]?
6152 key , * subkeys = random .split (key , 2 )
6253 outputs = random .bernoulli (subkeys [0 ], p = inputs ).astype (jnp .float32 )
63- tols = _update_times (t , outputs , tols )
54+ # Updates time-of-last-spike (tols) variable:
55+ # output = s = binary spike vector
56+ # tols = current time-of-last-spike variable
57+ tols = (1. - outputs ) * tols + (outputs * t )
6458 return outputs , tols , key
6559
66- @resolver (_advance_state )
67- def advance_state (self , outputs , tols , key ):
68- self .outputs .set (outputs )
69- self .tols .set (tols )
70- self .key .set (key )
71-
60+ @transition (output_compartments = ["inputs" , "outputs" , "tols" ])
7261 @staticmethod
73- def _reset (batch_size , n_units ):
62+ def reset (batch_size , n_units ):
7463 restVals = jnp .zeros ((batch_size , n_units ))
7564 return restVals , restVals , restVals
7665
77- @resolver (_reset )
78- def reset (self , inputs , outputs , tols ):
79- self .inputs .set (inputs )
80- self .outputs .set (outputs ) #None
81- self .tols .set (tols )
82-
8366 def save (self , directory , ** kwargs ):
8467 file_name = directory + "/" + self .name + ".npz"
8568 jnp .savez (file_name , key = self .key .value )
0 commit comments