Skip to content

Commit e94e728

Browse files
Update links to PaliGemma 2 in fine-tuning docs (#545)
* Update links to PaliGemma 2 * Update fine-tuning-paligemma.ipynb Removing second person ("we" language) that is causing lint error and failure to sync with DevSite. --------- Co-authored-by: Omar Sanseviero <[email protected]> Co-authored-by: Joe Fernandez <[email protected]>
1 parent 87e7cde commit e94e728

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

site/en/gemma/docs/paligemma/fine-tuning-paligemma.ipynb

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@
105105
"Before using PaliGemma for the first time, you must request access to the model through Kaggle by completing the following steps:\n",
106106
"\n",
107107
"1. Log in to [Kaggle](https://www.kaggle.com), or create a new Kaggle account if you don't already have one.\n",
108-
"1. Go to the [PaliGemma model card](https://www.kaggle.com/models/google/paligemma/) and click **Request Access**.\n",
108+
"1. Go to the [PaliGemma model card](https://www.kaggle.com/models/google/paligemma-2) and click **Request Access**.\n",
109109
"1. Complete the consent form and accept the terms and conditions."
110110
]
111111
},
@@ -287,7 +287,7 @@
287287
"source": [
288288
"### Download the model checkpoint\n",
289289
"\n",
290-
"PaliGemma includes several model variations. For this tutorial, you'll use the base [JAX/FLAX PaliGemma 3B weight model](https://www.kaggle.com/models/google/paligemma/jax/paligemma-3b-pt-224).\n",
290+
"PaliGemma includes several model variations. For this tutorial, you'll use the base [JAX/FLAX PaliGemma 3B weight model](https://www.kaggle.com/models/google/paligemma-2/jax/paligemma2-3b-pt-224).\n",
291291
"\n",
292292
"Download the model checkpoint from Kaggle by running the following code. This process takes several minutes to complete."
293293
]
@@ -365,7 +365,7 @@
365365
"source": [
366366
"# Define model\n",
367367
"\n",
368-
"# IMPORTANT: Gemma-2 has a \"final_logits_softcap\" property, we set it to 0.0\n",
368+
"# IMPORTANT: Gemma-2 has a \"final_logits_softcap\" property. Set it to 0.0\n",
369369
"# for better transfer results.\n",
370370
"model_config = ml_collections.FrozenConfigDict({\n",
371371
" \"llm\": {\"vocab_size\": 257_152, \"variant\": LLM_VARIANT, \"final_logits_softcap\": 0.0},\n",
@@ -434,7 +434,7 @@
434434
"\n",
435435
"# Loading all params in simultaneous - albeit much faster and more succinct -\n",
436436
"# requires more RAM than the T4 colab runtimes have by default.\n",
437-
"# Instead we do it param by param.\n",
437+
"# Instead, do it param by param.\n",
438438
"params, treedef = jax.tree.flatten(params)\n",
439439
"sharding_leaves = jax.tree.leaves(params_sharding)\n",
440440
"trainable_leaves = jax.tree.leaves(trainable_mask)\n",
@@ -708,7 +708,7 @@
708708
" targets = jax.nn.one_hot(txts[:, 1:], text_logits.shape[-1])\n",
709709
"\n",
710710
" # Compute the loss per example. i.e. the mean of per token pplx.\n",
711-
" # Since each example has a different number of tokens we normalize it.\n",
711+
" # Since each example has a different number of tokens, normalize it.\n",
712712
" token_pplx = jnp.sum(logp * targets, axis=-1) # sum across vocab_size.\n",
713713
" example_loss = -jnp.sum(token_pplx * mask_loss, axis=-1) # sum across seq_len.\n",
714714
" example_loss /= jnp.clip(jnp.sum(mask_loss, -1), 1) # weight by num of tokens.\n",

0 commit comments

Comments
 (0)