@@ -83,7 +83,7 @@ def __init__(
8383 if weight_mask is None :
8484 self .weight_mask = jnp .ones ((1 , 1 ))
8585 else :
86- self .weight_mask = self . weight_mask
86+ self .weight_mask = weight_mask
8787
8888 self .weights .set (self .weights .get () * self .weight_mask )
8989
@@ -95,27 +95,27 @@ def __init__(
9595 self .preTrace = Compartment (preVals )
9696 self .postTrace = Compartment (postVals )
9797 self .dWeights = Compartment (self .weights .get () * 0 )
98- self .eta = jnp . ones (( 1 , 1 )) * eta ## global learning rate
98+ self .eta = eta ## global learning rate
9999
100100 def _compute_update (self ):
101101 if self .mu > 0. :
102102 post_shift = jnp .power (self .w_bound - self .weights .get (), self .mu )
103103 pre_shift = jnp .power (self .weights .get (), self .mu )
104- dWpost = (post_shift * jnp .matmul ((self .preSpike .get () - self .preTrace_target ).T , self .postSpike .get ())) * self .Aplus
104+ dWpost = (post_shift * jnp .matmul ((self .preTrace .get () - self .preTrace_target ).T , self .postSpike .get ())) * self .Aplus
105105
106106 if self .Aminus > 0. :
107107 dWpre = - (pre_shift * jnp .matmul (self .preSpike .get ().T , self .postTrace .get ())) * self .Aminus
108108 else :
109109 dWpre = 0.
110110
111111 else :
112- dWpost = jnp .matmul ((self .preSpike .get () - self .preTrace_target ).T , self .postSpike .get () * self .Aplus )
112+ dWpost = jnp .matmul ((self .preTrace .get () - self .preTrace_target ).T , self .postSpike .get () * self .Aplus )
113113 if self .Aminus > 0. :
114114 dWpre = - jnp .matmul (self .preSpike .get ().T , self .postTrace .get () * self .Aminus )
115115 else :
116116 dWpre = 0.
117117
118- dW = (dWpost - dWpre )
118+ dW = (dWpost + dWpre )
119119 return dW
120120
121121 @compilable
0 commit comments