Skip to content

Commit 2a71b7f

Browse files
author
Alexander Ororbia
committed
minor tweak to attentive prob code comments
1 parent f402d98 commit 2a71b7f

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

ngclearn/utils/analysis/attentive_probe.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -232,24 +232,24 @@ def __init__(
232232
bvs = random.normal(subkeys[13], (1, learnable_query_dim)) * sigma
233233
Wouts = random.normal(subkeys[14], (learnable_query_dim, learnable_query_dim)) * sigma
234234
bouts = random.normal(subkeys[15], (1, learnable_query_dim)) * sigma
235-
Wlnattn_mu = jnp.zeros((1, learnable_query_dim))
236-
Wlnattn_scale = jnp.ones((1, learnable_query_dim))
235+
Wlnattn_mu = jnp.zeros((1, learnable_query_dim)) ## LN parameter (applied to output of attention)
236+
Wlnattn_scale = jnp.ones((1, learnable_query_dim)) ## LN parameter (applied to output of attention)
237237
self_attn_params = (Wqs, bqs, Wks, bks, Wvs, bvs, Wouts, bouts, Wlnattn_mu, Wlnattn_scale)
238238
learnable_query = jnp.zeros((batch_size, 1, learnable_query_dim)) # (B, T, D)
239239
self.mask = np.zeros((batch_size, target_seq_length, source_seq_length)).astype(bool) ## mask tensor
240240
## MLP parameters
241241
Whid1 = random.normal(subkeys[16], (learnable_query_dim, learnable_query_dim)) * sigma
242242
bhid1 = random.normal(subkeys[17], (1, learnable_query_dim)) * sigma
243-
Wln_mu1 = jnp.zeros((1, learnable_query_dim))
244-
Wln_scale1 = jnp.ones((1, learnable_query_dim))
243+
Wln_mu1 = jnp.zeros((1, learnable_query_dim)) ## LN parameter
244+
Wln_scale1 = jnp.ones((1, learnable_query_dim)) ## LN parameter
245245
Whid2 = random.normal(subkeys[18], (learnable_query_dim, learnable_query_dim * 4)) * sigma
246246
bhid2 = random.normal(subkeys[19], (1, learnable_query_dim * 4)) * sigma
247-
Wln_mu2 = jnp.zeros((1, learnable_query_dim))
248-
Wln_scale2 = jnp.ones((1, learnable_query_dim))
247+
Wln_mu2 = jnp.zeros((1, learnable_query_dim)) ## LN parameter
248+
Wln_scale2 = jnp.ones((1, learnable_query_dim)) ## LN parameter
249249
Whid3 = random.normal(subkeys[20], (learnable_query_dim * 4, learnable_query_dim)) * sigma
250250
bhid3 = random.normal(subkeys[21], (1, learnable_query_dim)) * sigma
251-
Wln_mu3 = jnp.zeros((1, learnable_query_dim * 4))
252-
Wln_scale3 = jnp.ones((1, learnable_query_dim * 4))
251+
Wln_mu3 = jnp.zeros((1, learnable_query_dim * 4)) ## LN parameter
252+
Wln_scale3 = jnp.ones((1, learnable_query_dim * 4)) ## LN parameter
253253
Wy = random.normal(subkeys[22], (learnable_query_dim, out_dim)) * sigma
254254
by = random.normal(subkeys[23], (1, out_dim)) * sigma
255255
mlp_params = (Whid1, bhid1, Wln_mu1, Wln_scale1, Whid2, bhid2, Wln_mu2, Wln_scale2, Whid3, bhid3, Wln_mu3, Wln_scale3, Wy, by)

0 commit comments

Comments
 (0)