11from jax import random , numpy as jnp , jit
2- from ngclearn import resolver , Component , Compartment
2+ from ngcsimlib .compilers .process import transition
3+ from ngcsimlib .component import Component
4+ from ngcsimlib .compartment import Compartment
5+
36from ngclearn .components .synapses import DenseSynapse
47from ngclearn .utils import tensorstats
58
6- def _calc_update (dt , pre , x_pre , post , x_post , W , w_bound = 1. , x_tar = 0.0 , mu = 0. ,
7- Aplus = 1. , Aminus = 0. ):
8- if mu > 0. :
9- ## equations 3, 5, & 6 from Diehl and Cook - full power-law STDP
10- post_shift = jnp .power (w_bound - W , mu )
11- pre_shift = jnp .power (W , mu )
12- dWpost = (post_shift * jnp .matmul ((x_pre - x_tar ).T , post )) * Aplus
13- dWpre = 0.
14- if Aminus > 0. :
15- dWpre = - (pre_shift * jnp .matmul (pre .T , x_post )) * Aminus
16- else :
17- ## calculate post-synaptic term
18- dWpost = jnp .matmul ((x_pre - x_tar ).T , post * Aplus )
19- dWpre = 0.
20- if Aminus > 0. :
21- ## calculate pre-synaptic term
22- dWpre = - jnp .matmul (pre .T , x_post * Aminus )
23- ## calc final weighted adjustment
24- dW = (dWpost + dWpre )
25- return dW
269
2710class TraceSTDPSynapse (DenseSynapse ): # power-law / trace-based STDP
2811 """
@@ -83,9 +66,10 @@ class TraceSTDPSynapse(DenseSynapse): # power-law / trace-based STDP
8366 """
8467
8568 # Define Functions
86- def __init__ (self , name , shape , A_plus , A_minus , eta = 1. , mu = 0. ,
87- pretrace_target = 0. , weight_init = None , resist_scale = 1. ,
88- p_conn = 1. , w_bound = 1. , batch_size = 1 , ** kwargs ):
69+ def __init__ (
70+ self , name , shape , A_plus , A_minus , eta = 1. , mu = 0. , pretrace_target = 0. , weight_init = None , resist_scale = 1. ,
71+ p_conn = 1. , w_bound = 1. , batch_size = 1 , ** kwargs
72+ ):
8973 super ().__init__ (name , shape , weight_init , None , resist_scale ,
9074 p_conn , batch_size = batch_size , ** kwargs )
9175
@@ -109,19 +93,41 @@ def __init__(self, name, shape, A_plus, A_minus, eta=1., mu=0.,
10993 self .eta = Compartment (jnp .ones ((1 , 1 )) * eta ) ## global learning rate
11094
11195 @staticmethod
112- def _compute_update (dt , w_bound , preTrace_target , mu , Aplus , Aminus ,
113- preSpike , postSpike , preTrace , postTrace , weights ):
114- dW = _calc_update (dt , preSpike , preTrace , postSpike , postTrace , weights ,
115- w_bound = w_bound , x_tar = preTrace_target , mu = mu ,
116- Aplus = Aplus , Aminus = Aminus )
96+ def _compute_update (
97+ dt , w_bound , preTrace_target , mu , Aplus , Aminus , preSpike , postSpike , preTrace , postTrace , weights
98+ ):
99+ pre = preSpike
100+ x_pre = preTrace
101+ post = postSpike
102+ x_post = postTrace
103+ W = weights
104+ x_tar = preTrace_target
105+ if mu > 0. :
106+ ## equations 3, 5, & 6 from Diehl and Cook - full power-law STDP
107+ post_shift = jnp .power (w_bound - W , mu )
108+ pre_shift = jnp .power (W , mu )
109+ dWpost = (post_shift * jnp .matmul ((x_pre - x_tar ).T , post )) * Aplus
110+ dWpre = 0.
111+ if Aminus > 0. :
112+ dWpre = - (pre_shift * jnp .matmul (pre .T , x_post )) * Aminus
113+ else :
114+ ## calculate post-synaptic term
115+ dWpost = jnp .matmul ((x_pre - x_tar ).T , post * Aplus )
116+ dWpre = 0.
117+ if Aminus > 0. :
118+ ## calculate pre-synaptic term
119+ dWpre = - jnp .matmul (pre .T , x_post * Aminus )
120+ ## calc final weighted adjustment
121+ dW = (dWpost + dWpre )
117122 return dW
118123
124+ @transition (output_compartments = ["weights" , "dWeights" ])
119125 @staticmethod
120- def _evolve (dt , w_bound , preTrace_target , mu , Aplus , Aminus ,
121- preSpike , postSpike , preTrace , postTrace , weights , eta ):
126+ def evolve (
127+ dt , w_bound , preTrace_target , mu , Aplus , Aminus , preSpike , postSpike , preTrace , postTrace , weights , eta
128+ ):
122129 dWeights = TraceSTDPSynapse ._compute_update (
123- dt , w_bound , preTrace_target , mu , Aplus , Aminus ,
124- preSpike , postSpike , preTrace , postTrace , weights
130+ dt , w_bound , preTrace_target , mu , Aplus , Aminus , preSpike , postSpike , preTrace , postTrace , weights
125131 )
126132 ## do a gradient ascent update/shift
127133 weights = weights + dWeights * eta
@@ -130,13 +136,9 @@ def _evolve(dt, w_bound, preTrace_target, mu, Aplus, Aminus,
130136 weights = jnp .clip (weights , eps , w_bound - eps ) # jnp.abs(w_bound))
131137 return weights , dWeights
132138
133- @resolver (_evolve )
134- def evolve (self , weights , dWeights ):
135- self .weights .set (weights )
136- self .dWeights .set (dWeights )
137-
139+ @transition (output_compartments = ["inputs" , "outputs" , "preSpike" , "postSpike" , "preTrace" , "postTrace" , "dWeights" ])
138140 @staticmethod
139- def _reset (batch_size , shape ):
141+ def reset (batch_size , shape ):
140142 preVals = jnp .zeros ((batch_size , shape [0 ]))
141143 postVals = jnp .zeros ((batch_size , shape [1 ]))
142144 inputs = preVals
@@ -148,16 +150,6 @@ def _reset(batch_size, shape):
148150 dWeights = jnp .zeros (shape )
149151 return inputs , outputs , preSpike , postSpike , preTrace , postTrace , dWeights
150152
151- @resolver (_reset )
152- def reset (self , inputs , outputs , preSpike , postSpike , preTrace , postTrace , dWeights ):
153- self .inputs .set (inputs )
154- self .outputs .set (outputs )
155- self .preSpike .set (preSpike )
156- self .postSpike .set (postSpike )
157- self .preTrace .set (preTrace )
158- self .postTrace .set (postTrace )
159- self .dWeights .set (dWeights )
160-
161153 @classmethod
162154 def help (cls ): ## component help function
163155 properties = {
0 commit comments