Skip to content

Commit 1d15f1f

Browse files
authored
add prior for hebbian patched synapse (#96)
* prior replaced w_decay hebbianPatchedSynapse.py remove w_decay add prior_type and prior_lmbda * revised typo hebbianSynapse.py dWweight was typo
1 parent 6e8261e commit 1d15f1f

File tree

2 files changed

+52
-24
lines changed

2 files changed

+52
-24
lines changed

ngclearn/components/synapses/hebbian/hebbianSynapse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ class HebbianSynapse(DenseSynapse):
9898
| --- Synaptic Plasticity Compartments: ---
9999
| pre - pre-synaptic signal to drive first term of Hebbian update (takes in external signals)
100100
| post - post-synaptic signal to drive 2nd term of Hebbian update (takes in external signals)
101-
| dWweights - current delta matrix containing changes to be applied to synaptic efficacies
101+
| dWeights - current delta matrix containing changes to be applied to synaptic efficacies
102102
| dBiases - current delta vector containing changes to be applied to bias values
103103
| opt_params - locally-embedded optimizer statisticis (e.g., Adam 1st/2nd moments if adam is used)
104104

ngclearn/components/synapses/patched/hebbianPatchedSynapse.py

Lines changed: 51 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
from ngclearn.utils import tensorstats
88

99
@partial(jit, static_argnums=[3, 4, 5, 6, 7, 8, 9])
10-
def _calc_update(pre, post, W, w_mask, w_bound, is_nonnegative=True, signVal=1., w_decay=0.,
10+
def _calc_update(pre, post, W, w_mask, w_bound, is_nonnegative=True, signVal=1.,
11+
prior_type=None, prior_lmbda=0.,
1112
pre_wght=1., post_wght=1.):
1213
"""
1314
Compute a tensor of adjustments to be applied to a synaptic value matrix.
@@ -19,14 +20,18 @@ def _calc_update(pre, post, W, w_mask, w_bound, is_nonnegative=True, signVal=1.,
1920
2021
W: synaptic weight values (at time t)
2122
22-
w_mask: weight mask matrix
23+
w_mask: synaptic weight masking matrix (same shape as W)
2324
2425
w_bound: maximum value to enforce over newly computed efficacies
2526
27+
is_nonnegative: (Unused)
28+
2629
signVal: multiplicative factor to modulate final update by (good for
2730
flipping the signs of a computed synaptic change matrix)
2831
29-
w_decay: synaptic decay factor to apply to this update
32+
prior_type: prior type or name (Default: None)
33+
34+
prior_lmbda: prior parameter (Default: 0.0)
3035
3136
pre_wght: pre-synaptic weighting term (Default: 1.)
3237
@@ -35,14 +40,28 @@ def _calc_update(pre, post, W, w_mask, w_bound, is_nonnegative=True, signVal=1.,
3540
Returns:
3641
an update/adjustment matrix, an update adjustment vector (for biases)
3742
"""
43+
3844
_pre = pre * pre_wght
3945
_post = post * post_wght
4046
dW = jnp.matmul(_pre.T, _post)
4147
db = jnp.sum(_post, axis=0, keepdims=True)
48+
dW_reg = 0.
49+
4250
if w_bound > 0.:
4351
dW = dW * (w_bound - jnp.abs(W))
44-
if w_decay > 0.:
45-
dW = dW - W * w_decay
52+
53+
if prior_type == "l2" or prior_type == "ridge":
54+
dW_reg = W
55+
56+
if prior_type == "l1" or prior_type == "lasso":
57+
dW_reg = jnp.sign(W)
58+
59+
if prior_type == "l1l2" or prior_type == "elastic_net":
60+
l1_ratio = prior_lmbda[1]
61+
prior_lmbda = prior_lmbda[0]
62+
dW_reg = jnp.sign(W) * l1_ratio + W * (1-l1_ratio)/2
63+
64+
dW = dW + prior_lmbda * dW_reg
4665

4766
if w_mask!=None:
4867
dW = dW * w_mask
@@ -79,6 +98,7 @@ def _enforce_constraints(W, w_mask, w_bound, is_nonnegative=True):
7998

8099
return _W
81100

101+
82102
class HebbianPatchedSynapse(PatchedSynapse):
83103
"""
84104
A synaptic cable that adjusts its efficacies via a two-factor Hebbian
@@ -93,7 +113,7 @@ class HebbianPatchedSynapse(PatchedSynapse):
93113
| --- Synaptic Plasticity Compartments: ---
94114
| pre - pre-synaptic signal to drive first term of Hebbian update (takes in external signals)
95115
| post - post-synaptic signal to drive 2nd term of Hebbian update (takes in external signals)
96-
| dWweights - current delta matrix containing changes to be applied to synaptic efficacies
116+
| dWeights - current delta matrix containing changes to be applied to synaptic efficacies
97117
| dBiases - current delta vector containing changes to be applied to bias values
98118
| opt_params - locally-embedded optimizer statisticis (e.g., Adam 1st/2nd moments if adam is used)
99119
@@ -104,7 +124,7 @@ class HebbianPatchedSynapse(PatchedSynapse):
104124
with number of inputs by number of outputs)
105125
106126
n_sub_models: The number of submodels in each layer
107-
127+
108128
stride_shape: Stride shape of overlapping synaptic weight value matrix
109129
(Default: (0, 0))
110130
@@ -125,9 +145,17 @@ class HebbianPatchedSynapse(PatchedSynapse):
125145
is_nonnegative: enforce that synaptic efficacies are always non-negative
126146
after each synaptic update (if False, no constraint will be applied)
127147
128-
w_decay: degree to which (L2) synaptic weight decay is applied to the
129-
computed Hebbian adjustment (Default: 0); note that decay is not
130-
applied to any configured biases
148+
149+
prior: a kernel to drive prior of this synaptic cable's values;
150+
typically a tuple with 1st element as a string calling the name of
151+
prior to use and 2nd element as a floating point number
152+
calling the prior parameter lambda (Default: (None, 0.))
153+
currently it supports "l1" or "lasso" or "l2" or "ridge" or "l1l2" or "elastic_net".
154+
usage guide:
155+
prior = ('l1', 0.01) or prior = ('lasso', lmbda)
156+
prior = ('l2', 0.01) or prior = ('ridge', lmbda)
157+
prior = ('l1l2', (0.01, 0.01)) or prior = ('elastic_net', (lmbda, l1_ratio))
158+
131159
132160
sign_value: multiplicative factor to apply to final synaptic update before
133161
it is applied to synapses; this is useful if gradient descent style
@@ -157,12 +185,16 @@ class HebbianPatchedSynapse(PatchedSynapse):
157185
"""
158186

159187
def __init__(self, name, shape, n_sub_models, stride_shape=(0,0), eta=0., weight_init=None, bias_init=None,
160-
w_mask=None, w_bound=1., is_nonnegative=False, w_decay=0., sign_value=1.,
188+
w_mask=None, w_bound=1., is_nonnegative=False, prior=(None, 0.), sign_value=1.,
161189
optim_type="sgd", pre_wght=1., post_wght=1., p_conn=1.,
162190
resist_scale=1., batch_size=1, **kwargs):
163191
super().__init__(name, shape, n_sub_models, stride_shape, w_mask, weight_init, bias_init, resist_scale,
164192
p_conn, batch_size=batch_size, **kwargs)
165193

194+
prior_type, prior_lmbda = prior
195+
self.prior_type = prior_type
196+
self.prior_lmbda = prior_lmbda
197+
166198
self.n_sub_models = n_sub_models
167199
self.sub_stride = stride_shape
168200

@@ -174,7 +206,6 @@ def __init__(self, name, shape, n_sub_models, stride_shape=(0,0), eta=0., weight
174206
## synaptic plasticity properties and characteristics
175207
self.Rscale = resist_scale
176208
self.w_bound = w_bound
177-
self.w_decay = w_decay ## synaptic decay
178209
self.pre_wght = pre_wght
179210
self.post_wght = post_wght
180211
self.eta = eta
@@ -199,22 +230,22 @@ def __init__(self, name, shape, n_sub_models, stride_shape=(0,0), eta=0., weight
199230
if bias_init else [self.weights.value]))
200231

201232
@staticmethod
202-
def _compute_update(w_mask, w_bound, is_nonnegative, sign_value, w_decay, pre_wght,
233+
def _compute_update(w_mask, w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda, pre_wght,
203234
post_wght, pre, post, weights):
204235
## calculate synaptic update values
205236
dW, db = _calc_update(
206237
pre, post, weights, w_mask, w_bound, is_nonnegative=is_nonnegative,
207-
signVal=sign_value, w_decay=w_decay, pre_wght=pre_wght,
238+
signVal=sign_value, prior_type=prior_type, prior_lmbda=prior_lmbda, pre_wght=pre_wght,
208239
post_wght=post_wght)
209240

210241
return dW * jnp.where(0 != jnp.abs(weights), 1, 0) , db
211242

212243
@staticmethod
213-
def _evolve(w_mask, opt, w_bound, is_nonnegative, sign_value, w_decay, pre_wght,
244+
def _evolve(w_mask, opt, w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda, pre_wght,
214245
post_wght, bias_init, pre, post, weights, biases, opt_params):
215246
## calculate synaptic update values
216247
dWeights, dBiases = HebbianPatchedSynapse._compute_update(
217-
w_mask, w_bound, is_nonnegative, sign_value, w_decay,
248+
w_mask, w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda,
218249
pre_wght, post_wght, pre, post, weights
219250
)
220251
## conduct a step of optimization - get newly evolved synaptic weight value matrix
@@ -299,14 +330,14 @@ def help(cls): ## component help function
299330
"pre_wght": "Pre-synaptic weighting coefficient (q_pre)",
300331
"post_wght": "Post-synaptic weighting coefficient (q_post)",
301332
"w_bound": "Soft synaptic bound applied to synapses post-update",
333+
"prior": "prior name and value for synaptic updating prior",
302334
"w_mask": "weight mask matrix",
303-
"w_decay": "Synaptic decay term",
304335
"optim_type": "Choice of optimizer to adjust synaptic weights"
305336
}
306337
info = {cls.__name__: properties,
307338
"compartments": compartment_props,
308339
"dynamics": "outputs = [(W * Rscale) * inputs] + b ;"
309-
"dW_{ij}/dt = eta * [(z_j * q_pre) * (z_i * q_post)] - W_{ij} * w_decay",
340+
"dW_{ij}/dt = eta * [(z_j * q_pre) * (z_i * q_post)] - g(W_{ij}) * prior_lmbda",
310341
"hyperparameters": hyperparams}
311342
return info
312343

@@ -336,12 +367,9 @@ def __repr__(self):
336367
if __name__ == '__main__':
337368
from ngcsimlib.context import Context
338369
with Context("Bar") as bar:
339-
Wab = HebbianPatchedSynapse("Wab", (9, 30), 3)
370+
Wab = HebbianPatchedSynapse("Wab", (9, 30), 3, (0, 0), optim_type='adam',
371+
sign_value=-1.0, prior=("l1l2", 0.001))
340372
print(Wab)
341373
plt.imshow(Wab.weights.value, cmap='gray')
342374
plt.show()
343375

344-
345-
346-
347-

0 commit comments

Comments
 (0)