Skip to content

Commit eabab9f

Browse files
authored
Update fine-tuning-paligemma.ipynb
Removing second person ("we" language) that is causing lint error and failure to sync with DevSite.
1 parent 4b8c1ae commit eabab9f

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

site/en/gemma/docs/paligemma/fine-tuning-paligemma.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@
365365
"source": [
366366
"# Define model\n",
367367
"\n",
368-
"# IMPORTANT: Gemma-2 has a \"final_logits_softcap\" property, we set it to 0.0\n",
368+
"# IMPORTANT: Gemma-2 has a \"final_logits_softcap\" property. Set it to 0.0\n",
369369
"# for better transfer results.\n",
370370
"model_config = ml_collections.FrozenConfigDict({\n",
371371
" \"llm\": {\"vocab_size\": 257_152, \"variant\": LLM_VARIANT, \"final_logits_softcap\": 0.0},\n",
@@ -434,7 +434,7 @@
434434
"\n",
435435
"# Loading all params in simultaneous - albeit much faster and more succinct -\n",
436436
"# requires more RAM than the T4 colab runtimes have by default.\n",
437-
"# Instead we do it param by param.\n",
437+
"# Instead, do it param by param.\n",
438438
"params, treedef = jax.tree.flatten(params)\n",
439439
"sharding_leaves = jax.tree.leaves(params_sharding)\n",
440440
"trainable_leaves = jax.tree.leaves(trainable_mask)\n",
@@ -708,7 +708,7 @@
708708
" targets = jax.nn.one_hot(txts[:, 1:], text_logits.shape[-1])\n",
709709
"\n",
710710
" # Compute the loss per example. i.e. the mean of per token pplx.\n",
711-
" # Since each example has a different number of tokens we normalize it.\n",
711+
" # Since each example has a different number of tokens, normalize it.\n",
712712
" token_pplx = jnp.sum(logp * targets, axis=-1) # sum across vocab_size.\n",
713713
" example_loss = -jnp.sum(token_pplx * mask_loss, axis=-1) # sum across seq_len.\n",
714714
" example_loss /= jnp.clip(jnp.sum(mask_loss, -1), 1) # weight by num of tokens.\n",

0 commit comments

Comments
 (0)