11from jax import random , numpy as jnp , jit
2- from ngcsimlib .compilers .process import transition
3- from ngcsimlib .component import Component
42from ngcsimlib .compartment import Compartment
3+ from ngcsimlib .parser import compilable
54
6- from ngclearn .components .synapses import DenseSynapse
7- from ngclearn .utils import tensorstats
5+ from ngclearn .components .synapses .denseSynapse import DenseSynapse
86
97class BCMSynapse (DenseSynapse ): # BCM-adjusted synaptic cable
108 """
@@ -71,8 +69,7 @@ def __init__(
7169 self , name , shape , tau_w , tau_theta , theta0 = - 1. , w_bound = 0. , w_decay = 0. , weight_init = None , resist_scale = 1. ,
7270 p_conn = 1. , batch_size = 1 , ** kwargs
7371 ):
74- super ().__init__ (name , shape , weight_init , None , resist_scale , p_conn ,
75- batch_size = batch_size , ** kwargs )
72+ super ().__init__ (name , shape , weight_init , None , resist_scale , p_conn , batch_size = batch_size , ** kwargs )
7673
7774 ## Synapse and BCM hyper-parameters
7875 self .shape = shape ## shape of synaptic efficacy matrix
@@ -90,48 +87,51 @@ def __init__(
9087 self .post = Compartment (postVals ) ## post-synaptic statistic
9188 self .post_term = Compartment (postVals )
9289 self .theta = Compartment (postVals + self .theta0 ) ## synaptic modification thresholds
93- self .dWeights = Compartment (self .weights .value * 0 )
90+ self .dWeights = Compartment (self .weights .get () * 0 )
9491
95- @transition (output_compartments = ["weights" , "theta" , "dWeights" , "post_term" ])
96- @staticmethod
97- def evolve (t , dt , tau_w , tau_theta , w_bound , w_decay , pre , post , theta , weights ):
92+ @compilable
93+ def evolve (self , t , dt ): #t, dt, tau_w, tau_theta, w_bound, w_decay, pre, post, theta, weights):
9894 eps = 1e-7
99- post_term = post * (post - theta ) # post - theta
100- post_term = post_term * (1. / (theta + eps ))
101- dWeights = jnp .matmul (pre .T , post_term )
102- if w_bound > 0. :
103- dWeights = dWeights * (w_bound - jnp .abs (weights ))
95+ post_term = self . post . get () * (self . post . get () - self . theta . get () ) # post - theta
96+ post_term = post_term * (1. / (self . theta . get () + eps ))
97+ dWeights = jnp .matmul (self . pre . get () .T , post_term )
98+ if self . w_bound > 0. :
99+ dWeights = dWeights * (self . w_bound - jnp .abs (self . weights . get () ))
104100 ## update synaptic efficacies according to a leaky ODE
105- dWeights = - weights * w_decay + dWeights
106- _W = weights + dWeights * dt / tau_w
101+ dWeights = - self . weights . get () * self . w_decay + dWeights
102+ _W = self . weights . get () + dWeights * dt / self . tau_w
107103 ## update synaptic modification threshold as a leaky ODE
108- dtheta = jnp .mean (jnp .square (post ), axis = 0 , keepdims = True ) ## batch avg
109- theta = theta + (- theta + dtheta ) * dt / tau_theta
110- return weights , theta , dWeights , post_term
111-
112- @transition (output_compartments = ["inputs" , "outputs" , "pre" , "post" , "dWeights" , "post_term" ])
113- @staticmethod
114- def reset (batch_size , shape ):
115- preVals = jnp .zeros ((batch_size , shape [0 ]))
116- postVals = jnp .zeros ((batch_size , shape [1 ]))
117- inputs = preVals
118- outputs = postVals
119- pre = preVals
120- post = postVals
121- dWeights = jnp .zeros (shape )
122- post_term = postVals
123- return inputs , outputs , pre , post , dWeights , post_term
124-
125- def save (self , directory , ** kwargs ):
126- file_name = directory + "/" + self .name + ".npz"
127- jnp .savez (file_name ,
128- weights = self .weights .value , theta = self .theta .value )
129-
130- def load (self , directory , ** kwargs ):
131- file_name = directory + "/" + self .name + ".npz"
132- data = jnp .load (file_name )
133- self .weights .set (data ['weights' ])
134- self .theta .set (data ['theta' ])
104+ dtheta = jnp .mean (jnp .square (self .post .get ()), axis = 0 , keepdims = True ) ## batch avg
105+ theta = self .theta .get () + (- self .theta .get () + dtheta ) * dt / self .tau_theta
106+
107+ #self.weights.set(weights)
108+ self .theta .set (theta )
109+ self .dWeights .set (dWeights )
110+ self .post_term .set (post_term )
111+
112+ @compilable
113+ def reset (self ):
114+ preVals = jnp .zeros ((self .batch_size .get (), self .shape .get ()[0 ]))
115+ postVals = jnp .zeros ((self .batch_size .get (), self .shape .get ()[1 ]))
116+
117+ if not self .inputs .targeted :
118+ self .inputs .set (preVals )
119+ self .outputs .set (postVals )
120+ self .pre .set (preVals )
121+ self .post .set (postVals )
122+ self .dWeights .set (jnp .zeros (self .shape .get ()))
123+ self .post_term .set (postVals )
124+
125+ # def save(self, directory, **kwargs):
126+ # file_name = directory + "/" + self.name + ".npz"
127+ # jnp.savez(file_name,
128+ # weights=self.weights.value, theta=self.theta.value)
129+ #
130+ # def load(self, directory, **kwargs):
131+ # file_name = directory + "/" + self.name + ".npz"
132+ # data = jnp.load(file_name)
133+ # self.weights.set(data['weights'])
134+ # self.theta.set(data['theta'])
135135
136136 @classmethod
137137 def help (cls ): ## component help function
@@ -175,21 +175,6 @@ def help(cls): ## component help function
175175 "hyperparameters" : hyperparams }
176176 return info
177177
178- def __repr__ (self ):
179- comps = [varname for varname in dir (self ) if Compartment .is_compartment (getattr (self , varname ))]
180- maxlen = max (len (c ) for c in comps ) + 5
181- lines = f"[{ self .__class__ .__name__ } ] PATH: { self .name } \n "
182- for c in comps :
183- stats = tensorstats (getattr (self , c ).value )
184- if stats is not None :
185- line = [f"{ k } : { v } " for k , v in stats .items ()]
186- line = ", " .join (line )
187- else :
188- line = "None"
189- lines += f" { f'({ c } )' .ljust (maxlen )} { line } \n "
190- return lines
191-
192-
193178if __name__ == '__main__' :
194179 from ngcsimlib .context import Context
195180 with Context ("Bar" ) as bar :
0 commit comments