-
I'm trying to implement the Transformer architecture from scratch. I find three issues while training:
I'll attach code for your reference if needed: class SelfAttention(eqx.Module):
def __call__(self, query, key, value, mask):
scaled_dot_prod = query @ jnp.transpose(key, (0, 2, 1)) / jnp.sqrt(query.shape[-1])
scaled_dot_prod = mask + scaled_dot_prod
return (jax.nn.softmax(scaled_dot_prod) @ value)
def create_mask(arr):
return jnp.where(arr == 0, np.NINF, 0)
def loss(model, X, y, X_mask, y_mask, labels):
y_pred = jnp.log(predict(model, X, y, X_mask, y_mask))
y_pred = jnp.where(labels==0, 0, jnp.take(y_pred, labels, axis=-1))
count = jnp.count_nonzero(y_pred)
return -jnp.sum(y_pred)/count
with jax.disable_jit():
for e in range(EPOCHS):
total_loss = 0
num_batches = 0
total_tokens = 0
for i, (Xbt, ybt, labelbt) in enumerate(dataloader(Xtr, ytr, SEQ_LEN)):
total_tokens += len([token for seq in labelbt for token in list(filter(lambda x: x!=0, seq))])
Xbt, ybt, labelbt = [jnp.array(x) for x in (Xbt, ybt, labelbt)]
Xmask, ymask = [create_mask(x) for x in (Xbt, ybt)]
model, opt_state, batch_loss = step(model, opt_state, Xbt, ybt, Xmask, ymask, labelbt)
total_loss += batch_loss
num_batches += 1
if num_batches % 20 == 0:
print(f"Batches trained: {num_batches} | Avg. Batch loss: {total_loss/num_batches}")
epoch_loss = total_loss / num_batches
print(f"Epoch {e} | loss: {epoch_loss}") Error: def _softmax_deprecated(
478 x: ArrayLike,
479 axis: Optional[Union[int, tuple[int, ...]]] = -1,
480 where: Optional[ArrayLike] = None,
481 initial: Optional[ArrayLike] = None) -> Array:
482 x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True)
--> 483 unnormalized = jnp.exp(x - lax.stop_gradient(x_max))
484 result = unnormalized / jnp.sum(unnormalized, axis, where=where, keepdims=True)
485 if where is not None:
FloatingPointError: invalid value (nan) encountered in jit(sub) |
Beta Was this translation helpful? Give feedback.
Answered by
svarunid
Dec 19, 2023
Replies: 1 comment 2 replies
-
Hello - I think I answered your question on StackOverflow this morning here: https://stackoverflow.com/a/77680900/2937831 One question: could you say more about what leads you to believe that |
Beta Was this translation helpful? Give feedback.
2 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I resolved my issue. The error was occuring due to a bug in data preprocessing. Hence, the forward pass didn't go well.