44from ngclearn import resolver , Component , Compartment
55from ngclearn .components .synapses import DenseSynapse
66from ngclearn .utils import tensorstats
7+ from ngcsimlib .deprecators import deprecate_args
78
8- @partial (jit , static_argnums = [3 , 4 , 5 , 6 , 7 , 8 ])
9- def _calc_update (pre , post , W , w_bound , is_nonnegative = True , signVal = 1. , w_decay = 0. ,
9+ @partial (jit , static_argnums = [3 , 4 , 5 , 6 , 7 , 8 , 9 ])
10+ def _calc_update (pre , post , W , w_bound , is_nonnegative = True , signVal = 1. ,
11+ prior_type = None , prior_lmbda = 0. ,
1012 pre_wght = 1. , post_wght = 1. ):
1113 """
1214 Compute a tensor of adjustments to be applied to a synaptic value matrix.
@@ -25,7 +27,9 @@ def _calc_update(pre, post, W, w_bound, is_nonnegative=True, signVal=1., w_decay
2527 signVal: multiplicative factor to modulate final update by (good for
2628 flipping the signs of a computed synaptic change matrix)
2729
28- w_decay: synaptic decay factor to apply to this update
30+ prior_type: prior type or name (Default: None)
31+
32+ prior_lmbda: prior parameter (Default: 0.0)
2933
3034 pre_wght: pre-synaptic weighting term (Default: 1.)
3135
@@ -38,10 +42,21 @@ def _calc_update(pre, post, W, w_bound, is_nonnegative=True, signVal=1., w_decay
3842 _post = post * post_wght
3943 dW = jnp .matmul (_pre .T , _post )
4044 db = jnp .sum (_post , axis = 0 , keepdims = True )
45+ dW_reg = 0.
46+
4147 if w_bound > 0. :
4248 dW = dW * (w_bound - jnp .abs (W ))
43- if w_decay > 0. :
44- dW = dW - W * w_decay
49+
50+ if prior_type == "l2" or prior_type == "ridge" :
51+ dW_reg = W
52+ if prior_type == "l1" or prior_type == "lasso" :
53+ dW_reg = jnp .sign (W )
54+ if prior_type == "l1l2" or prior_type == "elastic_net" :
55+ l1_ratio = prior_lmbda [1 ]
56+ prior_lmbda = prior_lmbda [0 ]
57+ dW_reg = jnp .sign (W ) * l1_ratio + W * (1 - l1_ratio )/ 2
58+
59+ dW = dW + prior_lmbda * dW_reg
4560 return dW * signVal , db * signVal
4661
4762@partial (jit , static_argnums = [1 ,2 ])
@@ -68,6 +83,7 @@ def _enforce_constraints(W, w_bound, is_nonnegative=True):
6883 _W = jnp .clip (_W , - w_bound , w_bound )
6984 return _W
7085
86+
7187class HebbianSynapse (DenseSynapse ):
7288 """
7389 A synaptic cable that adjusts its efficacies via a two-factor Hebbian
@@ -107,9 +123,17 @@ class HebbianSynapse(DenseSynapse):
107123 is_nonnegative: enforce that synaptic efficacies are always non-negative
108124 after each synaptic update (if False, no constraint will be applied)
109125
110- w_decay: degree to which (L2) synaptic weight decay is applied to the
111- computed Hebbian adjustment (Default: 0); note that decay is not
112- applied to any configured biases
126+ prior: a kernel to drive prior of this synaptic cable's values;
127+ typically a tuple with 1st element as a string calling the name of
128+ prior to use and 2nd element as a floating point number
129+ calling the prior parameter lambda (Default: (None, 0.))
130+ currently it supports "l1" or "lasso" or "l2" or "ridge" or "l1l2" or "elastic_net".
131+ usage guide:
132+ prior = ('l1', 0.01) or prior = ('lasso', lmbda)
133+ prior = ('l2', 0.01) or prior = ('ridge', lmbda)
134+ prior = ('l1l2', (0.01, 0.01)) or prior = ('elastic_net', (lmbda, l1_ratio))
135+
136+
113137
114138 sign_value: multiplicative factor to apply to final synaptic update before
115139 it is applied to synapses; this is useful if gradient descent style
@@ -137,18 +161,24 @@ class HebbianSynapse(DenseSynapse):
137161 """
138162
139163 # Define Functions
164+ @deprecate_args (_rebind = False , w_decay = 'prior' )
140165 def __init__ (self , name , shape , eta = 0. , weight_init = None , bias_init = None ,
141- w_bound = 1. , is_nonnegative = False , w_decay = 0. , sign_value = 1. ,
166+ w_bound = 1. , is_nonnegative = False , prior = ( None , 0. ), w_decay = 0. , sign_value = 1. ,
142167 optim_type = "sgd" , pre_wght = 1. , post_wght = 1. , p_conn = 1. ,
143168 resist_scale = 1. , batch_size = 1 , ** kwargs ):
144169 super ().__init__ (name , shape , weight_init , bias_init , resist_scale ,
145170 p_conn , batch_size = batch_size , ** kwargs )
146171
172+ if w_decay > 0. :
173+ prior = ('l2' , w_decay )
174+
175+ prior_type , prior_lmbda = prior
147176 ## synaptic plasticity properties and characteristics
148177 self .shape = shape
149178 self .Rscale = resist_scale
179+ self .prior_type = prior_type
180+ self .prior_lmbda = prior_lmbda
150181 self .w_bound = w_bound
151- self .w_decay = w_decay ## synaptic decay
152182 self .pre_wght = pre_wght
153183 self .post_wght = post_wght
154184 self .eta = eta
@@ -172,21 +202,21 @@ def __init__(self, name, shape, eta=0., weight_init=None, bias_init=None,
172202 if bias_init else [self .weights .value ]))
173203
174204 @staticmethod
175- def _compute_update (w_bound , is_nonnegative , sign_value , w_decay , pre_wght ,
205+ def _compute_update (w_bound , is_nonnegative , sign_value , prior_type , prior_lmbda , pre_wght ,
176206 post_wght , pre , post , weights ):
177207 ## calculate synaptic update values
178208 dW , db = _calc_update (
179209 pre , post , weights , w_bound , is_nonnegative = is_nonnegative ,
180- signVal = sign_value , w_decay = w_decay , pre_wght = pre_wght ,
210+ signVal = sign_value , prior_type = prior_type , prior_lmbda = prior_lmbda , pre_wght = pre_wght ,
181211 post_wght = post_wght )
182212 return dW , db
183213
184214 @staticmethod
185- def _evolve (opt , w_bound , is_nonnegative , sign_value , w_decay , pre_wght ,
215+ def _evolve (opt , w_bound , is_nonnegative , sign_value , prior_type , prior_lmbda , pre_wght ,
186216 post_wght , bias_init , pre , post , weights , biases , opt_params ):
187217 ## calculate synaptic update values
188218 dWeights , dBiases = HebbianSynapse ._compute_update (
189- w_bound , is_nonnegative , sign_value , w_decay , pre_wght , post_wght ,
219+ w_bound , is_nonnegative , sign_value , prior_type , prior_lmbda , pre_wght , post_wght ,
190220 pre , post , weights
191221 )
192222 ## conduct a step of optimization - get newly evolved synaptic weight value matrix
@@ -264,13 +294,13 @@ def help(cls): ## component help function
264294 "pre_wght" : "Pre-synaptic weighting coefficient (q_pre)" ,
265295 "post_wght" : "Post-synaptic weighting coefficient (q_post)" ,
266296 "w_bound" : "Soft synaptic bound applied to synapses post-update" ,
267- "w_decay " : "Synaptic decay term " ,
297+ "prior " : "prior name and value for synaptic updating prior " ,
268298 "optim_type" : "Choice of optimizer to adjust synaptic weights"
269299 }
270300 info = {cls .__name__ : properties ,
271301 "compartments" : compartment_props ,
272302 "dynamics" : "outputs = [(W * Rscale) * inputs] + b ;"
273- "dW_{ij}/dt = eta * [(z_j * q_pre) * (z_i * q_post)] - W_{ij} * w_decay " ,
303+ "dW_{ij}/dt = eta * [(z_j * q_pre) * (z_i * q_post)] - g( W_{ij}) * prior_lmbda " ,
274304 "hyperparameters" : hyperparams }
275305 return info
276306
@@ -292,5 +322,5 @@ def __repr__(self):
292322 from ngcsimlib .context import Context
293323 with Context ("Bar" ) as bar :
294324 Wab = HebbianSynapse ("Wab" , (2 , 3 ), 0.0004 , optim_type = 'adam' ,
295- sign_value = - 1.0 , bias_init = ("constant " , 0. , 0. ))
325+ sign_value = - 1.0 , prior = ("l1l2 " , 0.001 ))
296326 print (Wab )
0 commit comments