1- from jax import numpy as jnp , jit
2- from functools import partial
3- from ngclearn import resolver , Component , Compartment
41from ngclearn .components .jaxComponent import JaxComponent
2+ from jax import numpy as jnp , random , jit
3+ from functools import partial
54from ngclearn .utils import tensorstats
5+ from ngcsimlib .deprecators import deprecate_args
6+ from ngcsimlib .logger import info , warn
7+
8+ from ngcsimlib .compilers .process import transition
9+ #from ngcsimlib.component import Component
10+ from ngcsimlib .compartment import Compartment
611
712@partial (jit , static_argnums = [5 ,6 ])
813def _apply_kernel (tf_curr , s , t , tau_w , win_len , krn_start , krn_end ):
@@ -40,6 +45,8 @@ class ExpKernel(JaxComponent): ## exponential kernel
4045 nu: (ms, spike time interval for window)
4146
4247 tau_w: spike window time constant (in micro-secs, or nano-s)
48+
49+ batch_size: batch size dimension of this cell (Default: 1)
4350 """
4451
4552 # Define Functions
@@ -60,31 +67,22 @@ def __init__(self, name, n_units, dt, tau_w=500., nu=4., batch_size=1, **kwargs)
6067 ## window of spike times
6168 self .tf = Compartment (jnp .zeros ((self .win_len , self .batch_size , self .n_units )))
6269
70+ @transition (output_compartments = ["epsp" , "tf" ])
6371 @staticmethod
64- def _advance_state (t , tau_w , win_len , inputs , tf ):
72+ def advance_state (t , tau_w , win_len , inputs , tf ):
6573 s = inputs
6674 ## update spike time window and corresponding window volume
6775 tf , epsp = _apply_kernel (tf , s , t , tau_w , win_len , krn_start = 0 ,
6876 krn_end = win_len - 1 ) #0:win_len-1)
6977 return epsp , tf
7078
71- @resolver (_advance_state )
72- def advance_state (self , epsp , tf ):
73- self .epsp .set (epsp )
74- self .tf .set (tf )
75-
79+ @transition (output_compartments = ["inputs" , "epsp" , "tf" ])
7680 @staticmethod
77- def _reset (batch_size , n_units , win_len ):
81+ def reset (batch_size , n_units , win_len ):
7882 restVals = jnp .zeros ((batch_size , n_units ))
7983 restTensor = jnp .zeros ([win_len , batch_size , n_units ], jnp .float32 ) # tf
8084 return restVals , restVals , restTensor # inputs, epsp, tf
8185
82- @resolver (_reset )
83- def reset (self , inputs , epsp , tf ):
84- self .inputs .set (inputs )
85- self .epsp .set (epsp )
86- self .tf .set (tf )
87-
8886 @classmethod
8987 def help (cls ): ## component help function
9088 properties = {
0 commit comments