@@ -133,7 +133,9 @@ def run_attention_probe(params, encodings, mask, n_heads: int, dropout: float =
133133 features = features + skip
134134 outs = jnp .matmul (features , Wy ) + by
135135 if use_softmax : ## apply softmax output nonlinearity
136- outs = softmax (outs )
136+ # NOTE: Viet: please check the softmax function, it might potentially
137+ # cause the gradient to be nan since there is a potential division by zero
138+ outs = jax .nn .softmax (outs )
137139 return outs , features
138140
139141@bind (jax .jit , static_argnums = [4 , 5 , 6 , 7 ])
@@ -165,7 +167,7 @@ def eval_attention_probe(params, encodings, labels, mask, n_heads: int, dropout:
165167 # encodings: (B, hw, dim)
166168 outs , _ = run_attention_probe (params , encodings , mask , n_heads , dropout , use_LN , use_softmax )
167169 if use_softmax : ## Multinoulli log likelihood for 1-of-K predictions
168- L = - jnp .mean (jnp .sum (jnp .log (outs ) * labels , axis = 1 , keepdims = True ))
170+ L = - jnp .mean (jnp .sum (jnp .log (outs . clip ( min = 1e-5 ) ) * labels , axis = 1 , keepdims = True ))
169171 else : ## MSE for real-valued outputs
170172 L = jnp .mean (jnp .sum (jnp .square (outs - labels ), axis = 1 , keepdims = True ))
171173 return L , outs #, features
0 commit comments