88from ngcsimlib .compilers .process import transition
99
1010@partial (jit , static_argnums = [3 , 4 , 5 , 6 , 7 , 8 , 9 ])
11- def _calc_update (pre , post , W , w_mask , w_bound , is_nonnegative = True , signVal = 1. ,
11+ def _calc_update (pre , post , W , mask , w_bound , is_nonnegative = True , signVal = 1. ,
1212 prior_type = None , prior_lmbda = 0. ,
1313 pre_wght = 1. , post_wght = 1. ):
1414 """
@@ -21,7 +21,7 @@ def _calc_update(pre, post, W, w_mask, w_bound, is_nonnegative=True, signVal=1.,
2121
2222 W: synaptic weight values (at time t)
2323
24- w_mask : synaptic weight masking matrix (same shape as W)
24+ mask : synaptic weight masking matrix (same shape as W)
2525
2626 w_bound: maximum value to enforce over newly computed efficacies
2727
@@ -64,21 +64,21 @@ def _calc_update(pre, post, W, w_mask, w_bound, is_nonnegative=True, signVal=1.,
6464
6565 dW = dW + prior_lmbda * dW_reg
6666
67- if w_mask != None :
68- dW = dW * w_mask
67+ if mask != None :
68+ dW = dW * mask
6969
7070 return dW * signVal , db * signVal
7171
7272@partial (jit , static_argnums = [1 ,2 , 3 ])
73- def _enforce_constraints (W , w_mask , w_bound , is_nonnegative = True ):
73+ def _enforce_constraints (W , block_mask , w_bound , is_nonnegative = True ):
7474 """
7575 Enforces constraints that the (synaptic) efficacies/values within matrix
7676 `W` must adhere to.
7777
7878 Args:
7979 W: synaptic weight values (at time t)
8080
81- w_mask : weight mask matrix
81+ block_mask : weight mask matrix
8282
8383 w_bound: maximum value to enforce over newly computed efficacies
8484
@@ -94,8 +94,8 @@ def _enforce_constraints(W, w_mask, w_bound, is_nonnegative=True):
9494 else :
9595 _W = jnp .clip (_W , - w_bound , w_bound )
9696
97- if w_mask != None :
98- _W = _W * w_mask
97+ if block_mask != None :
98+ _W = _W * block_mask
9999
100100 return _W
101101
@@ -138,7 +138,7 @@ class HebbianPatchedSynapse(PatchedSynapse):
138138 bias_init: a kernel to drive initialization of biases for this synaptic cable
139139 (Default: None, which turns off/disables biases)
140140
141- w_mask : weight mask matrix
141+ block_mask : weight mask matrix
142142
143143 w_bound: maximum weight to softly bound this cable's value matrix to; if
144144 set to 0, then no synaptic value bounding will be applied
@@ -186,10 +186,10 @@ class HebbianPatchedSynapse(PatchedSynapse):
186186 """
187187
188188 def __init__ (self , name , shape , n_sub_models = 1 , stride_shape = (0 ,0 ), eta = 0. , weight_init = None , bias_init = None ,
189- w_mask = None , w_bound = 1. , is_nonnegative = False , prior = (None , 0. ), sign_value = 1. ,
189+ block_mask = None , w_bound = 1. , is_nonnegative = False , prior = (None , 0. ), sign_value = 1. ,
190190 optim_type = "sgd" , pre_wght = 1. , post_wght = 1. , p_conn = 1. ,
191191 resist_scale = 1. , batch_size = 1 , ** kwargs ):
192- super ().__init__ (name , shape , n_sub_models , stride_shape , w_mask , weight_init , bias_init , resist_scale ,
192+ super ().__init__ (name , shape , n_sub_models , stride_shape , block_mask , weight_init , bias_init , resist_scale ,
193193 p_conn , batch_size = batch_size , ** kwargs )
194194
195195 prior_type , prior_lmbda = prior
@@ -221,7 +221,7 @@ def __init__(self, name, shape, n_sub_models=1, stride_shape=(0,0), eta=0., weig
221221 self .postVals = jnp .zeros ((self .batch_size , self .shape [1 ]))
222222 self .pre = Compartment (self .preVals )
223223 self .post = Compartment (self .postVals )
224- self .w_mask = w_mask
224+ self .block_mask = block_mask
225225 self .dWeights = Compartment (jnp .zeros (self .shape ))
226226 self .dBiases = Compartment (jnp .zeros (self .shape [1 ]))
227227
@@ -231,23 +231,23 @@ def __init__(self, name, shape, n_sub_models=1, stride_shape=(0,0), eta=0., weig
231231 if bias_init else [self .weights .value ]))
232232
233233 @staticmethod
234- def _compute_update (w_mask , w_bound , is_nonnegative , sign_value , prior_type , prior_lmbda , pre_wght ,
234+ def _compute_update (block_mask , w_bound , is_nonnegative , sign_value , prior_type , prior_lmbda , pre_wght ,
235235 post_wght , pre , post , weights ):
236236 ## calculate synaptic update values
237237 dW , db = _calc_update (
238- pre , post , weights , w_mask , w_bound , is_nonnegative = is_nonnegative ,
238+ pre , post , weights , block_mask , w_bound , is_nonnegative = is_nonnegative ,
239239 signVal = sign_value , prior_type = prior_type , prior_lmbda = prior_lmbda , pre_wght = pre_wght ,
240240 post_wght = post_wght )
241241
242242 return dW * jnp .where (0 != jnp .abs (weights ), 1 , 0 ) , db
243243
244244 @transition (output_compartments = ["opt_params" , "weights" , "biases" , "dWeights" , "dBiases" ])
245245 @staticmethod
246- def evolve (w_mask , opt , w_bound , is_nonnegative , sign_value , prior_type , prior_lmbda , pre_wght ,
246+ def evolve (block_mask , opt , w_bound , is_nonnegative , sign_value , prior_type , prior_lmbda , pre_wght ,
247247 post_wght , bias_init , pre , post , weights , biases , opt_params ):
248248 ## calculate synaptic update values
249249 dWeights , dBiases = HebbianPatchedSynapse ._compute_update (
250- w_mask , w_bound , is_nonnegative , sign_value , prior_type , prior_lmbda ,
250+ block_mask , w_bound , is_nonnegative , sign_value , prior_type , prior_lmbda ,
251251 pre_wght , post_wght , pre , post , weights
252252 )
253253 ## conduct a step of optimization - get newly evolved synaptic weight value matrix
@@ -257,7 +257,7 @@ def evolve(w_mask, opt, w_bound, is_nonnegative, sign_value, prior_type, prior_l
257257 # ignore db since no biases configured
258258 opt_params , [weights ] = opt (opt_params , [weights ], [dWeights ])
259259 ## ensure synaptic efficacies adhere to constraints
260- weights = _enforce_constraints (weights , w_mask , w_bound , is_nonnegative = is_nonnegative )
260+ weights = _enforce_constraints (weights , block_mask , w_bound , is_nonnegative = is_nonnegative )
261261 return opt_params , weights , biases , dWeights , dBiases
262262
263263 @transition (output_compartments = ["inputs" , "outputs" , "pre" , "post" , "dWeights" , "dBiases" ])
@@ -313,7 +313,7 @@ def help(cls): ## component help function
313313 "post_wght" : "Post-synaptic weighting coefficient (q_post)" ,
314314 "w_bound" : "Soft synaptic bound applied to synapses post-update" ,
315315 "prior" : "prior name and value for synaptic updating prior" ,
316- "w_mask " : "weight mask matrix" ,
316+ "block_mask " : "weight mask matrix" ,
317317 "optim_type" : "Choice of optimizer to adjust synaptic weights"
318318 }
319319 info = {cls .__name__ : properties ,
0 commit comments