@@ -40,23 +40,24 @@ def _calc_update(pre, post, W, w_bound, is_nonnegative=True, signVal=1.,
4040 """
4141 _pre = pre * pre_wght
4242 _post = post * post_wght
43- dW = jnp .matmul (_pre .T , _post )
44- db = jnp .sum (_post , axis = 0 , keepdims = True )
45- dW_reg = 0.
43+ dW = jnp .matmul (_pre .T , _post ) ## calc Hebbian adjustment
44+ db = jnp .sum (_post , axis = 0 , keepdims = True ) ## calc Hebbian adjustment to bias/base-rates
45+ dW_reg = 0. ## synaptic decay term
4646
47- if w_bound > 0. :
47+ if w_bound > 0. : ## induce any synaptic value bounding
4848 dW = dW * (w_bound - jnp .abs (W ))
49-
49+ ## apply synaptic priors
5050 if prior_type == "l2" or prior_type == "ridge" :
51- dW_reg = - W
51+ dW_reg = - W * prior_lmbda
5252 if prior_type == "l1" or prior_type == "lasso" :
53- dW_reg = - jnp .sign (W )
53+ dW_reg = - jnp .sign (W ) * prior_lmbda
5454 if prior_type == "l1l2" or prior_type == "elastic_net" :
5555 l1_ratio = prior_lmbda [1 ]
56- prior_lmbda = prior_lmbda [0 ]
56+ prior_scale = prior_lmbda [0 ]
5757 dW_reg = - jnp .sign (W ) * l1_ratio - W * (1 - l1_ratio )/ 2
58-
59- dW = dW + prior_lmbda * dW_reg
58+ dW_reg = dW_reg * prior_scale
59+ ## produce final update/adjustment
60+ dW = dW + dW_reg
6061 return dW * signVal , db * signVal
6162
6263@partial (jit , static_argnums = [1 ,2 ])
0 commit comments