Skip to content

Commit d2d4331

Browse files
committed
Fixed an execution bug
1 parent ea724e9 commit d2d4331

File tree

2 files changed

+6
-7
lines changed

2 files changed

+6
-7
lines changed

ngclearn/components/jaxComponent.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import time
22

3-
from typing import Union
3+
from typing import Union, Dict, Any
44
import jax
55
from jax import numpy as jnp
66
from jax import random
@@ -25,7 +25,6 @@ def __init__(self, name: str, key: Union[jax.Array, None] = None):
2525
self.key = Compartment(
2626
random.PRNGKey(time.time_ns()) if key is None else key)
2727

28-
2928
def save(self, directory: str):
3029
"""
3130
The default save method for JaxComponents, it stores the values of all

ngclearn/components/synapses/hebbian/traceSTDPSynapse.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def __init__(
8383
if weight_mask is None:
8484
self.weight_mask = jnp.ones((1, 1))
8585
else:
86-
self.weight_mask = self.weight_mask
86+
self.weight_mask = weight_mask
8787

8888
self.weights.set(self.weights.get() * self.weight_mask)
8989

@@ -95,27 +95,27 @@ def __init__(
9595
self.preTrace = Compartment(preVals)
9696
self.postTrace = Compartment(postVals)
9797
self.dWeights = Compartment(self.weights.get() * 0)
98-
self.eta = jnp.ones((1, 1)) * eta ## global learning rate
98+
self.eta = eta ## global learning rate
9999

100100
def _compute_update(self):
101101
if self.mu > 0.:
102102
post_shift = jnp.power(self.w_bound - self.weights.get(), self.mu)
103103
pre_shift = jnp.power(self.weights.get(), self.mu)
104-
dWpost = (post_shift * jnp.matmul((self.preSpike.get() - self.preTrace_target).T, self.postSpike.get())) * self.Aplus
104+
dWpost = (post_shift * jnp.matmul((self.preTrace.get() - self.preTrace_target).T, self.postSpike.get())) * self.Aplus
105105

106106
if self.Aminus > 0.:
107107
dWpre = -(pre_shift * jnp.matmul(self.preSpike.get().T, self.postTrace.get())) * self.Aminus
108108
else:
109109
dWpre = 0.
110110

111111
else:
112-
dWpost = jnp.matmul((self.preSpike.get() - self.preTrace_target).T, self.postSpike.get() * self.Aplus)
112+
dWpost = jnp.matmul((self.preTrace.get() - self.preTrace_target).T, self.postSpike.get() * self.Aplus)
113113
if self.Aminus > 0.:
114114
dWpre = -jnp.matmul(self.preSpike.get().T, self.postTrace.get() * self.Aminus)
115115
else:
116116
dWpre = 0.
117117

118-
dW = (dWpost - dWpre)
118+
dW = (dWpost + dWpre)
119119
return dW
120120

121121
@compilable

0 commit comments

Comments
 (0)