Skip to content

Commit 2d8c6e4

Browse files
author
Alexander Ororbia
committed
Merge branch 'main' of github.com:NACLab/ngc-learn
2 parents cf7ee73 + 91789a5 commit 2d8c6e4

File tree

3 files changed

+24
-24
lines changed

3 files changed

+24
-24
lines changed

ngclearn/components/synapses/denseSynapse.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ def __init__(
6262
self.weight_init = {"dist": "uniform", "amin": 0.025, "amax": 0.8}
6363
weights = initialize_params(subkeys[0], self.weight_init, shape)
6464
if 0. < p_conn < 1.: ## only non-zero and <1 probs allowed
65-
mask = random.bernoulli(subkeys[1], p=p_conn, shape=shape)
66-
weights = weights * mask ## sparsify matrix
65+
p_mask = random.bernoulli(subkeys[1], p=p_conn, shape=shape)
66+
weights = weights * p_mask ## sparsify matrix
6767

6868
self.batch_size = batch_size #1
6969
## Compartment setup

ngclearn/components/synapses/patched/hebbianPatchedSynapse.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from ngcsimlib.compilers.process import transition
99

1010
@partial(jit, static_argnums=[3, 4, 5, 6, 7, 8, 9])
11-
def _calc_update(pre, post, W, w_mask, w_bound, is_nonnegative=True, signVal=1.,
11+
def _calc_update(pre, post, W, mask, w_bound, is_nonnegative=True, signVal=1.,
1212
prior_type=None, prior_lmbda=0.,
1313
pre_wght=1., post_wght=1.):
1414
"""
@@ -21,7 +21,7 @@ def _calc_update(pre, post, W, w_mask, w_bound, is_nonnegative=True, signVal=1.,
2121
2222
W: synaptic weight values (at time t)
2323
24-
w_mask: synaptic weight masking matrix (same shape as W)
24+
mask: synaptic weight masking matrix (same shape as W)
2525
2626
w_bound: maximum value to enforce over newly computed efficacies
2727
@@ -64,21 +64,21 @@ def _calc_update(pre, post, W, w_mask, w_bound, is_nonnegative=True, signVal=1.,
6464

6565
dW = dW + prior_lmbda * dW_reg
6666

67-
if w_mask!=None:
68-
dW = dW * w_mask
67+
if mask!=None:
68+
dW = dW * mask
6969

7070
return dW * signVal, db * signVal
7171

7272
@partial(jit, static_argnums=[1,2, 3])
73-
def _enforce_constraints(W, w_mask, w_bound, is_nonnegative=True):
73+
def _enforce_constraints(W, block_mask, w_bound, is_nonnegative=True):
7474
"""
7575
Enforces constraints that the (synaptic) efficacies/values within matrix
7676
`W` must adhere to.
7777
7878
Args:
7979
W: synaptic weight values (at time t)
8080
81-
w_mask: weight mask matrix
81+
block_mask: weight mask matrix
8282
8383
w_bound: maximum value to enforce over newly computed efficacies
8484
@@ -94,8 +94,8 @@ def _enforce_constraints(W, w_mask, w_bound, is_nonnegative=True):
9494
else:
9595
_W = jnp.clip(_W, -w_bound, w_bound)
9696

97-
if w_mask!=None:
98-
_W = _W * w_mask
97+
if block_mask!=None:
98+
_W = _W * block_mask
9999

100100
return _W
101101

@@ -138,7 +138,7 @@ class HebbianPatchedSynapse(PatchedSynapse):
138138
bias_init: a kernel to drive initialization of biases for this synaptic cable
139139
(Default: None, which turns off/disables biases)
140140
141-
w_mask: weight mask matrix
141+
block_mask: weight mask matrix
142142
143143
w_bound: maximum weight to softly bound this cable's value matrix to; if
144144
set to 0, then no synaptic value bounding will be applied
@@ -186,10 +186,10 @@ class HebbianPatchedSynapse(PatchedSynapse):
186186
"""
187187

188188
def __init__(self, name, shape, n_sub_models=1, stride_shape=(0,0), eta=0., weight_init=None, bias_init=None,
189-
w_mask=None, w_bound=1., is_nonnegative=False, prior=(None, 0.), sign_value=1.,
189+
block_mask=None, w_bound=1., is_nonnegative=False, prior=(None, 0.), sign_value=1.,
190190
optim_type="sgd", pre_wght=1., post_wght=1., p_conn=1.,
191191
resist_scale=1., batch_size=1, **kwargs):
192-
super().__init__(name, shape, n_sub_models, stride_shape, w_mask, weight_init, bias_init, resist_scale,
192+
super().__init__(name, shape, n_sub_models, stride_shape, block_mask, weight_init, bias_init, resist_scale,
193193
p_conn, batch_size=batch_size, **kwargs)
194194

195195
prior_type, prior_lmbda = prior
@@ -221,7 +221,7 @@ def __init__(self, name, shape, n_sub_models=1, stride_shape=(0,0), eta=0., weig
221221
self.postVals = jnp.zeros((self.batch_size, self.shape[1]))
222222
self.pre = Compartment(self.preVals)
223223
self.post = Compartment(self.postVals)
224-
self.w_mask = w_mask
224+
self.block_mask = block_mask
225225
self.dWeights = Compartment(jnp.zeros(self.shape))
226226
self.dBiases = Compartment(jnp.zeros(self.shape[1]))
227227

@@ -231,23 +231,23 @@ def __init__(self, name, shape, n_sub_models=1, stride_shape=(0,0), eta=0., weig
231231
if bias_init else [self.weights.value]))
232232

233233
@staticmethod
234-
def _compute_update(w_mask, w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda, pre_wght,
234+
def _compute_update(block_mask, w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda, pre_wght,
235235
post_wght, pre, post, weights):
236236
## calculate synaptic update values
237237
dW, db = _calc_update(
238-
pre, post, weights, w_mask, w_bound, is_nonnegative=is_nonnegative,
238+
pre, post, weights, block_mask, w_bound, is_nonnegative=is_nonnegative,
239239
signVal=sign_value, prior_type=prior_type, prior_lmbda=prior_lmbda, pre_wght=pre_wght,
240240
post_wght=post_wght)
241241

242242
return dW * jnp.where(0 != jnp.abs(weights), 1, 0) , db
243243

244244
@transition(output_compartments=["opt_params", "weights", "biases", "dWeights", "dBiases"])
245245
@staticmethod
246-
def evolve(w_mask, opt, w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda, pre_wght,
246+
def evolve(block_mask, opt, w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda, pre_wght,
247247
post_wght, bias_init, pre, post, weights, biases, opt_params):
248248
## calculate synaptic update values
249249
dWeights, dBiases = HebbianPatchedSynapse._compute_update(
250-
w_mask, w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda,
250+
block_mask, w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda,
251251
pre_wght, post_wght, pre, post, weights
252252
)
253253
## conduct a step of optimization - get newly evolved synaptic weight value matrix
@@ -257,7 +257,7 @@ def evolve(w_mask, opt, w_bound, is_nonnegative, sign_value, prior_type, prior_l
257257
# ignore db since no biases configured
258258
opt_params, [weights] = opt(opt_params, [weights], [dWeights])
259259
## ensure synaptic efficacies adhere to constraints
260-
weights = _enforce_constraints(weights, w_mask, w_bound, is_nonnegative=is_nonnegative)
260+
weights = _enforce_constraints(weights, block_mask, w_bound, is_nonnegative=is_nonnegative)
261261
return opt_params, weights, biases, dWeights, dBiases
262262

263263
@transition(output_compartments=["inputs", "outputs", "pre", "post", "dWeights", "dBiases"])
@@ -313,7 +313,7 @@ def help(cls): ## component help function
313313
"post_wght": "Post-synaptic weighting coefficient (q_post)",
314314
"w_bound": "Soft synaptic bound applied to synapses post-update",
315315
"prior": "prior name and value for synaptic updating prior",
316-
"w_mask": "weight mask matrix",
316+
"block_mask": "weight mask matrix",
317317
"optim_type": "Choice of optimizer to adjust synaptic weights"
318318
}
319319
info = {cls.__name__: properties,

ngclearn/components/synapses/patched/patchedSynapse.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ class PatchedSynapse(JaxComponent): ## base patched synaptic cable
7979
bias_init: a kernel to drive initialization of biases for this synaptic cable
8080
(Default: None, which turns off/disables biases)
8181
82-
w_mask: weight mask matrix
82+
block_mask: weight mask matrix
8383
8484
pre_wght: pre-synaptic weighting factor (Default: 1.)
8585
@@ -92,7 +92,7 @@ class PatchedSynapse(JaxComponent): ## base patched synaptic cable
9292
this to < 1. will result in a sparser synaptic structure
9393
"""
9494

95-
def __init__(self, name, shape, n_sub_models=1, stride_shape=(0,0), w_mask=None, weight_init=None, bias_init=None,
95+
def __init__(self, name, shape, n_sub_models=1, stride_shape=(0,0), block_mask=None, weight_init=None, bias_init=None,
9696
resist_scale=1., p_conn=1., batch_size=1, **kwargs):
9797
super().__init__(name, **kwargs)
9898

@@ -112,7 +112,7 @@ def __init__(self, name, shape, n_sub_models=1, stride_shape=(0,0), w_mask=None,
112112
weights = create_multi_patch_synapses(key=subkeys, shape=shape, n_sub_models=self.n_sub_models, sub_stride=self.sub_stride,
113113
weight_init=self.weight_init)
114114

115-
self.w_mask = jnp.where(weights!=0, 1, 0)
115+
self.block_mask = jnp.where(weights!=0, 1, 0)
116116
self.sub_shape = (shape[0]//n_sub_models, shape[1]//n_sub_models)
117117

118118
self.shape = weights.shape
@@ -192,7 +192,7 @@ def help(cls): ## component help function
192192
"weight_init": "Initialization conditions for synaptic weight (W) values",
193193
"bias_init": "Initialization conditions for bias/base-rate (b) values",
194194
"resist_scale": "Resistance level scaling factor (Rscale); applied to output of transformation",
195-
"w_mask": "weight mask matrix",
195+
"block_mask": "weight mask matrix",
196196
"p_conn": "Probability of a connection existing (otherwise, it is masked to zero)"
197197
}
198198
info = {cls.__name__: properties,

0 commit comments

Comments
 (0)