File tree Expand file tree Collapse file tree 2 files changed +20
-0
lines changed
ngclearn/components/synapses/patched Expand file tree Collapse file tree 2 files changed +20
-0
lines changed Original file line number Diff line number Diff line change @@ -276,6 +276,17 @@ def evolve(self):
276276 self .dWeights .set (dWeights )
277277 self .dBiases .set (dBiases )
278278
279+ @compilable
280+ def reset (self , batch_size , shape ):
281+ preVals = jnp .zeros ((batch_size , shape [0 ]))
282+ postVals = jnp .zeros ((batch_size , shape [1 ]))
283+ self .inputs .set (preVals ) # inputs
284+ self .outputs .set (postVals ) # outputs
285+ self .pre .set (preVals ) # pre
286+ self .post .set (postVals ) # post
287+ self .dWeights .set (jnp .zeros (shape )) # dW
288+ self .dBiases .set (jnp .zeros (shape [1 ])) # db
289+
279290 @classmethod
280291 def help (cls ): ## component help function
281292 properties = {
Original file line number Diff line number Diff line change @@ -153,6 +153,15 @@ def advance_state(self):
153153 # Update compartment
154154 self .outputs .set (outputs )
155155
156+ @compilable
157+ def reset (self , batch_size , shape ):
158+ preVals = jnp .zeros ((batch_size , shape [0 ]))
159+ postVals = jnp .zeros ((batch_size , shape [1 ]))
160+ inputs = preVals
161+ outputs = postVals
162+ self .inputs .set (inputs )
163+ self .outputs .set (outputs )
164+
156165 def save (self , directory , ** kwargs ):
157166 file_name = directory + "/" + self .name + ".npz"
158167 if self .bias_init != None :
You can’t perform that action at this time.
0 commit comments