11from jax import numpy as jnp , jit
2- from ngclearn import resolver , Component , Compartment
3- from ngclearn .utils import tensorstats
4- ## import parent synapse class/component
2+ from ngcsimlib .compilers .process import transition
3+ from ngcsimlib .component import Component
4+ from ngcsimlib .compartment import Compartment
5+
56from ngclearn .components .synapses import DenseSynapse
7+ from ngclearn .utils import tensorstats
68
79class EventSTDPSynapse (DenseSynapse ): # event-driven, post-synaptic STDP
810 """
@@ -80,8 +82,9 @@ def __init__(self, name, shape, eta, lmbda=0.01, A_plus=1., A_minus=1.,
8082 self .eta = Compartment (jnp .ones ((1 , 1 )) * eta ) ## global learning rate governing plasticity
8183
8284 @staticmethod
83- def _compute_update (t , lmbda , presyn_win_len , Aminus , Aplus , w_bound , pre_tols ,
84- postSpike , weights ):
85+ def _compute_update (
86+ t , lmbda , presyn_win_len , Aminus , Aplus , w_bound , pre_tols , postSpike , weights
87+ ): ## synaptic adjustment calculation co-routine
8588 ## check if a spike occurred in window of (t - presyn_win_len, t]
8689 m = (pre_tols > 0. ) * 1. ## ignore default value of tols = 0 ms
8790 if presyn_win_len > 0. :
@@ -99,40 +102,30 @@ def _compute_update(t, lmbda, presyn_win_len, Aminus, Aplus, w_bound, pre_tols,
99102 dW = (dW * postSpike ) ## gate to make sure only post-spikes trigger updates
100103 return dW
101104
105+ @transition (output_compartments = ["weights" , "dWeights" ])
102106 @staticmethod
103- def _evolve (t , lmbda , presyn_win_len , Aminus , Aplus , w_bound , pre_tols ,
104- postSpike , weights , eta ):
107+ def evolve (
108+ t , lmbda , presyn_win_len , Aminus , Aplus , w_bound , pre_tols , postSpike , weights , eta
109+ ):
105110 dWeights = EventSTDPSynapse ._compute_update (
106111 t , lmbda , presyn_win_len , Aminus , Aplus , w_bound , pre_tols , postSpike , weights
107112 )
108113 weights = weights + dWeights * eta # * (1. - w) * eta
109- weights = jnp .clip (weights , 0.01 , w_bound ) # not in source paper
114+ weights = jnp .clip (weights , 0.01 , w_bound ) ## Note: this step not in source paper
110115 return weights , dWeights
111116
112- @resolver (_evolve )
113- def evolve (self , weights , dWeights ):
114- self .weights .set (weights )
115- self .dWeights .set (dWeights )
116-
117+ @transition (output_compartments = ["inputs" , "outputs" , "pre_tols" , "postSpike" , "dWeights" ])
117118 @staticmethod
118- def _reset (batch_size , shape ):
119+ def reset (batch_size , shape ):
119120 preVals = jnp .zeros ((batch_size , shape [0 ]))
120121 postVals = jnp .zeros ((batch_size , shape [1 ]))
121122 inputs = preVals
122123 outputs = postVals
123- pre_tols = preVals ## pre-synaptic time-of-last-spike record
124+ pre_tols = preVals ## pre-synaptic time-of-last-spike(s) record
124125 postSpike = postVals
125126 dWeights = jnp .zeros (shape )
126127 return inputs , outputs , pre_tols , postSpike , dWeights
127128
128- @resolver (_reset )
129- def reset (self , inputs , outputs , pre_tols , postSpike , dWeights ):
130- self .inputs .set (inputs )
131- self .outputs .set (outputs )
132- self .pre_tols .set (pre_tols )
133- self .postSpike .set (postSpike )
134- self .dWeights .set (dWeights )
135-
136129 @classmethod
137130 def help (cls ): ## component help function
138131 properties = {
0 commit comments