11from jax import random , numpy as jnp , jit
2- from ngcsimlib .compilers .process import transition
3- from ngcsimlib .component import Component
4- from ngcsimlib .compartment import Compartment
5-
2+ from ngclearn .components .jaxComponent import JaxComponent
3+ from ngclearn .utils import tensorstats
64from ngclearn .utils .weight_distribution import initialize_params
75from ngcsimlib .logger import info
6+
87from ngclearn .components .synapses import DenseSynapse
9- from ngclearn .utils import tensorstats
8+ from ngcsimlib .compartment import Compartment
9+ from ngcsimlib .parser import compilable
1010
1111class AlphaSynapse (DenseSynapse ): ## dynamic alpha synapse cable
1212 """
@@ -64,8 +64,8 @@ class AlphaSynapse(DenseSynapse): ## dynamic alpha synapse cable
6464
6565 # Define Functions
6666 def __init__ (
67- self , name , shape , tau_decay , g_syn_bar , syn_rest , weight_init = None , bias_init = None , resist_scale = 1. , p_conn = 1. ,
68- is_nonplastic = True , ** kwargs
67+ self , name , shape , tau_decay , g_syn_bar , syn_rest , weight_init = None , bias_init = None , resist_scale = 1. ,
68+ p_conn = 1. , is_nonplastic = True , ** kwargs
6969 ):
7070 super ().__init__ (name , shape , weight_init , bias_init , resist_scale , p_conn , ** kwargs )
7171 ## dynamic synapse meta-parameters
@@ -82,55 +82,55 @@ def __init__(
8282 self .g_syn = Compartment (postVals ) ## conductance variable
8383 self .h_syn = Compartment (postVals ) ## intermediate conductance variable
8484 if is_nonplastic :
85- self .weights .set (self .weights .value * 0 + 1. )
85+ self .weights .set (self .weights .get () * 0 + 1. )
8686
87- @transition (output_compartments = ["outputs" , "i_syn" , "g_syn" , "h_syn" ])
88- @staticmethod
89- def advance_state (
90- dt , tau_decay , g_syn_bar , syn_rest , Rscale , inputs , weights , i_syn , g_syn , h_syn , v
91- ):
92- s = inputs
87+ @compilable
88+ def advance_state (self , t , dt ):
89+ s = self .inputs .get ()
9390 ## advance conductance variable(s)
94- _out = jnp .matmul (s , weights ) ## sum all pre-syn spikes at t going into post-neuron)
95- dhsyn_dt = - h_syn / tau_decay + (_out * g_syn_bar ) * (1. / dt )
96- h_syn = h_syn + dhsyn_dt * dt ## run Euler step to move intermediate conductance h
91+ _out = jnp .matmul (s , self . weights . get () ) ## sum all pre-syn spikes at t going into post-neuron)
92+ dhsyn_dt = - self . h_syn . get () / self . tau_decay + (_out * self . g_syn_bar ) * (1. / dt )
93+ h_syn = self . h_syn . get () + dhsyn_dt * dt ## run Euler step to move intermediate conductance h
9794
98- dgsyn_dt = - g_syn / tau_decay + h_syn * (1. / dt ) # or -g_syn/tau_decay + h_syn/tau_decay
99- g_syn = g_syn + dgsyn_dt * dt ## run Euler step to move conductance g
95+ dgsyn_dt = - self . g_syn . get () / self . tau_decay + h_syn * (1. / dt ) # or -g_syn/tau_decay + h_syn/tau_decay
96+ g_syn = self . g_syn . get () + dgsyn_dt * dt ## run Euler step to move conductance g
10097
10198 ## compute derive electrical current variable
102- i_syn = - g_syn * Rscale
103- if syn_rest is not None :
104- i_syn = - (g_syn * Rscale ) * (v - syn_rest )
105- outputs = i_syn #jnp.matmul(inputs, Wdyn * Rscale) + biases
106- return outputs , i_syn , g_syn , h_syn
107-
108- @transition (output_compartments = ["inputs" , "outputs" , "i_syn" , "g_syn" , "h_syn" , "v" ])
109- @staticmethod
110- def reset (batch_size , shape ):
111- preVals = jnp .zeros ((batch_size , shape [0 ]))
112- postVals = jnp .zeros ((batch_size , shape [1 ]))
113- inputs = preVals
114- outputs = postVals
115- i_syn = postVals
116- g_syn = postVals
117- h_syn = postVals
118- v = postVals
119- return inputs , outputs , i_syn , g_syn , h_syn , v
120-
121- def save (self , directory , ** kwargs ):
122- file_name = directory + "/" + self .name + ".npz"
123- if self .bias_init != None :
124- jnp .savez (file_name , weights = self .weights .value , biases = self .biases .value )
125- else :
126- jnp .savez (file_name , weights = self .weights .value )
127-
128- def load (self , directory , ** kwargs ):
129- file_name = directory + "/" + self .name + ".npz"
130- data = jnp .load (file_name )
131- self .weights .set (data ['weights' ])
132- if "biases" in data .keys ():
133- self .biases .set (data ['biases' ])
99+ i_syn = - g_syn * self .resist_scale
100+ if self .syn_rest is not None :
101+ i_syn = - (g_syn * self .resist_scale ) * (self .v .get () - self .syn_rest )
102+ outputs = i_syn #jnp.matmul(inputs, Wdyn * self.resist_scale) + biases
103+
104+ self .outputs .set (outputs )
105+ self .i_syn .set (i_syn )
106+ self .g_syn .set (g_syn )
107+ self .h_syn .set (h_syn )
108+
109+ @compilable
110+ def reset (self ):
111+ preVals = jnp .zeros ((self .batch_size .get (), self .shape .get ()[0 ]))
112+ postVals = jnp .zeros ((self .batch_size .get (), self .shape .get ()[1 ]))
113+ if not self .inputs .targeted :
114+ self .inputs .set (preVals )
115+ self .outputs .set (postVals )
116+ self .i_syn .set (postVals )
117+ self .g_syn .set (postVals )
118+ self .h_syn .set (postVals )
119+ self .v .set (postVals )
120+
121+ # def save(self, directory, **kwargs):
122+ # file_name = directory + "/" + self.name + ".npz"
123+ # if self.bias_init != None:
124+ # jnp.savez(file_name, weights=self.weights.value, biases=self.biases.value)
125+ # else:
126+ # jnp.savez(file_name, weights=self.weights.value)
127+ #
128+ # def load(self, directory, **kwargs):
129+ # file_name = directory + "/" + self.name + ".npz"
130+ # data = jnp.load(file_name)
131+ # self.weights.set(data['weights'])
132+ # if "biases" in data.keys():
133+ # self.biases.set(data['biases'])
134134
135135 @classmethod
136136 def help (cls ): ## component help function
@@ -170,17 +170,3 @@ def help(cls): ## component help function
170170 "dgsyn_dt = -g_syn/tau_decay + h_syn" ,
171171 "hyperparameters" : hyperparams }
172172 return info
173-
174- def __repr__ (self ):
175- comps = [varname for varname in dir (self ) if Compartment .is_compartment (getattr (self , varname ))]
176- maxlen = max (len (c ) for c in comps ) + 5
177- lines = f"[{ self .__class__ .__name__ } ] PATH: { self .name } \n "
178- for c in comps :
179- stats = tensorstats (getattr (self , c ).value )
180- if stats is not None :
181- line = [f"{ k } : { v } " for k , v in stats .items ()]
182- line = ", " .join (line )
183- else :
184- line = "None"
185- lines += f" { f'({ c } )' .ljust (maxlen )} { line } \n "
186- return lines
0 commit comments