11from jax import random , numpy as jnp , jit
2- from ngclearn import resolver , Component , Compartment
2+ from ngcsimlib .compilers .process import transition
3+ from ngcsimlib .component import Component
4+ from ngcsimlib .compartment import Compartment
5+
36from ngclearn .components .synapses import DenseSynapse
47from ngclearn .utils import tensorstats
58
6- def _calc_update (dt , pre , x_pre , post , x_post , W , w_bound = 1. , x_tar = 0.7 ,
7- exp_beta = 1. , Aplus = 1. , Aminus = 0. ): ## internal dynamics method
8- ## equations 4 from Diehl and Cook - full exponential weight-dependent STDP
9- ## calculate post-synaptic term
10- post_term1 = jnp .exp (- exp_beta * W ) * jnp .matmul (x_pre .T , post )
11- x_tar_vec = x_pre * 0 + x_tar # need to broadcast scalar x_tar to mat/vec form
12- post_term2 = jnp .exp (- exp_beta * (w_bound - W )) * jnp .matmul (x_tar_vec .T ,
13- post )
14- dWpost = (post_term1 - post_term2 ) * Aplus
15- ## calculate pre-synaptic term
16- dWpre = 0.
17- if Aminus > 0. :
18- dWpre = - jnp .exp (- exp_beta * W ) * jnp .matmul (pre .T , x_post ) * Aminus
19- ## calc final weighted adjustment
20- dW = (dWpost + dWpre )
21- return dW
22-
239class ExpSTDPSynapse (DenseSynapse ):
2410 """
2511 A synaptic cable that adjusts its efficacies via trace-based form of
@@ -78,9 +64,10 @@ class ExpSTDPSynapse(DenseSynapse):
7864 """
7965
8066 # Define Functions
81- def __init__ (self , name , shape , A_plus , A_minus , exp_beta , eta = 1. ,
82- pretrace_target = 0. , weight_init = None , resist_scale = 1. ,
83- p_conn = 1. , w_bound = 1. , batch_size = 1 , ** kwargs ):
67+ def __init__ (
68+ self , name , shape , A_plus , A_minus , exp_beta , eta = 1. , pretrace_target = 0. , weight_init = None , resist_scale = 1. ,
69+ p_conn = 1. , w_bound = 1. , batch_size = 1 , ** kwargs
70+ ):
8471 super ().__init__ (name , shape , weight_init , None , resist_scale ,
8572 p_conn , batch_size = batch_size , ** kwargs )
8673
@@ -105,16 +92,36 @@ def __init__(self, name, shape, A_plus, A_minus, exp_beta, eta=1.,
10592 self .eta = Compartment (jnp .ones ((1 , 1 )) * eta ) ## global learning rate governing plasticity
10693
10794 @staticmethod
108- def _compute_update (dt , w_bound , preTrace_target , exp_beta , Aplus , Aminus ,
109- preSpike , postSpike , preTrace , postTrace , weights ):
110- dW = _calc_update (dt , preSpike , preTrace , postSpike , postTrace , weights ,
111- w_bound = w_bound , x_tar = preTrace_target , exp_beta = exp_beta ,
112- Aplus = Aplus , Aminus = Aminus )
95+ def _compute_update (
96+ dt , w_bound , preTrace_target , exp_beta , Aplus , Aminus , preSpike , postSpike , preTrace , postTrace , weights
97+ ):
98+ pre = preSpike
99+ x_pre = preTrace
100+ post = postSpike
101+ x_post = postTrace
102+ W = weights
103+ x_tar = preTrace_target
104+ ## equations 4 from Diehl and Cook - full exponential weight-dependent STDP
105+ ## calculate post-synaptic term
106+ post_term1 = jnp .exp (- exp_beta * W ) * jnp .matmul (x_pre .T , post )
107+ x_tar_vec = x_pre * 0 + x_tar # need to broadcast scalar x_tar to mat/vec form
108+ post_term2 = jnp .exp (- exp_beta * (w_bound - W )) * jnp .matmul (x_tar_vec .T ,
109+ post )
110+ dWpost = (post_term1 - post_term2 ) * Aplus
111+ ## calculate pre-synaptic term
112+ dWpre = 0.
113+ if Aminus > 0. :
114+ dWpre = - jnp .exp (- exp_beta * W ) * jnp .matmul (pre .T , x_post ) * Aminus
115+ ## calc final weighted adjustment
116+ dW = (dWpost + dWpre )
113117 return dW
114118
119+ @transition (output_compartments = ["weights" , "dWeights" ])
115120 @staticmethod
116- def _evolve (dt , w_bound , preTrace_target , exp_beta , Aplus , Aminus ,
117- preSpike , postSpike , preTrace , postTrace , weights , eta ):
121+ def evolve (
122+ dt , w_bound , preTrace_target , exp_beta , Aplus , Aminus , preSpike , postSpike , preTrace , postTrace ,
123+ weights , eta
124+ ):
118125 dW = ExpSTDPSynapse ._compute_update (
119126 dt , w_bound , preTrace_target , exp_beta , Aplus , Aminus ,
120127 preSpike , postSpike , preTrace , postTrace , weights
@@ -126,13 +133,9 @@ def _evolve(dt, w_bound, preTrace_target, exp_beta, Aplus, Aminus,
126133 _W = jnp .clip (_W , eps , w_bound - eps )
127134 return weights , dW
128135
129- @resolver (_evolve )
130- def evolve (self , weights , dWeights ):
131- self .weights .set (weights )
132- self .dWeights .set (dWeights )
133-
136+ @transition (output_compartments = ["inputs" , "outputs" , "preSpike" , "postSpike" , "preTrace" , "postTrace" , "dWeights" ])
134137 @staticmethod
135- def _reset (batch_size , shape ):
138+ def reset (batch_size , shape ):
136139 preVals = jnp .zeros ((batch_size , shape [0 ]))
137140 postVals = jnp .zeros ((batch_size , shape [1 ]))
138141 inputs = preVals
@@ -144,16 +147,6 @@ def _reset(batch_size, shape):
144147 dWeights = jnp .zeros (shape )
145148 return inputs , outputs , preSpike , postSpike , preTrace , postTrace , dWeights
146149
147- @resolver (_reset )
148- def reset (self , inputs , outputs , preSpike , postSpike , preTrace , postTrace , dWeights ):
149- self .inputs .set (inputs )
150- self .outputs .set (outputs )
151- self .preSpike .set (preSpike )
152- self .postSpike .set (postSpike )
153- self .preTrace .set (preTrace )
154- self .postTrace .set (postTrace )
155- self .dWeights .set (dWeights )
156-
157150 @classmethod
158151 def help (cls ): ## component help function
159152 properties = {
0 commit comments