Skip to content

Commit de10028

Browse files
committed
update patched synapse reset
1 parent a685fcd commit de10028

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

ngclearn/components/synapses/patched/hebbianPatchedSynapse.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,17 @@ def evolve(self):
276276
self.dWeights.set(dWeights)
277277
self.dBiases.set(dBiases)
278278

279+
@compilable
280+
def reset(self, batch_size, shape):
281+
preVals = jnp.zeros((batch_size, shape[0]))
282+
postVals = jnp.zeros((batch_size, shape[1]))
283+
self.inputs.set(preVals) # inputs
284+
self.outputs.set(postVals) # outputs
285+
self.pre.set(preVals) # pre
286+
self.post.set(postVals) # post
287+
self.dWeights.set(jnp.zeros(shape)) # dW
288+
self.dBiases.set(jnp.zeros(shape[1])) # db
289+
279290
@classmethod
280291
def help(cls): ## component help function
281292
properties = {

ngclearn/components/synapses/patched/patchedSynapse.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,15 @@ def advance_state(self):
153153
# Update compartment
154154
self.outputs.set(outputs)
155155

156+
@compilable
157+
def reset(self, batch_size, shape):
158+
preVals = jnp.zeros((batch_size, shape[0]))
159+
postVals = jnp.zeros((batch_size, shape[1]))
160+
inputs = preVals
161+
outputs = postVals
162+
self.inputs.set(inputs)
163+
self.outputs.set(outputs)
164+
156165
def save(self, directory, **kwargs):
157166
file_name = directory + "/" + self.name + ".npz"
158167
if self.bias_init != None:

0 commit comments

Comments
 (0)