Skip to content

Commit eeb057a

Browse files
authored
Add l1 decay term to update calculation (#84)
* Update hebbianSynapse.py * update main update main at the end * Update hebbianSynapse.py add regularization argument and w_decay is deprecated. * Update hebbianSynapse.py add elastic_net * Update hebbianSynapse.py * Update hebbianSynapse.py
1 parent 2295ba5 commit eeb057a

File tree

1 file changed

+47
-17
lines changed

1 file changed

+47
-17
lines changed

ngclearn/components/synapses/hebbian/hebbianSynapse.py

Lines changed: 47 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
from ngclearn import resolver, Component, Compartment
55
from ngclearn.components.synapses import DenseSynapse
66
from ngclearn.utils import tensorstats
7+
from ngcsimlib.deprecators import deprecate_args
78

8-
@partial(jit, static_argnums=[3, 4, 5, 6, 7, 8])
9-
def _calc_update(pre, post, W, w_bound, is_nonnegative=True, signVal=1., w_decay=0.,
9+
@partial(jit, static_argnums=[3, 4, 5, 6, 7, 8, 9])
10+
def _calc_update(pre, post, W, w_bound, is_nonnegative=True, signVal=1.,
11+
prior_type=None, prior_lmbda=0.,
1012
pre_wght=1., post_wght=1.):
1113
"""
1214
Compute a tensor of adjustments to be applied to a synaptic value matrix.
@@ -25,7 +27,9 @@ def _calc_update(pre, post, W, w_bound, is_nonnegative=True, signVal=1., w_decay
2527
signVal: multiplicative factor to modulate final update by (good for
2628
flipping the signs of a computed synaptic change matrix)
2729
28-
w_decay: synaptic decay factor to apply to this update
30+
prior_type: prior type or name (Default: None)
31+
32+
prior_lmbda: prior parameter (Default: 0.0)
2933
3034
pre_wght: pre-synaptic weighting term (Default: 1.)
3135
@@ -38,10 +42,21 @@ def _calc_update(pre, post, W, w_bound, is_nonnegative=True, signVal=1., w_decay
3842
_post = post * post_wght
3943
dW = jnp.matmul(_pre.T, _post)
4044
db = jnp.sum(_post, axis=0, keepdims=True)
45+
dW_reg = 0.
46+
4147
if w_bound > 0.:
4248
dW = dW * (w_bound - jnp.abs(W))
43-
if w_decay > 0.:
44-
dW = dW - W * w_decay
49+
50+
if prior_type == "l2" or prior_type == "ridge":
51+
dW_reg = W
52+
if prior_type == "l1" or prior_type == "lasso":
53+
dW_reg = jnp.sign(W)
54+
if prior_type == "l1l2" or prior_type == "elastic_net":
55+
l1_ratio = prior_lmbda[1]
56+
prior_lmbda = prior_lmbda[0]
57+
dW_reg = jnp.sign(W) * l1_ratio + W * (1-l1_ratio)/2
58+
59+
dW = dW + prior_lmbda * dW_reg
4560
return dW * signVal, db * signVal
4661

4762
@partial(jit, static_argnums=[1,2])
@@ -68,6 +83,7 @@ def _enforce_constraints(W, w_bound, is_nonnegative=True):
6883
_W = jnp.clip(_W, -w_bound, w_bound)
6984
return _W
7085

86+
7187
class HebbianSynapse(DenseSynapse):
7288
"""
7389
A synaptic cable that adjusts its efficacies via a two-factor Hebbian
@@ -107,9 +123,17 @@ class HebbianSynapse(DenseSynapse):
107123
is_nonnegative: enforce that synaptic efficacies are always non-negative
108124
after each synaptic update (if False, no constraint will be applied)
109125
110-
w_decay: degree to which (L2) synaptic weight decay is applied to the
111-
computed Hebbian adjustment (Default: 0); note that decay is not
112-
applied to any configured biases
126+
prior: a kernel to drive prior of this synaptic cable's values;
127+
typically a tuple with 1st element as a string calling the name of
128+
prior to use and 2nd element as a floating point number
129+
calling the prior parameter lambda (Default: (None, 0.))
130+
currently it supports "l1" or "lasso" or "l2" or "ridge" or "l1l2" or "elastic_net".
131+
usage guide:
132+
prior = ('l1', 0.01) or prior = ('lasso', lmbda)
133+
prior = ('l2', 0.01) or prior = ('ridge', lmbda)
134+
prior = ('l1l2', (0.01, 0.01)) or prior = ('elastic_net', (lmbda, l1_ratio))
135+
136+
113137
114138
sign_value: multiplicative factor to apply to final synaptic update before
115139
it is applied to synapses; this is useful if gradient descent style
@@ -137,18 +161,24 @@ class HebbianSynapse(DenseSynapse):
137161
"""
138162

139163
# Define Functions
164+
@deprecate_args(_rebind=False, w_decay='prior')
140165
def __init__(self, name, shape, eta=0., weight_init=None, bias_init=None,
141-
w_bound=1., is_nonnegative=False, w_decay=0., sign_value=1.,
166+
w_bound=1., is_nonnegative=False, prior=(None, 0.), w_decay=0., sign_value=1.,
142167
optim_type="sgd", pre_wght=1., post_wght=1., p_conn=1.,
143168
resist_scale=1., batch_size=1, **kwargs):
144169
super().__init__(name, shape, weight_init, bias_init, resist_scale,
145170
p_conn, batch_size=batch_size, **kwargs)
146171

172+
if w_decay > 0.:
173+
prior = ('l2', w_decay)
174+
175+
prior_type, prior_lmbda = prior
147176
## synaptic plasticity properties and characteristics
148177
self.shape = shape
149178
self.Rscale = resist_scale
179+
self.prior_type = prior_type
180+
self.prior_lmbda = prior_lmbda
150181
self.w_bound = w_bound
151-
self.w_decay = w_decay ## synaptic decay
152182
self.pre_wght = pre_wght
153183
self.post_wght = post_wght
154184
self.eta = eta
@@ -172,21 +202,21 @@ def __init__(self, name, shape, eta=0., weight_init=None, bias_init=None,
172202
if bias_init else [self.weights.value]))
173203

174204
@staticmethod
175-
def _compute_update(w_bound, is_nonnegative, sign_value, w_decay, pre_wght,
205+
def _compute_update(w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda, pre_wght,
176206
post_wght, pre, post, weights):
177207
## calculate synaptic update values
178208
dW, db = _calc_update(
179209
pre, post, weights, w_bound, is_nonnegative=is_nonnegative,
180-
signVal=sign_value, w_decay=w_decay, pre_wght=pre_wght,
210+
signVal=sign_value, prior_type=prior_type, prior_lmbda=prior_lmbda, pre_wght=pre_wght,
181211
post_wght=post_wght)
182212
return dW, db
183213

184214
@staticmethod
185-
def _evolve(opt, w_bound, is_nonnegative, sign_value, w_decay, pre_wght,
215+
def _evolve(opt, w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda, pre_wght,
186216
post_wght, bias_init, pre, post, weights, biases, opt_params):
187217
## calculate synaptic update values
188218
dWeights, dBiases = HebbianSynapse._compute_update(
189-
w_bound, is_nonnegative, sign_value, w_decay, pre_wght, post_wght,
219+
w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda, pre_wght, post_wght,
190220
pre, post, weights
191221
)
192222
## conduct a step of optimization - get newly evolved synaptic weight value matrix
@@ -264,13 +294,13 @@ def help(cls): ## component help function
264294
"pre_wght": "Pre-synaptic weighting coefficient (q_pre)",
265295
"post_wght": "Post-synaptic weighting coefficient (q_post)",
266296
"w_bound": "Soft synaptic bound applied to synapses post-update",
267-
"w_decay": "Synaptic decay term",
297+
"prior": "prior name and value for synaptic updating prior",
268298
"optim_type": "Choice of optimizer to adjust synaptic weights"
269299
}
270300
info = {cls.__name__: properties,
271301
"compartments": compartment_props,
272302
"dynamics": "outputs = [(W * Rscale) * inputs] + b ;"
273-
"dW_{ij}/dt = eta * [(z_j * q_pre) * (z_i * q_post)] - W_{ij} * w_decay",
303+
"dW_{ij}/dt = eta * [(z_j * q_pre) * (z_i * q_post)] - g(W_{ij}) * prior_lmbda",
274304
"hyperparameters": hyperparams}
275305
return info
276306

@@ -292,5 +322,5 @@ def __repr__(self):
292322
from ngcsimlib.context import Context
293323
with Context("Bar") as bar:
294324
Wab = HebbianSynapse("Wab", (2, 3), 0.0004, optim_type='adam',
295-
sign_value=-1.0, bias_init=("constant", 0., 0.))
325+
sign_value=-1.0, prior=("l1l2", 0.001))
296326
print(Wab)

0 commit comments

Comments
 (0)