@@ -252,40 +252,39 @@ def advance_state(self, dt):
252252 self .zF .set (zF )
253253
254254 @compilable
255- def reset (self , batch_size , shape ): #n_units
256- _shape = (batch_size , shape [0 ])
257- if len (shape ) > 1 :
258- _shape = (batch_size , shape [0 ], shape [1 ], shape [2 ])
255+ def reset (self ): # , batch_size, shape): #n_units
256+ _shape = (self . batch_size , self . shape [0 ])
257+ if len (self . shape ) > 1 :
258+ _shape = (self . batch_size , self . shape [0 ], self . shape [1 ], self . shape [2 ])
259259 restVals = jnp .zeros (_shape )
260260 self .j .set (restVals )
261261 self .j_td .set (restVals )
262262 self .z .set (restVals )
263263 self .zF .set (restVals )
264264
265-
266- def save (self , directory , ** kwargs ):
267- ## do a protected save of constants, depending on whether they are floats or arrays
268- tau_m = (self .tau_m if isinstance (self .tau_m , float )
269- else jnp .ones ([[self .tau_m ]]))
270- priorLeakRate = (self .priorLeakRate if isinstance (self .priorLeakRate , float )
271- else jnp .ones ([[self .priorLeakRate ]]))
272- resist_scale = (self .resist_scale if isinstance (self .resist_scale , float )
273- else jnp .ones ([[self .resist_scale ]]))
274-
275- file_name = directory + "/" + self .name + ".npz"
276- jnp .savez (file_name ,
277- tau_m = tau_m , priorLeakRate = priorLeakRate ,
278- resist_scale = resist_scale ) #, key=self.key.value)
279-
280- def load (self , directory , seeded = False , ** kwargs ):
281- file_name = directory + "/" + self .name + ".npz"
282- data = jnp .load (file_name )
283- ## constants loaded in
284- self .tau_m = data ['tau_m' ]
285- self .priorLeakRate = data ['priorLeakRate' ]
286- self .resist_scale = data ['resist_scale' ]
287- #if seeded:
288- # self.key.set(data['key'])
265+ # def save(self, directory, **kwargs):
266+ # ## do a protected save of constants, depending on whether they are floats or arrays
267+ # tau_m = (self.tau_m if isinstance(self.tau_m, float)
268+ # else jnp.ones([[self.tau_m]]))
269+ # priorLeakRate = (self.priorLeakRate if isinstance(self.priorLeakRate, float)
270+ # else jnp.ones([[self.priorLeakRate]]))
271+ # resist_scale = (self.resist_scale if isinstance(self.resist_scale, float)
272+ # else jnp.ones([[self.resist_scale]]))
273+ #
274+ # file_name = directory + "/" + self.name + ".npz"
275+ # jnp.savez(file_name,
276+ # tau_m=tau_m, priorLeakRate=priorLeakRate,
277+ # resist_scale=resist_scale) #, key=self.key.value)
278+ #
279+ # def load(self, directory, seeded=False, **kwargs):
280+ # file_name = directory + "/" + self.name + ".npz"
281+ # data = jnp.load(file_name)
282+ # ## constants loaded in
283+ # self.tau_m = data['tau_m']
284+ # self.priorLeakRate = data['priorLeakRate']
285+ # self.resist_scale = data['resist_scale']
286+ # #if seeded:
287+ # # self.key.set(data['key'])
289288
290289 @classmethod
291290 def help (cls ): ## component help function
0 commit comments