11from jax import random , numpy as jnp , jit
2- from ngcsimlib .compilers .process import transition
3- from ngcsimlib .component import Component
42from ngcsimlib .compartment import Compartment
5-
6- from .convSynapse import ConvSynapse
3+ from ngcsimlib .parser import compilable
74from ngclearn .utils .weight_distribution import initialize_params
8- from ngcsimlib .logger import info
9- from ngclearn .utils import tensorstats
105import ngclearn .utils .weight_distribution as dist
6+
7+ from ngclearn .components .synapses .convolution .convSynapse import ConvSynapse
8+
119from ngclearn .components .synapses .convolution .ngcconv import (_conv_same_transpose_padding ,
1210 _conv_valid_transpose_padding )
1311from ngclearn .components .synapses .convolution .ngcconv import (conv2d , _calc_dX_conv ,
1715
1816class HebbianConvSynapse (ConvSynapse ): ## Hebbian-evolved convolutional cable
1917 """
20- A synaptic convolutional cable that adjusts its efficacies via a two-factor
21- Hebbian adjustment rule.
18+ A specialized synaptic convolutional cable that adjusts its efficacies via a two-factor Hebbian adjustment rule.
2219
2320 | --- Synapse Compartments: ---
2421 | inputs - input (takes in external signals)
@@ -88,10 +85,11 @@ class HebbianConvSynapse(ConvSynapse): ## Hebbian-evolved convolutional cable
8885 """
8986
9087 # Define Functions
91- def __init__ (self , name , shape , x_shape , eta = 0. , filter_init = None , bias_init = None ,
92- stride = 1 , padding = None , resist_scale = 1. , w_bound = 0. ,
93- is_nonnegative = False , w_decay = 0. , sign_value = 1. , optim_type = "sgd" ,
94- batch_size = 1 , ** kwargs ):
88+ def __init__ (
89+ self , name , shape , x_shape , eta = 0. , filter_init = None , bias_init = None , stride = 1 , padding = None ,
90+ resist_scale = 1. , w_bound = 0. , is_nonnegative = False , w_decay = 0. , sign_value = 1. , optim_type = "sgd" ,
91+ batch_size = 1 , ** kwargs
92+ ):
9593 super ().__init__ (
9694 name , shape , x_shape = x_shape , filter_init = filter_init , bias_init = bias_init , resist_scale = resist_scale ,
9795 stride = stride , padding = padding , batch_size = batch_size , ** kwargs
@@ -107,9 +105,9 @@ def __init__(self, name, shape, x_shape, eta=0., filter_init=None, bias_init=Non
107105
108106 ######################### set up compartments ##########################
109107 ## Compartment setup and shape computation
110- self .dWeights = Compartment (self .weights .value * 0 )
108+ self .dWeights = Compartment (self .weights .get () * 0 )
111109 self .dInputs = Compartment (jnp .zeros (self .in_shape ))
112- self .dBiases = Compartment (self .biases .value * 0 )
110+ self .dBiases = Compartment (self .biases .get () * 0 )
113111 self .pre = Compartment (jnp .zeros (self .in_shape ))
114112 self .post = Compartment (jnp .zeros (self .out_shape ))
115113
@@ -120,103 +118,97 @@ def __init__(self, name, shape, x_shape, eta=0., filter_init=None, bias_init=Non
120118 self .antiPad = None
121119 k_size , k_size , n_in_chan , n_out_chan = self .shape
122120 if padding == "SAME" :
123- self .antiPad = _conv_same_transpose_padding (self .post .value .shape [1 ],
121+ self .antiPad = _conv_same_transpose_padding (self .post .get () .shape [1 ],
124122 self .x_size , k_size , stride )
125123 elif padding == "VALID" :
126- self .antiPad = _conv_valid_transpose_padding (self .post .value .shape [1 ],
124+ self .antiPad = _conv_valid_transpose_padding (self .post .get () .shape [1 ],
127125 self .x_size , k_size , stride )
128126
129127 ########################################################################
130128
131129 ## set up outer optimization compartments
132130 self .opt_params = Compartment (get_opt_init_fn (optim_type )(
133- [self .weights .value , self .biases .value ]
134- if bias_init else [self .weights .value ]))
131+ [self .weights .get (), self .biases .get ()]
132+ if bias_init else [self .weights .get ()])
133+ )
135134
136135 def _init (self , batch_size , x_size , shape , stride , padding , pad_args , weights ):
137136 k_size , k_size , n_in_chan , n_out_chan = shape
138137 _x = jnp .zeros ((batch_size , x_size , x_size , n_in_chan ))
139- _d = conv2d (_x , weights .value , stride_size = stride , padding = padding ) * 0
138+ _d = conv2d (_x , weights .get () , stride_size = stride , padding = padding ) * 0
140139 _dK = _calc_dK_conv (_x , _d , stride_size = stride , padding = pad_args )
141140 ## get filter update correction
142- dx = _dK .shape [0 ] - weights .value .shape [0 ]
143- dy = _dK .shape [1 ] - weights .value .shape [1 ]
141+ dx = _dK .shape [0 ] - weights .get () .shape [0 ]
142+ dy = _dK .shape [1 ] - weights .get () .shape [1 ]
144143 self .delta_shape = (max (dx , 0 ), max (dy , 0 ))
145144 ## get input update correction
146- _dx = _calc_dX_conv (weights .value , _d , stride_size = stride ,
147- anti_padding = pad_args )
145+ _dx = _calc_dX_conv (weights .get (), _d , stride_size = stride , anti_padding = pad_args )
148146 dx = (_dx .shape [1 ] - _x .shape [1 ])
149147 dy = (_dx .shape [2 ] - _x .shape [2 ])
150148 self .x_delta_shape = (dx , dy )
151149
152- @staticmethod
153- def _compute_update (
154- sign_value , w_decay , bias_init , stride , pad_args , delta_shape , pre , post , weights
155- ): ## synaptic kernel adjustment calculation co-routine
150+ def _compute_update (self ): #sign_value, w_decay, bias_init, stride, pad_args, delta_shape, pre, post, weights
151+ ## synaptic kernel adjustment calculation co-routine
156152 ## compute adjustment to filters
157- dWeights = calc_dK_conv (pre , post , delta_shape = delta_shape , stride_size = stride , padding = pad_args )
158- dWeights = dWeights * sign_value
159- if w_decay > 0. : ## apply synaptic decay
160- dWeights = dWeights - weights * w_decay
153+ dWeights = calc_dK_conv (
154+ self .pre .get (), self .post .get (), delta_shape = self .delta_shape , stride_size = self .stride , padding = self .pad_args
155+ )
156+ dWeights = dWeights * self .sign_value
157+ if self .w_decay > 0. : ## apply synaptic decay
158+ dWeights = dWeights - self .weights .get () * self .w_decay
161159 ## compute adjustment to base-rates (if applicable)
162160 dBiases = 0. # jnp.zeros((1,1))
163- if bias_init != None :
164- dBiases = jnp .sum (post , axis = 0 , keepdims = True ) * sign_value
161+ if self . bias_init != None :
162+ dBiases = jnp .sum (self . post . get () , axis = 0 , keepdims = True ) * self . sign_value
165163 return dWeights , dBiases
166164
167- @transition (output_compartments = ["opt_params" , "weights" , "biases" , "dWeights" , "dBiases" ])
168- @staticmethod
169- def evolve (
170- opt , sign_value , w_decay , w_bounds , is_nonnegative , bias_init , stride , pad_args , delta_shape , pre , post ,
171- weights , biases , opt_params
172- ):
165+ @compilable
166+ def evolve (self ):
173167 ## calc dFilters / dBiases - update to filters and biases
174- dWeights , dBiases = HebbianConvSynapse ._compute_update (
175- sign_value , w_decay , bias_init , stride , pad_args , delta_shape , pre , post , weights
176- )
177- if bias_init != None :
178- opt_params , [weights , biases ] = opt (opt_params , [weights , biases ], [dWeights , dBiases ])
168+ dWeights , dBiases = self ._compute_update ()
169+ if self .bias_init is not None :
170+ opt_params , [weights , biases ] = self .opt (self .opt_params .get (), [self .weights .get (), self .biases .get ()], [dWeights , dBiases ])
179171 else : ## ignore dBiases since no biases configured
180- opt_params , [weights ] = opt (opt_params , [weights ], [dWeights ])
181-
172+ opt_params , [weights ] = self . opt (self . opt_params . get () , [self . weights . get () ], [dWeights ])
173+ biases = None
182174 ## apply any enforced filter constraints
183- if w_bounds > 0. :
184- if is_nonnegative :
185- weights = jnp .clip (weights , 0. , w_bounds )
175+ if self . w_bounds > 0. :
176+ if self . is_nonnegative :
177+ weights = jnp .clip (weights , 0. , self . w_bounds )
186178 else :
187- weights = jnp .clip (weights , - w_bounds , w_bounds )
188- return opt_params , weights , biases , dWeights , dBiases
189-
190- @transition (output_compartments = ["dInputs" ])
191- @staticmethod
192- def backtransmit (
193- sign_value , x_size , shape , stride , padding , x_delta_shape , antiPad , post , weights
194- ): ## action-backpropagating routine
179+ weights = jnp .clip (weights , - self .w_bounds , self .w_bounds )
180+
181+ self .opt_params .set (opt_params )
182+ self .weights .set (weights )
183+ self .biases .set (biases )
184+ self .dWeights .set (dWeights )
185+ self .dBiases .set (dBiases )
186+
187+ @compilable
188+ def backtransmit (self ): ## action-backpropagating co-routine
195189 ## calc dInputs - adjustment w.r.t. input signal
196- k_size , k_size , n_in_chan , n_out_chan = shape
190+ k_size , k_size , n_in_chan , n_out_chan = self . shape
197191 # antiPad = None
198192 # if padding == "SAME":
199193 # antiPad = _conv_same_transpose_padding(post.shape[1], x_size,
200194 # k_size, stride)
201195 # elif padding == "VALID":
202196 # antiPad = _conv_valid_transpose_padding(post.shape[1], x_size,
203197 # k_size, stride)
204- dInputs = calc_dX_conv (weights , post , delta_shape = x_delta_shape , stride_size = stride , anti_padding = antiPad )
198+ dInputs = calc_dX_conv (self . weights . get (), self . post . get () , delta_shape = self . x_delta_shape , stride_size = self . stride , anti_padding = self . antiPad )
205199 ## flip sign of back-transmitted signal (if applicable)
206- dInputs = dInputs * sign_value
207- return dInputs
208-
209- @transition (output_compartments = ["inputs" , "outputs" , "pre" , "post" , "dInputs" ])
210- @staticmethod
211- def reset (in_shape , out_shape ):
212- preVals = jnp .zeros (in_shape )
213- postVals = jnp .zeros (out_shape )
214- inputs = preVals
215- outputs = postVals
216- pre = preVals
217- post = postVals
218- dInputs = preVals
219- return inputs , outputs , pre , post , dInputs
200+ dInputs = dInputs * self .sign_value
201+ self .dInputs .set (dInputs )
202+
203+ @compilable
204+ def reset (self ): #in_shape, out_shape):
205+ preVals = jnp .zeros (self .in_shape .get ())
206+ postVals = jnp .zeros (self .out_shape .get ())
207+ self .inputs .set (preVals )
208+ self .outputs .set (postVals )
209+ self .pre .set (preVals )
210+ self .post .set (postVals )
211+ self .dInputs .set (preVals )
220212
221213 @classmethod
222214 def help (cls ): ## component help function
0 commit comments