1+ from jax import random , numpy as jnp , jit
2+ from functools import partial
3+ from ngclearn .utils .optim import get_opt_init_fn , get_opt_step_fn
4+ from ngclearn import resolver , Component , Compartment
5+ from ngclearn .components .synapses import DenseSynapse
6+ from ngclearn .utils import tensorstats
7+ from ngcsimlib .deprecators import deprecate_args
8+
9+ @partial (jit , static_argnums = [3 , 4 , 5 , 6 , 7 , 8 , 9 ])
10+ def _calc_update (pre , post , W , w_bound , is_nonnegative = True , signVal = 1. ,
11+ prior_type = None , prior_lmbda = 0. ,
12+ pre_wght = 1. , post_wght = 1. ):
13+ """
14+ Compute a tensor of adjustments to be applied to a synaptic value matrix.
15+
16+ Args:
17+ pre: pre-synaptic statistic to drive Hebbian update
18+
19+ post: post-synaptic statistic to drive Hebbian update
20+
21+ W: synaptic weight values (at time t)
22+
23+ w_bound: maximum value to enforce over newly computed efficacies
24+
25+ is_nonnegative: (Unused)
26+
27+ signVal: multiplicative factor to modulate final update by (good for
28+ flipping the signs of a computed synaptic change matrix)
29+
30+ prior_type: prior type or name (Default: None)
31+
32+ prior_lmbda: prior parameter (Default: 0.0)
33+
34+ pre_wght: pre-synaptic weighting term (Default: 1.)
35+
36+ post_wght: post-synaptic weighting term (Default: 1.)
37+
38+ Returns:
39+ an update/adjustment matrix, an update adjustment vector (for biases)
40+ """
41+ _pre = pre * pre_wght
42+ _post = post * post_wght
43+ dW = jnp .matmul (_pre .T , _post )
44+ db = jnp .sum (_post , axis = 0 , keepdims = True )
45+ dW_reg = 0.
46+
47+ if w_bound > 0. :
48+ dW = dW * (w_bound - jnp .abs (W ))
49+
50+ if prior_type == "l2" or prior_type == "ridge" :
51+ dW_reg = W
52+ if prior_type == "l1" or prior_type == "lasso" :
53+ dW_reg = jnp .sign (W )
54+ if prior_type == "l1l2" or prior_type == "elastic_net" :
55+ l1_ratio = prior_lmbda [1 ]
56+ prior_lmbda = prior_lmbda [0 ]
57+ dW_reg = jnp .sign (W ) * l1_ratio + W * (1 - l1_ratio )/ 2
58+
59+ dW = dW + prior_lmbda * dW_reg
60+ return dW * signVal , db * signVal
61+
62+ @partial (jit , static_argnums = [1 ,2 ])
63+ def _enforce_constraints (W , w_bound , is_nonnegative = True ):
64+ """
65+ Enforces constraints that the (synaptic) efficacies/values within matrix
66+ `W` must adhere to.
67+
68+ Args:
69+ W: synaptic weight values (at time t)
70+
71+ w_bound: maximum value to enforce over newly computed efficacies
72+
73+ is_nonnegative: ensure updated value matrix is strictly non-negative
74+
75+ Returns:
76+ the newly evolved synaptic weight value matrix
77+ """
78+ _W = W
79+ if w_bound > 0. :
80+ if is_nonnegative == True :
81+ _W = jnp .clip (_W , 0. , w_bound )
82+ else :
83+ _W = jnp .clip (_W , - w_bound , w_bound )
84+ return _W
85+
86+
87+ class HebbianSynapse (DenseSynapse ):
88+ """
89+ A synaptic cable that adjusts its efficacies via a two-factor Hebbian
90+ adjustment rule.
91+
92+ | --- Synapse Compartments: ---
93+ | inputs - input (takes in external signals)
94+ | outputs - output signals (transformation induced by synapses)
95+ | weights - current value matrix of synaptic efficacies
96+ | biases - current value vector of synaptic bias values
97+ | key - JAX PRNG key
98+ | --- Synaptic Plasticity Compartments: ---
99+ | pre - pre-synaptic signal to drive first term of Hebbian update (takes in external signals)
100+ | post - post-synaptic signal to drive 2nd term of Hebbian update (takes in external signals)
101+ | dWeights - current delta matrix containing changes to be applied to synaptic efficacies
102+ | dBiases - current delta vector containing changes to be applied to bias values
103+ | opt_params - locally-embedded optimizer statisticis (e.g., Adam 1st/2nd moments if adam is used)
104+
105+ Args:
106+ name: the string name of this cell
107+
108+ shape: tuple specifying shape of this synaptic cable (usually a 2-tuple
109+ with number of inputs by number of outputs)
110+
111+ eta: global learning rate
112+
113+ weight_init: a kernel to drive initialization of this synaptic cable's values;
114+ typically a tuple with 1st element as a string calling the name of
115+ initialization to use
116+
117+ bias_init: a kernel to drive initialization of biases for this synaptic cable
118+ (Default: None, which turns off/disables biases)
119+
120+ w_bound: maximum weight to softly bound this cable's value matrix to; if
121+ set to 0, then no synaptic value bounding will be applied
122+
123+ is_nonnegative: enforce that synaptic efficacies are always non-negative
124+ after each synaptic update (if False, no constraint will be applied)
125+
126+ prior: a kernel to drive prior of this synaptic cable's values;
127+ typically a tuple with 1st element as a string calling the name of
128+ prior to use and 2nd element as a floating point number
129+ calling the prior parameter lambda (Default: (None, 0.))
130+ currently it supports "l1" or "lasso" or "l2" or "ridge" or "l1l2" or "elastic_net".
131+ usage guide:
132+ prior = ('l1', 0.01) or prior = ('lasso', lmbda)
133+ prior = ('l2', 0.01) or prior = ('ridge', lmbda)
134+ prior = ('l1l2', (0.01, 0.01)) or prior = ('elastic_net', (lmbda, l1_ratio))
135+
136+
137+
138+ sign_value: multiplicative factor to apply to final synaptic update before
139+ it is applied to synapses; this is useful if gradient descent style
140+ optimization is required (as Hebbian rules typically yield
141+ adjustments for ascent)
142+
143+ optim_type: optimization scheme to physically alter synaptic values
144+ once an update is computed (Default: "sgd"); supported schemes
145+ include "sgd" and "adam"
146+
147+ :Note: technically, if "sgd" or "adam" is used but `signVal = 1`,
148+ then the ascent form of each rule is employed (signVal = -1) or
149+ a negative learning rate will mean a descent form of the
150+ `optim_scheme` is being employed
151+
152+ pre_wght: pre-synaptic weighting factor (Default: 1.)
153+
154+ post_wght: post-synaptic weighting factor (Default: 1.)
155+
156+ resist_scale: a fixed scaling factor to apply to synaptic transform
157+ (Default: 1.), i.e., yields: out = ((W * Rscale) * in) + b
158+
159+ p_conn: probability of a connection existing (default: 1.); setting
160+ this to < 1. will result in a sparser synaptic structure
161+ """
162+
163+ # Define Functions
164+ @deprecate_args (_rebind = False , w_decay = 'prior' )
165+ def __init__ (self , name , shape , eta = 0. , weight_init = None , bias_init = None ,
166+ w_bound = 1. , is_nonnegative = False , prior = (None , 0. ), w_decay = 0. , sign_value = 1. ,
167+ optim_type = "sgd" , pre_wght = 1. , post_wght = 1. , p_conn = 1. ,
168+ resist_scale = 1. , batch_size = 1 , ** kwargs ):
169+ super ().__init__ (name , shape , weight_init , bias_init , resist_scale ,
170+ p_conn , batch_size = batch_size , ** kwargs )
171+
172+ if w_decay > 0. :
173+ prior = ('l2' , w_decay )
174+
175+ prior_type , prior_lmbda = prior
176+ ## synaptic plasticity properties and characteristics
177+ self .shape = shape
178+ self .Rscale = resist_scale
179+ self .prior_type = prior_type
180+ self .prior_lmbda = prior_lmbda
181+ self .w_bound = w_bound
182+ self .pre_wght = pre_wght
183+ self .post_wght = post_wght
184+ self .eta = eta
185+ self .is_nonnegative = is_nonnegative
186+ self .sign_value = sign_value
187+
188+ ## optimization / adjustment properties (given learning dynamics above)
189+ self .opt = get_opt_step_fn (optim_type , eta = self .eta )
190+
191+ # compartments (state of the cell, parameters, will be updated through stateless calls)
192+ self .preVals = jnp .zeros ((self .batch_size , shape [0 ]))
193+ self .postVals = jnp .zeros ((self .batch_size , shape [1 ]))
194+ self .pre = Compartment (self .preVals )
195+ self .post = Compartment (self .postVals )
196+ self .dWeights = Compartment (jnp .zeros (shape ))
197+ self .dBiases = Compartment (jnp .zeros (shape [1 ]))
198+
199+ #key, subkey = random.split(self.key.value)
200+ self .opt_params = Compartment (get_opt_init_fn (optim_type )(
201+ [self .weights .value , self .biases .value ]
202+ if bias_init else [self .weights .value ]))
203+
204+ @staticmethod
205+ def _compute_update (w_bound , is_nonnegative , sign_value , prior_type , prior_lmbda , pre_wght ,
206+ post_wght , pre , post , weights ):
207+ ## calculate synaptic update values
208+ dW , db = _calc_update (
209+ pre , post , weights , w_bound , is_nonnegative = is_nonnegative ,
210+ signVal = sign_value , prior_type = prior_type , prior_lmbda = prior_lmbda , pre_wght = pre_wght ,
211+ post_wght = post_wght )
212+ return dW , db
213+
214+ @staticmethod
215+ def _evolve (opt , w_bound , is_nonnegative , sign_value , prior_type , prior_lmbda , pre_wght ,
216+ post_wght , bias_init , pre , post , weights , biases , opt_params ):
217+ ## calculate synaptic update values
218+ dWeights , dBiases = HebbianSynapse ._compute_update (
219+ w_bound , is_nonnegative , sign_value , prior_type , prior_lmbda , pre_wght , post_wght ,
220+ pre , post , weights
221+ )
222+ ## conduct a step of optimization - get newly evolved synaptic weight value matrix
223+ if bias_init != None :
224+ opt_params , [weights , biases ] = opt (opt_params , [weights , biases ], [dWeights , dBiases ])
225+ else :
226+ # ignore db since no biases configured
227+ opt_params , [weights ] = opt (opt_params , [weights ], [dWeights ])
228+ ## ensure synaptic efficacies adhere to constraints
229+ weights = _enforce_constraints (weights , w_bound , is_nonnegative = is_nonnegative )
230+ return opt_params , weights , biases , dWeights , dBiases
231+
232+ @resolver (_evolve )
233+ def evolve (self , opt_params , weights , biases , dWeights , dBiases ):
234+ self .opt_params .set (opt_params )
235+ self .weights .set (weights )
236+ self .biases .set (biases )
237+ self .dWeights .set (dWeights )
238+ self .dBiases .set (dBiases )
239+
240+ @staticmethod
241+ def _reset (batch_size , shape ):
242+ preVals = jnp .zeros ((batch_size , shape [0 ]))
243+ postVals = jnp .zeros ((batch_size , shape [1 ]))
244+ return (
245+ preVals , # inputs
246+ postVals , # outputs
247+ preVals , # pre
248+ postVals , # post
249+ jnp .zeros (shape ), # dW
250+ jnp .zeros (shape [1 ]), # db
251+ )
252+
253+ @resolver (_reset )
254+ def reset (self , inputs , outputs , pre , post , dWeights , dBiases ):
255+ self .inputs .set (inputs )
256+ self .outputs .set (outputs )
257+ self .pre .set (pre )
258+ self .post .set (post )
259+ self .dWeights .set (dWeights )
260+ self .dBiases .set (dBiases )
261+
262+ @classmethod
263+ def help (cls ): ## component help function
264+ properties = {
265+ "synapse_type" : "HebbianSynapse - performs an adaptable synaptic "
266+ "transformation of inputs to produce output signals; "
267+ "synapses are adjusted via two-term/factor Hebbian adjustment"
268+ }
269+ compartment_props = {
270+ "inputs" :
271+ {"inputs" : "Takes in external input signal values" ,
272+ "pre" : "Pre-synaptic statistic for Hebb rule (z_j)" ,
273+ "post" : "Post-synaptic statistic for Hebb rule (z_i)" },
274+ "states" :
275+ {"weights" : "Synapse efficacy/strength parameter values" ,
276+ "biases" : "Base-rate/bias parameter values" ,
277+ "key" : "JAX PRNG key" },
278+ "analytics" :
279+ {"dWeights" : "Synaptic weight value adjustment matrix produced at time t" ,
280+ "dBiases" : "Synaptic bias/base-rate value adjustment vector produced at time t" },
281+ "outputs" :
282+ {"outputs" : "Output of synaptic transformation" },
283+ }
284+ hyperparams = {
285+ "shape" : "Shape of synaptic weight value matrix; number inputs x number outputs" ,
286+ "batch_size" : "Batch size dimension of this component" ,
287+ "weight_init" : "Initialization conditions for synaptic weight (W) values" ,
288+ "bias_init" : "Initialization conditions for bias/base-rate (b) values" ,
289+ "resist_scale" : "Resistance level scaling factor (applied to output of transformation)" ,
290+ "p_conn" : "Probability of a connection existing (otherwise, it is masked to zero)" ,
291+ "is_nonnegative" : "Should synapses be constrained to be non-negative post-updates?" ,
292+ "sign_value" : "Scalar `flipping` constant -- changes direction to Hebbian descent if < 0" ,
293+ "eta" : "Global (fixed) learning rate" ,
294+ "pre_wght" : "Pre-synaptic weighting coefficient (q_pre)" ,
295+ "post_wght" : "Post-synaptic weighting coefficient (q_post)" ,
296+ "w_bound" : "Soft synaptic bound applied to synapses post-update" ,
297+ "prior" : "prior name and value for synaptic updating prior" ,
298+ "optim_type" : "Choice of optimizer to adjust synaptic weights"
299+ }
300+ info = {cls .__name__ : properties ,
301+ "compartments" : compartment_props ,
302+ "dynamics" : "outputs = [(W * Rscale) * inputs] + b ;"
303+ "dW_{ij}/dt = eta * [(z_j * q_pre) * (z_i * q_post)] - g(W_{ij}) * prior_lmbda" ,
304+ "hyperparameters" : hyperparams }
305+ return info
306+
307+ def __repr__ (self ):
308+ comps = [varname for varname in dir (self ) if Compartment .is_compartment (getattr (self , varname ))]
309+ maxlen = max (len (c ) for c in comps ) + 5
310+ lines = f"[{ self .__class__ .__name__ } ] PATH: { self .name } \n "
311+ for c in comps :
312+ stats = tensorstats (getattr (self , c ).value )
313+ if stats is not None :
314+ line = [f"{ k } : { v } " for k , v in stats .items ()]
315+ line = ", " .join (line )
316+ else :
317+ line = "None"
318+ lines += f" { f'({ c } )' .ljust (maxlen )} { line } \n "
319+ return lines
320+
321+ if __name__ == '__main__' :
322+ from ngcsimlib .context import Context
323+ with Context ("Bar" ) as bar :
324+ Wab = HebbianSynapse ("Wab" , (2 , 3 ), 0.0004 , optim_type = 'adam' ,
325+ sign_value = - 1.0 , prior = ("l1l2" , 0.001 ))
326+ print (Wab )
0 commit comments