Skip to content

Commit dc8c127

Browse files
committed
update code to fix nan bug
1 parent 8682954 commit dc8c127

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

ngclearn/utils/analysis/attentive_probe.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,9 @@ def run_attention_probe(params, encodings, mask, n_heads: int, dropout: float =
133133
features = features + skip
134134
outs = jnp.matmul(features, Wy) + by
135135
if use_softmax: ## apply softmax output nonlinearity
136-
outs = softmax(outs)
136+
# NOTE: Viet: please check the softmax function, it might potentially
137+
# cause the gradient to be nan since there is a potential division by zero
138+
outs = jax.nn.softmax(outs)
137139
return outs, features
138140

139141
@bind(jax.jit, static_argnums=[4, 5, 6, 7])
@@ -165,7 +167,7 @@ def eval_attention_probe(params, encodings, labels, mask, n_heads: int, dropout:
165167
# encodings: (B, hw, dim)
166168
outs, _ = run_attention_probe(params, encodings, mask, n_heads, dropout, use_LN, use_softmax)
167169
if use_softmax: ## Multinoulli log likelihood for 1-of-K predictions
168-
L = -jnp.mean(jnp.sum(jnp.log(outs) * labels, axis=1, keepdims=True))
170+
L = -jnp.mean(jnp.sum(jnp.log(outs.clip(min=1e-5)) * labels, axis=1, keepdims=True))
169171
else: ## MSE for real-valued outputs
170172
L = jnp.mean(jnp.sum(jnp.square(outs - labels), axis=1, keepdims=True))
171173
return L, outs #, features

0 commit comments

Comments
 (0)