22from functools import partial
33from ngclearn .utils .optim import get_opt_init_fn , get_opt_step_fn
44from ngclearn import resolver , Component , Compartment
5+ from ngcsimlib .compilers .process import transition
56from ngclearn .components .synapses import DenseSynapse
67from ngclearn .utils import tensorstats
78from ngcsimlib .deprecators import deprecate_args
@@ -216,8 +217,9 @@ def _compute_update(w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda
216217 post_wght = post_wght )
217218 return dW , db
218219
220+ @transition (output_compartments = ["opt_params" , "weights" , "biases" , "dWeights" , "dBiases" ])
219221 @staticmethod
220- def _evolve (opt , w_bound , is_nonnegative , sign_value , prior_type , prior_lmbda , pre_wght ,
222+ def evolve (opt , w_bound , is_nonnegative , sign_value , prior_type , prior_lmbda , pre_wght ,
221223 post_wght , bias_init , pre , post , weights , biases , opt_params ):
222224 ## calculate synaptic update values
223225 dWeights , dBiases = HebbianSynapse ._compute_update (
@@ -234,16 +236,9 @@ def _evolve(opt, w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda, p
234236 weights = _enforce_constraints (weights , w_bound , is_nonnegative = is_nonnegative )
235237 return opt_params , weights , biases , dWeights , dBiases
236238
237- @resolver (_evolve )
238- def evolve (self , opt_params , weights , biases , dWeights , dBiases ):
239- self .opt_params .set (opt_params )
240- self .weights .set (weights )
241- self .biases .set (biases )
242- self .dWeights .set (dWeights )
243- self .dBiases .set (dBiases )
244-
239+ @transition (output_compartments = ["inputs" , "outputs" , "pre" , "post" , "dWeights" , "dBiases" ])
245240 @staticmethod
246- def _reset (batch_size , shape ):
241+ def reset (batch_size , shape ):
247242 preVals = jnp .zeros ((batch_size , shape [0 ]))
248243 postVals = jnp .zeros ((batch_size , shape [1 ]))
249244 return (
@@ -255,15 +250,6 @@ def _reset(batch_size, shape):
255250 jnp .zeros (shape [1 ]), # db
256251 )
257252
258- @resolver (_reset )
259- def reset (self , inputs , outputs , pre , post , dWeights , dBiases ):
260- self .inputs .set (inputs )
261- self .outputs .set (outputs )
262- self .pre .set (pre )
263- self .post .set (post )
264- self .dWeights .set (dWeights )
265- self .dBiases .set (dBiases )
266-
267253 @classmethod
268254 def help (cls ): ## component help function
269255 properties = {
0 commit comments