77from ngclearn .utils import tensorstats
88
99@partial (jit , static_argnums = [3 , 4 , 5 , 6 , 7 , 8 , 9 ])
10- def _calc_update (pre , post , W , w_mask , w_bound , is_nonnegative = True , signVal = 1. , w_decay = 0. ,
10+ def _calc_update (pre , post , W , w_mask , w_bound , is_nonnegative = True , signVal = 1. ,
11+ prior_type = None , prior_lmbda = 0. ,
1112 pre_wght = 1. , post_wght = 1. ):
1213 """
1314 Compute a tensor of adjustments to be applied to a synaptic value matrix.
@@ -19,14 +20,18 @@ def _calc_update(pre, post, W, w_mask, w_bound, is_nonnegative=True, signVal=1.,
1920
2021 W: synaptic weight values (at time t)
2122
22- w_mask: weight mask matrix
23+ w_mask: synaptic weight masking matrix (same shape as W)
2324
2425 w_bound: maximum value to enforce over newly computed efficacies
2526
27+ is_nonnegative: (Unused)
28+
2629 signVal: multiplicative factor to modulate final update by (good for
2730 flipping the signs of a computed synaptic change matrix)
2831
29- w_decay: synaptic decay factor to apply to this update
32+ prior_type: prior type or name (Default: None)
33+
34+ prior_lmbda: prior parameter (Default: 0.0)
3035
3136 pre_wght: pre-synaptic weighting term (Default: 1.)
3237
@@ -35,14 +40,28 @@ def _calc_update(pre, post, W, w_mask, w_bound, is_nonnegative=True, signVal=1.,
3540 Returns:
3641 an update/adjustment matrix, an update adjustment vector (for biases)
3742 """
43+
3844 _pre = pre * pre_wght
3945 _post = post * post_wght
4046 dW = jnp .matmul (_pre .T , _post )
4147 db = jnp .sum (_post , axis = 0 , keepdims = True )
48+ dW_reg = 0.
49+
4250 if w_bound > 0. :
4351 dW = dW * (w_bound - jnp .abs (W ))
44- if w_decay > 0. :
45- dW = dW - W * w_decay
52+
53+ if prior_type == "l2" or prior_type == "ridge" :
54+ dW_reg = W
55+
56+ if prior_type == "l1" or prior_type == "lasso" :
57+ dW_reg = jnp .sign (W )
58+
59+ if prior_type == "l1l2" or prior_type == "elastic_net" :
60+ l1_ratio = prior_lmbda [1 ]
61+ prior_lmbda = prior_lmbda [0 ]
62+ dW_reg = jnp .sign (W ) * l1_ratio + W * (1 - l1_ratio )/ 2
63+
64+ dW = dW + prior_lmbda * dW_reg
4665
4766 if w_mask != None :
4867 dW = dW * w_mask
@@ -79,6 +98,7 @@ def _enforce_constraints(W, w_mask, w_bound, is_nonnegative=True):
7998
8099 return _W
81100
101+
82102class HebbianPatchedSynapse (PatchedSynapse ):
83103 """
84104 A synaptic cable that adjusts its efficacies via a two-factor Hebbian
@@ -93,7 +113,7 @@ class HebbianPatchedSynapse(PatchedSynapse):
93113 | --- Synaptic Plasticity Compartments: ---
94114 | pre - pre-synaptic signal to drive first term of Hebbian update (takes in external signals)
95115 | post - post-synaptic signal to drive 2nd term of Hebbian update (takes in external signals)
96- | dWweights - current delta matrix containing changes to be applied to synaptic efficacies
116+ | dWeights - current delta matrix containing changes to be applied to synaptic efficacies
97117 | dBiases - current delta vector containing changes to be applied to bias values
98118 | opt_params - locally-embedded optimizer statisticis (e.g., Adam 1st/2nd moments if adam is used)
99119
@@ -104,7 +124,7 @@ class HebbianPatchedSynapse(PatchedSynapse):
104124 with number of inputs by number of outputs)
105125
106126 n_sub_models: The number of submodels in each layer
107-
127+
108128 stride_shape: Stride shape of overlapping synaptic weight value matrix
109129 (Default: (0, 0))
110130
@@ -125,9 +145,17 @@ class HebbianPatchedSynapse(PatchedSynapse):
125145 is_nonnegative: enforce that synaptic efficacies are always non-negative
126146 after each synaptic update (if False, no constraint will be applied)
127147
128- w_decay: degree to which (L2) synaptic weight decay is applied to the
129- computed Hebbian adjustment (Default: 0); note that decay is not
130- applied to any configured biases
148+
149+ prior: a kernel to drive prior of this synaptic cable's values;
150+ typically a tuple with 1st element as a string calling the name of
151+ prior to use and 2nd element as a floating point number
152+ calling the prior parameter lambda (Default: (None, 0.))
153+ currently it supports "l1" or "lasso" or "l2" or "ridge" or "l1l2" or "elastic_net".
154+ usage guide:
155+ prior = ('l1', 0.01) or prior = ('lasso', lmbda)
156+ prior = ('l2', 0.01) or prior = ('ridge', lmbda)
157+ prior = ('l1l2', (0.01, 0.01)) or prior = ('elastic_net', (lmbda, l1_ratio))
158+
131159
132160 sign_value: multiplicative factor to apply to final synaptic update before
133161 it is applied to synapses; this is useful if gradient descent style
@@ -157,12 +185,16 @@ class HebbianPatchedSynapse(PatchedSynapse):
157185 """
158186
159187 def __init__ (self , name , shape , n_sub_models , stride_shape = (0 ,0 ), eta = 0. , weight_init = None , bias_init = None ,
160- w_mask = None , w_bound = 1. , is_nonnegative = False , w_decay = 0. , sign_value = 1. ,
188+ w_mask = None , w_bound = 1. , is_nonnegative = False , prior = ( None , 0. ) , sign_value = 1. ,
161189 optim_type = "sgd" , pre_wght = 1. , post_wght = 1. , p_conn = 1. ,
162190 resist_scale = 1. , batch_size = 1 , ** kwargs ):
163191 super ().__init__ (name , shape , n_sub_models , stride_shape , w_mask , weight_init , bias_init , resist_scale ,
164192 p_conn , batch_size = batch_size , ** kwargs )
165193
194+ prior_type , prior_lmbda = prior
195+ self .prior_type = prior_type
196+ self .prior_lmbda = prior_lmbda
197+
166198 self .n_sub_models = n_sub_models
167199 self .sub_stride = stride_shape
168200
@@ -174,7 +206,6 @@ def __init__(self, name, shape, n_sub_models, stride_shape=(0,0), eta=0., weight
174206 ## synaptic plasticity properties and characteristics
175207 self .Rscale = resist_scale
176208 self .w_bound = w_bound
177- self .w_decay = w_decay ## synaptic decay
178209 self .pre_wght = pre_wght
179210 self .post_wght = post_wght
180211 self .eta = eta
@@ -199,22 +230,22 @@ def __init__(self, name, shape, n_sub_models, stride_shape=(0,0), eta=0., weight
199230 if bias_init else [self .weights .value ]))
200231
201232 @staticmethod
202- def _compute_update (w_mask , w_bound , is_nonnegative , sign_value , w_decay , pre_wght ,
233+ def _compute_update (w_mask , w_bound , is_nonnegative , sign_value , prior_type , prior_lmbda , pre_wght ,
203234 post_wght , pre , post , weights ):
204235 ## calculate synaptic update values
205236 dW , db = _calc_update (
206237 pre , post , weights , w_mask , w_bound , is_nonnegative = is_nonnegative ,
207- signVal = sign_value , w_decay = w_decay , pre_wght = pre_wght ,
238+ signVal = sign_value , prior_type = prior_type , prior_lmbda = prior_lmbda , pre_wght = pre_wght ,
208239 post_wght = post_wght )
209240
210241 return dW * jnp .where (0 != jnp .abs (weights ), 1 , 0 ) , db
211242
212243 @staticmethod
213- def _evolve (w_mask , opt , w_bound , is_nonnegative , sign_value , w_decay , pre_wght ,
244+ def _evolve (w_mask , opt , w_bound , is_nonnegative , sign_value , prior_type , prior_lmbda , pre_wght ,
214245 post_wght , bias_init , pre , post , weights , biases , opt_params ):
215246 ## calculate synaptic update values
216247 dWeights , dBiases = HebbianPatchedSynapse ._compute_update (
217- w_mask , w_bound , is_nonnegative , sign_value , w_decay ,
248+ w_mask , w_bound , is_nonnegative , sign_value , prior_type , prior_lmbda ,
218249 pre_wght , post_wght , pre , post , weights
219250 )
220251 ## conduct a step of optimization - get newly evolved synaptic weight value matrix
@@ -299,14 +330,14 @@ def help(cls): ## component help function
299330 "pre_wght" : "Pre-synaptic weighting coefficient (q_pre)" ,
300331 "post_wght" : "Post-synaptic weighting coefficient (q_post)" ,
301332 "w_bound" : "Soft synaptic bound applied to synapses post-update" ,
333+ "prior" : "prior name and value for synaptic updating prior" ,
302334 "w_mask" : "weight mask matrix" ,
303- "w_decay" : "Synaptic decay term" ,
304335 "optim_type" : "Choice of optimizer to adjust synaptic weights"
305336 }
306337 info = {cls .__name__ : properties ,
307338 "compartments" : compartment_props ,
308339 "dynamics" : "outputs = [(W * Rscale) * inputs] + b ;"
309- "dW_{ij}/dt = eta * [(z_j * q_pre) * (z_i * q_post)] - W_{ij} * w_decay " ,
340+ "dW_{ij}/dt = eta * [(z_j * q_pre) * (z_i * q_post)] - g( W_{ij}) * prior_lmbda " ,
310341 "hyperparameters" : hyperparams }
311342 return info
312343
@@ -336,12 +367,9 @@ def __repr__(self):
336367if __name__ == '__main__' :
337368 from ngcsimlib .context import Context
338369 with Context ("Bar" ) as bar :
339- Wab = HebbianPatchedSynapse ("Wab" , (9 , 30 ), 3 )
370+ Wab = HebbianPatchedSynapse ("Wab" , (9 , 30 ), 3 , (0 , 0 ), optim_type = 'adam' ,
371+ sign_value = - 1.0 , prior = ("l1l2" , 0.001 ))
340372 print (Wab )
341373 plt .imshow (Wab .weights .value , cmap = 'gray' )
342374 plt .show ()
343375
344-
345-
346-
347-
0 commit comments