|
365 | 365 | "source": [ |
366 | 366 | "# Define model\n", |
367 | 367 | "\n", |
368 | | - "# IMPORTANT: Gemma-2 has a \"final_logits_softcap\" property, we set it to 0.0\n", |
| 368 | + "# IMPORTANT: Gemma-2 has a \"final_logits_softcap\" property. Set it to 0.0\n", |
369 | 369 | "# for better transfer results.\n", |
370 | 370 | "model_config = ml_collections.FrozenConfigDict({\n", |
371 | 371 | " \"llm\": {\"vocab_size\": 257_152, \"variant\": LLM_VARIANT, \"final_logits_softcap\": 0.0},\n", |
|
434 | 434 | "\n", |
435 | 435 | "# Loading all params in simultaneous - albeit much faster and more succinct -\n", |
436 | 436 | "# requires more RAM than the T4 colab runtimes have by default.\n", |
437 | | - "# Instead we do it param by param.\n", |
| 437 | + "# Instead, do it param by param.\n", |
438 | 438 | "params, treedef = jax.tree.flatten(params)\n", |
439 | 439 | "sharding_leaves = jax.tree.leaves(params_sharding)\n", |
440 | 440 | "trainable_leaves = jax.tree.leaves(trainable_mask)\n", |
|
708 | 708 | " targets = jax.nn.one_hot(txts[:, 1:], text_logits.shape[-1])\n", |
709 | 709 | "\n", |
710 | 710 | " # Compute the loss per example. i.e. the mean of per token pplx.\n", |
711 | | - " # Since each example has a different number of tokens we normalize it.\n", |
| 711 | + " # Since each example has a different number of tokens, normalize it.\n", |
712 | 712 | " token_pplx = jnp.sum(logp * targets, axis=-1) # sum across vocab_size.\n", |
713 | 713 | " example_loss = -jnp.sum(token_pplx * mask_loss, axis=-1) # sum across seq_len.\n", |
714 | 714 | " example_loss /= jnp.clip(jnp.sum(mask_loss, -1), 1) # weight by num of tokens.\n", |
|
0 commit comments