11from jax import random , numpy as jnp , jit
2- from ngclearn import resolver , Component , Compartment
2+ from ngcsimlib .compilers .process import transition
3+ from ngcsimlib .component import Component
4+ from ngcsimlib .compartment import Compartment
5+
6+ from ngclearn .utils .weight_distribution import initialize_params
7+ from ngcsimlib .logger import info
38from ngclearn .components .synapses .hebbian import TraceSTDPSynapse
49from ngclearn .utils import tensorstats
510
611class MSTDPETSynapse (TraceSTDPSynapse ): # modulated trace-based STDP w/ eligility traces
712 """
8- A synaptic cable that adjusts its efficacies via trace-based form of
9- three-factor learning, i.e., modulated spike-timing-dependent plasticity
10- (M-STDP) or modulated STDP with eligibility traces (M-STDP-ET).
13+ A synaptic cable that adjusts its efficacies via trace-based form of three-factor learning, i.e., modulated
14+ spike-timing-dependent plasticity (M-STDP) or modulated STDP with eligibility traces (M-STDP-ET).
1115
1216 | --- Synapse Compartments: ---
1317 | inputs - input (takes in external signals)
@@ -20,11 +24,14 @@ class MSTDPETSynapse(TraceSTDPSynapse): # modulated trace-based STDP w/ eligilit
2024 | postSpike - post-synaptic spike to drive 2nd term of STDP update (takes in external signals)
2125 | preTrace - pre-synaptic trace value to drive 1st term of STDP update (takes in external signals)
2226 | postTrace - post-synaptic trace value to drive 2nd term of STDP update (takes in external signals)
23- | dWeights - current delta matrix containing changes to be applied to synaptic efficacies
27+ | dWeights - current delta matrix containing (MS-STDP/MS-STDP-ET) changes to be applied to synaptic efficacies
2428 | eligibility - current state of eligibility trace
25- | eta - global learning rate (multiplier beyond A_plus and A_minus )
29+ | eta - global learning rate (applied to change in weights for final MS-STDP/MS-STDP-ET adjustment )
2630
2731 | References:
32+ | Florian, Răzvan V. "Reinforcement learning through modulation of spike-timing-dependent synaptic plasticity."
33+ | Neural computation 19.6 (2007): 1468-1502.
34+ |
2835 | Morrison, Abigail, Ad Aertsen, and Markus Diesmann. "Spike-timing-dependent
2936 | plasticity in balanced random networks." Neural computation 19.6 (2007): 1437-1467.
3037 |
@@ -66,29 +73,30 @@ class MSTDPETSynapse(TraceSTDPSynapse): # modulated trace-based STDP w/ eligilit
6673 """
6774
6875 # Define Functions
69- def __init__ (self , name , shape , A_plus , A_minus , eta = 1. , mu = 0. ,
70- pretrace_target = 0. , tau_elg = 0. , elg_decay = 1. ,
71- weight_init = None , resist_scale = 1. , p_conn = 1. , w_bound = 1. ,
72- batch_size = 1 , ** kwargs ):
73- super ().__init__ (name , shape , A_plus , A_minus , eta = eta , mu = mu ,
74- pretrace_target = pretrace_target , weight_init = weight_init ,
75- resist_scale = resist_scale , p_conn = p_conn , w_bound = w_bound ,
76- batch_size = batch_size , ** kwargs )
76+ def __init__ (
77+ self , name , shape , A_plus , A_minus , eta = 1. , mu = 0. , pretrace_target = 0. , tau_elg = 0. , elg_decay = 1. ,
78+ weight_init = None , resist_scale = 1. , p_conn = 1. , w_bound = 1. , batch_size = 1 , ** kwargs
79+ ):
80+ super ().__init__ (
81+ name , shape , A_plus , A_minus , eta = eta , mu = mu , pretrace_target = pretrace_target , weight_init = weight_init ,
82+ resist_scale = resist_scale , p_conn = p_conn , w_bound = w_bound , batch_size = batch_size , ** kwargs
83+ )
7784 ## MSTDP/MSTDP-ET meta-parameters
7885 self .tau_elg = tau_elg
7986 self .elg_decay = elg_decay
8087 ## MSTDP/MSTDP-ET compartments
8188 self .modulator = Compartment (jnp .zeros ((self .batch_size , 1 )))
8289 self .eligibility = Compartment (jnp .zeros (shape ))
8390
91+ @transition (output_compartments = ["weights" , "dWeights" , "eligibility" ])
8492 @staticmethod
85- def _evolve (dt , w_bound , preTrace_target , mu , Aplus , Aminus , tau_elg ,
86- elg_decay , preSpike , postSpike , preTrace , postTrace , weights ,
87- eta , modulator , eligibility ):
93+ def evolve (
94+ dt , w_bound , preTrace_target , mu , Aplus , Aminus , tau_elg , elg_decay , preSpike , postSpike , preTrace ,
95+ postTrace , weights , eta , modulator , eligibility
96+ ):
8897 ## compute local synaptic update (via STDP)
8998 dW_dt = TraceSTDPSynapse ._compute_update (
90- dt , w_bound , preTrace_target , mu , Aplus , Aminus ,
91- preSpike , postSpike , preTrace , postTrace , weights
99+ dt , w_bound , preTrace_target , mu , Aplus , Aminus , preSpike , postSpike , preTrace , postTrace , weights
92100 ) ## produce dW/dt (ODE for synaptic change dynamics)
93101 if tau_elg > 0. : ## perform dynamics of M-STDP-ET
94102 ## update eligibility trace given current local update
@@ -107,14 +115,11 @@ def _evolve(dt, w_bound, preTrace_target, mu, Aplus, Aminus, tau_elg,
107115 weights = jnp .clip (weights , eps , w_bound - eps ) # jnp.abs(w_bound))
108116 return weights , dWeights , eligibility
109117
110- @resolver (_evolve )
111- def evolve (self , weights , dWeights , eligibility ):
112- self .weights .set (weights )
113- self .dWeights .set (dWeights )
114- self .eligibility .set (eligibility )
115-
118+ @transition (
119+ output_compartments = ["inputs" , "outputs" , "preSpike" , "postSpike" , "preTrace" , "postTrace" , "dWeights" , "eligibility" ]
120+ )
116121 @staticmethod
117- def _reset (batch_size , shape ):
122+ def reset (batch_size , shape ):
118123 preVals = jnp .zeros ((batch_size , shape [0 ]))
119124 postVals = jnp .zeros ((batch_size , shape [1 ]))
120125 synVals = jnp .zeros (shape )
@@ -126,20 +131,7 @@ def _reset(batch_size, shape):
126131 postTrace = postVals
127132 dWeights = synVals
128133 eligibility = synVals
129- return (inputs , outputs , preSpike , postSpike , preTrace , postTrace ,
130- dWeights , eligibility )
131-
132- @resolver (_reset )
133- def reset (self , inputs , outputs , preSpike , postSpike , preTrace , postTrace ,
134- dWeights , eligibility ):
135- self .inputs .set (inputs )
136- self .outputs .set (outputs )
137- self .preSpike .set (preSpike )
138- self .postSpike .set (postSpike )
139- self .preTrace .set (preTrace )
140- self .postTrace .set (postTrace )
141- self .dWeights .set (dWeights )
142- self .eligibility .set (eligibility )
134+ return inputs , outputs , preSpike , postSpike , preTrace , postTrace , dWeights , eligibility
143135
144136 @classmethod
145137 def help (cls ): ## component help function
0 commit comments