1- from jax import numpy as jnp , jit
2- from ngcsimlib .compilers .process import transition
3- from ngcsimlib .component import Component
1+ from jax import random , numpy as jnp , jit
42from ngcsimlib .compartment import Compartment
3+ from ngcsimlib .parser import compilable
54
6- from ngclearn .components .synapses import DenseSynapse
7- from ngclearn .utils import tensorstats
5+ from ngclearn .components .synapses .denseSynapse import DenseSynapse
86
97class EventSTDPSynapse (DenseSynapse ): # event-driven, post-synaptic STDP
108 """
@@ -57,11 +55,11 @@ class EventSTDPSynapse(DenseSynapse): # event-driven, post-synaptic STDP
5755 """
5856
5957 # Define Functions
60- def __init__ (self , name , shape , eta , lmbda = 0.01 , A_plus = 1. , A_minus = 1. ,
61- presyn_win_len = 2 . , w_bound = 1. , weight_init = None , resist_scale = 1. ,
62- p_conn = 1. , batch_size = 1 , ** kwargs ):
63- super (). __init__ ( name , shape , weight_init , None , resist_scale , p_conn ,
64- batch_size = batch_size , ** kwargs )
58+ def __init__ (
59+ self , name , shape , eta , lmbda = 0.01 , A_plus = 1 . , A_minus = 1. , presyn_win_len = 2. , w_bound = 1. , weight_init = None ,
60+ resist_scale = 1. , p_conn = 1. , batch_size = 1 , ** kwargs
61+ ):
62+ super (). __init__ ( name , shape , weight_init , None , resist_scale , p_conn , batch_size = batch_size , ** kwargs )
6563
6664 ## Synaptic hyper-parameters
6765 self .eta = eta ## global learning rate governing plasticity
@@ -78,53 +76,47 @@ def __init__(self, name, shape, eta, lmbda=0.01, A_plus=1., A_minus=1.,
7876 postVals = jnp .zeros ((self .batch_size , shape [1 ]))
7977 self .pre_tols = Compartment (preVals )
8078 self .postSpike = Compartment (postVals )
81- self .dWeights = Compartment (self .weights .value * 0 )
79+ self .dWeights = Compartment (self .weights .get () * 0 )
8280 self .eta = Compartment (jnp .ones ((1 , 1 )) * eta ) ## global learning rate governing plasticity
8381
84- @staticmethod
85- def _compute_update (
86- t , lmbda , presyn_win_len , Aminus , Aplus , w_bound , pre_tols , postSpike , weights
87- ): ## synaptic adjustment calculation co-routine
82+ def _compute_update (self , t , dt ): ## synaptic adjustment calculation co-routine
8883 ## check if a spike occurred in window of (t - presyn_win_len, t]
89- m = (pre_tols > 0. ) * 1. ## ignore default value of tols = 0 ms
90- if presyn_win_len > 0. :
91- lbound = ((t - presyn_win_len ) < pre_tols ) * 1.
84+ m = (self . pre_tols . get () > 0. ) * 1. ## ignore default value of tols = 0 ms
85+ if self . presyn_win_len > 0. :
86+ lbound = ((t - self . presyn_win_len ) < self . pre_tols . get () ) * 1.
9287 preSpike = lbound * m
9388 else :
94- check_spike = (pre_tols == t ) * 1.
89+ check_spike = (self . pre_tols . get () == t ) * 1.
9590 preSpike = check_spike * m
9691 ## this implements a generalization of the rule in eqn 18 of the paper
97- pos_shift = w_bound - (weights * (1. + lmbda ))
98- pos_shift = pos_shift * Aplus
99- neg_shift = - weights * (1. + lmbda )
100- neg_shift = neg_shift * Aminus
92+ pos_shift = self . w_bound - (self . weights . get () * (1. + self . lmbda ))
93+ pos_shift = pos_shift * self . Aplus
94+ neg_shift = - self . weights . get () * (1. + self . lmbda )
95+ neg_shift = neg_shift * self . Aminus
10196 dW = jnp .where (preSpike .T , pos_shift , neg_shift ) # at pre-spikes => LTP, else decay
102- dW = (dW * postSpike ) ## gate to make sure only post-spikes trigger updates
97+ dW = (dW * self . postSpike . get () ) ## gate to make sure only post-spikes trigger updates
10398 return dW
10499
105- @transition (output_compartments = ["weights" , "dWeights" ])
106- @staticmethod
107- def evolve (
108- t , lmbda , presyn_win_len , Aminus , Aplus , w_bound , pre_tols , postSpike , weights , eta
109- ):
110- dWeights = EventSTDPSynapse ._compute_update (
111- t , lmbda , presyn_win_len , Aminus , Aplus , w_bound , pre_tols , postSpike , weights
112- )
113- weights = weights + dWeights * eta # * (1. - w) * eta
114- weights = jnp .clip (weights , 0.01 , w_bound ) ## Note: this step not in source paper
115- return weights , dWeights
116-
117- @transition (output_compartments = ["inputs" , "outputs" , "pre_tols" , "postSpike" , "dWeights" ])
118- @staticmethod
119- def reset (batch_size , shape ):
120- preVals = jnp .zeros ((batch_size , shape [0 ]))
121- postVals = jnp .zeros ((batch_size , shape [1 ]))
122- inputs = preVals
123- outputs = postVals
124- pre_tols = preVals ## pre-synaptic time-of-last-spike(s) record
125- postSpike = postVals
126- dWeights = jnp .zeros (shape )
127- return inputs , outputs , pre_tols , postSpike , dWeights
100+ @compilable
101+ def evolve (self , t , dt ):
102+ dWeights = self ._compute_update (t , dt )
103+ weights = self .weights .get () + dWeights * self .eta # * (1. - w) * eta
104+ weights = jnp .clip (weights , 0.01 , self .w_bound ) ## Note: this step not in source paper
105+
106+ self .weights .set (weights )
107+ self .dWeights .set (dWeights )
108+
109+ @compilable
110+ def reset (self ):
111+ preVals = jnp .zeros ((self .batch_size .get (), self .shape .get ()[0 ]))
112+ postVals = jnp .zeros ((self .batch_size .get (), self .shape .get ()[1 ]))
113+
114+ if not self .inputs .targeted :
115+ self .inputs .set (preVals )
116+ self .outputs .set (postVals )
117+ self .pre_tols .set (preVals ) ## pre-synaptic time-of-last-spike(s) record
118+ self .postSpike .set (postVals )
119+ self .dWeights .set (jnp .zeros (self .shape .get ()))
128120
129121 @classmethod
130122 def help (cls ): ## component help function
@@ -166,20 +158,6 @@ def help(cls): ## component help function
166158 "hyperparameters" : hyperparams }
167159 return info
168160
169- def __repr__ (self ):
170- comps = [varname for varname in dir (self ) if Compartment .is_compartment (getattr (self , varname ))]
171- maxlen = max (len (c ) for c in comps ) + 5
172- lines = f"[{ self .__class__ .__name__ } ] PATH: { self .name } \n "
173- for c in comps :
174- stats = tensorstats (getattr (self , c ).value )
175- if stats is not None :
176- line = [f"{ k } : { v } " for k , v in stats .items ()]
177- line = ", " .join (line )
178- else :
179- line = "None"
180- lines += f" { f'({ c } )' .ljust (maxlen )} { line } \n "
181- return lines
182-
183161if __name__ == '__main__' :
184162 from ngcsimlib .context import Context
185163 with Context ("Bar" ) as bar :
0 commit comments