Skip to content

Commit 20230ae

Browse files
committed
update old hebbian synapse
1 parent 2e1b2e8 commit 20230ae

File tree

1 file changed

+326
-0
lines changed

1 file changed

+326
-0
lines changed
Lines changed: 326 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,326 @@
1+
from jax import random, numpy as jnp, jit
2+
from functools import partial
3+
from ngclearn.utils.optim import get_opt_init_fn, get_opt_step_fn
4+
from ngclearn import resolver, Component, Compartment
5+
from ngclearn.components.synapses import DenseSynapse
6+
from ngclearn.utils import tensorstats
7+
from ngcsimlib.deprecators import deprecate_args
8+
9+
@partial(jit, static_argnums=[3, 4, 5, 6, 7, 8, 9])
10+
def _calc_update(pre, post, W, w_bound, is_nonnegative=True, signVal=1.,
11+
prior_type=None, prior_lmbda=0.,
12+
pre_wght=1., post_wght=1.):
13+
"""
14+
Compute a tensor of adjustments to be applied to a synaptic value matrix.
15+
16+
Args:
17+
pre: pre-synaptic statistic to drive Hebbian update
18+
19+
post: post-synaptic statistic to drive Hebbian update
20+
21+
W: synaptic weight values (at time t)
22+
23+
w_bound: maximum value to enforce over newly computed efficacies
24+
25+
is_nonnegative: (Unused)
26+
27+
signVal: multiplicative factor to modulate final update by (good for
28+
flipping the signs of a computed synaptic change matrix)
29+
30+
prior_type: prior type or name (Default: None)
31+
32+
prior_lmbda: prior parameter (Default: 0.0)
33+
34+
pre_wght: pre-synaptic weighting term (Default: 1.)
35+
36+
post_wght: post-synaptic weighting term (Default: 1.)
37+
38+
Returns:
39+
an update/adjustment matrix, an update adjustment vector (for biases)
40+
"""
41+
_pre = pre * pre_wght
42+
_post = post * post_wght
43+
dW = jnp.matmul(_pre.T, _post)
44+
db = jnp.sum(_post, axis=0, keepdims=True)
45+
dW_reg = 0.
46+
47+
if w_bound > 0.:
48+
dW = dW * (w_bound - jnp.abs(W))
49+
50+
if prior_type == "l2" or prior_type == "ridge":
51+
dW_reg = W
52+
if prior_type == "l1" or prior_type == "lasso":
53+
dW_reg = jnp.sign(W)
54+
if prior_type == "l1l2" or prior_type == "elastic_net":
55+
l1_ratio = prior_lmbda[1]
56+
prior_lmbda = prior_lmbda[0]
57+
dW_reg = jnp.sign(W) * l1_ratio + W * (1-l1_ratio)/2
58+
59+
dW = dW + prior_lmbda * dW_reg
60+
return dW * signVal, db * signVal
61+
62+
@partial(jit, static_argnums=[1,2])
63+
def _enforce_constraints(W, w_bound, is_nonnegative=True):
64+
"""
65+
Enforces constraints that the (synaptic) efficacies/values within matrix
66+
`W` must adhere to.
67+
68+
Args:
69+
W: synaptic weight values (at time t)
70+
71+
w_bound: maximum value to enforce over newly computed efficacies
72+
73+
is_nonnegative: ensure updated value matrix is strictly non-negative
74+
75+
Returns:
76+
the newly evolved synaptic weight value matrix
77+
"""
78+
_W = W
79+
if w_bound > 0.:
80+
if is_nonnegative == True:
81+
_W = jnp.clip(_W, 0., w_bound)
82+
else:
83+
_W = jnp.clip(_W, -w_bound, w_bound)
84+
return _W
85+
86+
87+
class HebbianSynapse(DenseSynapse):
88+
"""
89+
A synaptic cable that adjusts its efficacies via a two-factor Hebbian
90+
adjustment rule.
91+
92+
| --- Synapse Compartments: ---
93+
| inputs - input (takes in external signals)
94+
| outputs - output signals (transformation induced by synapses)
95+
| weights - current value matrix of synaptic efficacies
96+
| biases - current value vector of synaptic bias values
97+
| key - JAX PRNG key
98+
| --- Synaptic Plasticity Compartments: ---
99+
| pre - pre-synaptic signal to drive first term of Hebbian update (takes in external signals)
100+
| post - post-synaptic signal to drive 2nd term of Hebbian update (takes in external signals)
101+
| dWeights - current delta matrix containing changes to be applied to synaptic efficacies
102+
| dBiases - current delta vector containing changes to be applied to bias values
103+
| opt_params - locally-embedded optimizer statisticis (e.g., Adam 1st/2nd moments if adam is used)
104+
105+
Args:
106+
name: the string name of this cell
107+
108+
shape: tuple specifying shape of this synaptic cable (usually a 2-tuple
109+
with number of inputs by number of outputs)
110+
111+
eta: global learning rate
112+
113+
weight_init: a kernel to drive initialization of this synaptic cable's values;
114+
typically a tuple with 1st element as a string calling the name of
115+
initialization to use
116+
117+
bias_init: a kernel to drive initialization of biases for this synaptic cable
118+
(Default: None, which turns off/disables biases)
119+
120+
w_bound: maximum weight to softly bound this cable's value matrix to; if
121+
set to 0, then no synaptic value bounding will be applied
122+
123+
is_nonnegative: enforce that synaptic efficacies are always non-negative
124+
after each synaptic update (if False, no constraint will be applied)
125+
126+
prior: a kernel to drive prior of this synaptic cable's values;
127+
typically a tuple with 1st element as a string calling the name of
128+
prior to use and 2nd element as a floating point number
129+
calling the prior parameter lambda (Default: (None, 0.))
130+
currently it supports "l1" or "lasso" or "l2" or "ridge" or "l1l2" or "elastic_net".
131+
usage guide:
132+
prior = ('l1', 0.01) or prior = ('lasso', lmbda)
133+
prior = ('l2', 0.01) or prior = ('ridge', lmbda)
134+
prior = ('l1l2', (0.01, 0.01)) or prior = ('elastic_net', (lmbda, l1_ratio))
135+
136+
137+
138+
sign_value: multiplicative factor to apply to final synaptic update before
139+
it is applied to synapses; this is useful if gradient descent style
140+
optimization is required (as Hebbian rules typically yield
141+
adjustments for ascent)
142+
143+
optim_type: optimization scheme to physically alter synaptic values
144+
once an update is computed (Default: "sgd"); supported schemes
145+
include "sgd" and "adam"
146+
147+
:Note: technically, if "sgd" or "adam" is used but `signVal = 1`,
148+
then the ascent form of each rule is employed (signVal = -1) or
149+
a negative learning rate will mean a descent form of the
150+
`optim_scheme` is being employed
151+
152+
pre_wght: pre-synaptic weighting factor (Default: 1.)
153+
154+
post_wght: post-synaptic weighting factor (Default: 1.)
155+
156+
resist_scale: a fixed scaling factor to apply to synaptic transform
157+
(Default: 1.), i.e., yields: out = ((W * Rscale) * in) + b
158+
159+
p_conn: probability of a connection existing (default: 1.); setting
160+
this to < 1. will result in a sparser synaptic structure
161+
"""
162+
163+
# Define Functions
164+
@deprecate_args(_rebind=False, w_decay='prior')
165+
def __init__(self, name, shape, eta=0., weight_init=None, bias_init=None,
166+
w_bound=1., is_nonnegative=False, prior=(None, 0.), w_decay=0., sign_value=1.,
167+
optim_type="sgd", pre_wght=1., post_wght=1., p_conn=1.,
168+
resist_scale=1., batch_size=1, **kwargs):
169+
super().__init__(name, shape, weight_init, bias_init, resist_scale,
170+
p_conn, batch_size=batch_size, **kwargs)
171+
172+
if w_decay > 0.:
173+
prior = ('l2', w_decay)
174+
175+
prior_type, prior_lmbda = prior
176+
## synaptic plasticity properties and characteristics
177+
self.shape = shape
178+
self.Rscale = resist_scale
179+
self.prior_type = prior_type
180+
self.prior_lmbda = prior_lmbda
181+
self.w_bound = w_bound
182+
self.pre_wght = pre_wght
183+
self.post_wght = post_wght
184+
self.eta = eta
185+
self.is_nonnegative = is_nonnegative
186+
self.sign_value = sign_value
187+
188+
## optimization / adjustment properties (given learning dynamics above)
189+
self.opt = get_opt_step_fn(optim_type, eta=self.eta)
190+
191+
# compartments (state of the cell, parameters, will be updated through stateless calls)
192+
self.preVals = jnp.zeros((self.batch_size, shape[0]))
193+
self.postVals = jnp.zeros((self.batch_size, shape[1]))
194+
self.pre = Compartment(self.preVals)
195+
self.post = Compartment(self.postVals)
196+
self.dWeights = Compartment(jnp.zeros(shape))
197+
self.dBiases = Compartment(jnp.zeros(shape[1]))
198+
199+
#key, subkey = random.split(self.key.value)
200+
self.opt_params = Compartment(get_opt_init_fn(optim_type)(
201+
[self.weights.value, self.biases.value]
202+
if bias_init else [self.weights.value]))
203+
204+
@staticmethod
205+
def _compute_update(w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda, pre_wght,
206+
post_wght, pre, post, weights):
207+
## calculate synaptic update values
208+
dW, db = _calc_update(
209+
pre, post, weights, w_bound, is_nonnegative=is_nonnegative,
210+
signVal=sign_value, prior_type=prior_type, prior_lmbda=prior_lmbda, pre_wght=pre_wght,
211+
post_wght=post_wght)
212+
return dW, db
213+
214+
@staticmethod
215+
def _evolve(opt, w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda, pre_wght,
216+
post_wght, bias_init, pre, post, weights, biases, opt_params):
217+
## calculate synaptic update values
218+
dWeights, dBiases = HebbianSynapse._compute_update(
219+
w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda, pre_wght, post_wght,
220+
pre, post, weights
221+
)
222+
## conduct a step of optimization - get newly evolved synaptic weight value matrix
223+
if bias_init != None:
224+
opt_params, [weights, biases] = opt(opt_params, [weights, biases], [dWeights, dBiases])
225+
else:
226+
# ignore db since no biases configured
227+
opt_params, [weights] = opt(opt_params, [weights], [dWeights])
228+
## ensure synaptic efficacies adhere to constraints
229+
weights = _enforce_constraints(weights, w_bound, is_nonnegative=is_nonnegative)
230+
return opt_params, weights, biases, dWeights, dBiases
231+
232+
@resolver(_evolve)
233+
def evolve(self, opt_params, weights, biases, dWeights, dBiases):
234+
self.opt_params.set(opt_params)
235+
self.weights.set(weights)
236+
self.biases.set(biases)
237+
self.dWeights.set(dWeights)
238+
self.dBiases.set(dBiases)
239+
240+
@staticmethod
241+
def _reset(batch_size, shape):
242+
preVals = jnp.zeros((batch_size, shape[0]))
243+
postVals = jnp.zeros((batch_size, shape[1]))
244+
return (
245+
preVals, # inputs
246+
postVals, # outputs
247+
preVals, # pre
248+
postVals, # post
249+
jnp.zeros(shape), # dW
250+
jnp.zeros(shape[1]), # db
251+
)
252+
253+
@resolver(_reset)
254+
def reset(self, inputs, outputs, pre, post, dWeights, dBiases):
255+
self.inputs.set(inputs)
256+
self.outputs.set(outputs)
257+
self.pre.set(pre)
258+
self.post.set(post)
259+
self.dWeights.set(dWeights)
260+
self.dBiases.set(dBiases)
261+
262+
@classmethod
263+
def help(cls): ## component help function
264+
properties = {
265+
"synapse_type": "HebbianSynapse - performs an adaptable synaptic "
266+
"transformation of inputs to produce output signals; "
267+
"synapses are adjusted via two-term/factor Hebbian adjustment"
268+
}
269+
compartment_props = {
270+
"inputs":
271+
{"inputs": "Takes in external input signal values",
272+
"pre": "Pre-synaptic statistic for Hebb rule (z_j)",
273+
"post": "Post-synaptic statistic for Hebb rule (z_i)"},
274+
"states":
275+
{"weights": "Synapse efficacy/strength parameter values",
276+
"biases": "Base-rate/bias parameter values",
277+
"key": "JAX PRNG key"},
278+
"analytics":
279+
{"dWeights": "Synaptic weight value adjustment matrix produced at time t",
280+
"dBiases": "Synaptic bias/base-rate value adjustment vector produced at time t"},
281+
"outputs":
282+
{"outputs": "Output of synaptic transformation"},
283+
}
284+
hyperparams = {
285+
"shape": "Shape of synaptic weight value matrix; number inputs x number outputs",
286+
"batch_size": "Batch size dimension of this component",
287+
"weight_init": "Initialization conditions for synaptic weight (W) values",
288+
"bias_init": "Initialization conditions for bias/base-rate (b) values",
289+
"resist_scale": "Resistance level scaling factor (applied to output of transformation)",
290+
"p_conn": "Probability of a connection existing (otherwise, it is masked to zero)",
291+
"is_nonnegative": "Should synapses be constrained to be non-negative post-updates?",
292+
"sign_value": "Scalar `flipping` constant -- changes direction to Hebbian descent if < 0",
293+
"eta": "Global (fixed) learning rate",
294+
"pre_wght": "Pre-synaptic weighting coefficient (q_pre)",
295+
"post_wght": "Post-synaptic weighting coefficient (q_post)",
296+
"w_bound": "Soft synaptic bound applied to synapses post-update",
297+
"prior": "prior name and value for synaptic updating prior",
298+
"optim_type": "Choice of optimizer to adjust synaptic weights"
299+
}
300+
info = {cls.__name__: properties,
301+
"compartments": compartment_props,
302+
"dynamics": "outputs = [(W * Rscale) * inputs] + b ;"
303+
"dW_{ij}/dt = eta * [(z_j * q_pre) * (z_i * q_post)] - g(W_{ij}) * prior_lmbda",
304+
"hyperparameters": hyperparams}
305+
return info
306+
307+
def __repr__(self):
308+
comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
309+
maxlen = max(len(c) for c in comps) + 5
310+
lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
311+
for c in comps:
312+
stats = tensorstats(getattr(self, c).value)
313+
if stats is not None:
314+
line = [f"{k}: {v}" for k, v in stats.items()]
315+
line = ", ".join(line)
316+
else:
317+
line = "None"
318+
lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
319+
return lines
320+
321+
if __name__ == '__main__':
322+
from ngcsimlib.context import Context
323+
with Context("Bar") as bar:
324+
Wab = HebbianSynapse("Wab", (2, 3), 0.0004, optim_type='adam',
325+
sign_value=-1.0, prior=("l1l2", 0.001))
326+
print(Wab)

0 commit comments

Comments
 (0)