11from jax import random , numpy as jnp , jit
2- from ngclearn import resolver , Component , Compartment
2+ from ngcsimlib .compilers .process import transition
3+ from ngcsimlib .component import Component
4+ from ngcsimlib .compartment import Compartment
5+
36from .deconvSynapse import DeconvSynapse
7+ from ngclearn .utils .weight_distribution import initialize_params
8+ from ngcsimlib .logger import info
9+ from ngclearn .utils import tensorstats
10+ import ngclearn .utils .weight_distribution as dist
411from ngclearn .components .synapses .convolution .ngcconv import (deconv2d , _calc_dX_deconv ,
512 _calc_dK_deconv , calc_dX_deconv ,
613 calc_dK_deconv )
@@ -79,13 +86,15 @@ class HebbianDeconvSynapse(DeconvSynapse): ## Hebbian-evolved deconvolutional ca
7986 """
8087
8188 # Define Functions
82- def __init__ (self , name , shape , x_shape , eta = 0. , filter_init = None , bias_init = None ,
83- stride = 1 , padding = None , resist_scale = 1. , w_bound = 0. , is_nonnegative = False ,
84- w_decay = 0. , sign_value = 1. , optim_type = "sgd" , batch_size = 1 , ** kwargs ):
85- super ().__init__ (name , shape , x_shape = x_shape , filter_init = filter_init ,
86- bias_init = bias_init , resist_scale = resist_scale ,
87- stride = stride , padding = padding , batch_size = batch_size ,
88- ** kwargs )
89+ def __init__ (
90+ self , name , shape , x_shape , eta = 0. , filter_init = None , bias_init = None , stride = 1 , padding = None ,
91+ resist_scale = 1. , w_bound = 0. , is_nonnegative = False , w_decay = 0. , sign_value = 1. , optim_type = "sgd" ,
92+ batch_size = 1 , ** kwargs
93+ ):
94+ super ().__init__ (
95+ name , shape , x_shape = x_shape , filter_init = filter_init , bias_init = bias_init , resist_scale = resist_scale ,
96+ stride = stride , padding = padding , batch_size = batch_size , ** kwargs
97+ )
8998
9099 self .eta = eta
91100 self .w_bounds = w_bound
@@ -112,8 +121,7 @@ def __init__(self, name, shape, x_shape, eta=0., filter_init=None, bias_init=Non
112121 [self .weights .value , self .biases .value ]
113122 if bias_init else [self .weights .value ]))
114123
115- def _init (self , batch_size , x_size , shape , stride , padding , pad_args ,
116- weights ):
124+ def _init (self , batch_size , x_size , shape , stride , padding , pad_args , weights ):
117125 k_size , k_size , n_in_chan , n_out_chan = shape
118126 _x = jnp .zeros ((batch_size , x_size , x_size , n_in_chan ))
119127 _d = deconv2d (_x , self .weights .value , stride_size = self .stride ,
@@ -132,8 +140,7 @@ def _init(self, batch_size, x_size, shape, stride, padding, pad_args,
132140 self .x_delta_shape = (dx , dy )
133141
134142 @staticmethod
135- def _compute_update (sign_value , w_decay , bias_init , shape , stride , padding ,
136- delta_shape , pre , post , weights ):
143+ def _compute_update (sign_value , w_decay , bias_init , shape , stride , padding , delta_shape , pre , post , weights ):
137144 k_size , k_size , n_in_chan , n_out_chan = shape
138145 ## compute adjustment to filters
139146 dWeights = calc_dK_deconv (pre , post , delta_shape = delta_shape ,
@@ -148,10 +155,12 @@ def _compute_update(sign_value, w_decay, bias_init, shape, stride, padding,
148155 dBiases = jnp .sum (post , axis = 0 , keepdims = True ) * sign_value
149156 return dWeights , dBiases
150157
158+ @transition (output_compartments = ["opt_params" , "weights" , "biases" , "dWeights" , "dBiases" ])
151159 @staticmethod
152- def _evolve (opt , sign_value , w_decay , w_bounds , is_nonnegative , bias_init ,
153- shape , stride , padding , delta_shape , pre , post , weights , biases ,
154- opt_params ):
160+ def evolve (
161+ opt , sign_value , w_decay , w_bounds , is_nonnegative , bias_init , shape , stride , padding , delta_shape ,
162+ pre , post , weights , biases , opt_params
163+ ):
155164 dWeights , dBiases = HebbianDeconvSynapse ._compute_update (
156165 sign_value , w_decay , bias_init , shape , stride , padding , delta_shape ,
157166 pre , post , weights
@@ -169,30 +178,19 @@ def _evolve(opt, sign_value, w_decay, w_bounds, is_nonnegative, bias_init,
169178 weights = jnp .clip (weights , - w_bounds , w_bounds )
170179 return opt_params , weights , biases , dWeights , dBiases
171180
172- @resolver (_evolve )
173- def evolve (self , opt_params , weights , biases , dWeights , dBiases ):
174- self .opt_params .set (opt_params )
175- self .weights .set (weights )
176- self .biases .set (biases )
177- self .dWeights .set (dWeights )
178- self .dBiases .set (dBiases )
179-
181+ @transition (output_compartments = ["dInputs" ])
180182 @staticmethod
181- def _backtransmit (sign_value , stride , padding , x_delta_shape , pre , post ,
182- weights ): ## action-backpropagating routine
183+ def backtransmit (sign_value , stride , padding , x_delta_shape , pre , post , weights ): ## action-backpropagating routine
183184 ## calc dInputs
184185 dInputs = calc_dX_deconv (weights , post , delta_shape = x_delta_shape ,
185186 stride_size = stride , padding = padding )
186187 ## flip sign of back-transmitted signal (if applicable)
187188 dInputs = dInputs * sign_value
188189 return dInputs
189190
190- @resolver (_backtransmit )
191- def backtransmit (self , dInputs ):
192- self .dInputs .set (dInputs )
193-
191+ @transition (output_compartments = ["inputs" , "outputs" , "pre" , "post" , "dInputs" ])
194192 @staticmethod
195- def _reset (in_shape , out_shape ):
193+ def reset (in_shape , out_shape ):
196194 preVals = jnp .zeros (in_shape )
197195 postVals = jnp .zeros (out_shape )
198196 inputs = preVals
@@ -202,14 +200,6 @@ def _reset(in_shape, out_shape):
202200 dInputs = preVals
203201 return inputs , outputs , pre , post , dInputs
204202
205- @resolver (_reset )
206- def reset (self , inputs , outputs , pre , post , dInputs ):
207- self .inputs .set (inputs )
208- self .outputs .set (outputs )
209- self .pre .set (pre )
210- self .post .set (post )
211- self .dInputs .set (dInputs )
212-
213203 @classmethod
214204 def help (cls ): ## component help function
215205 properties = {
0 commit comments