Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions site/en/gemma/docs/paligemma/fine-tuning-paligemma.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@
"Before using PaliGemma for the first time, you must request access to the model through Kaggle by completing the following steps:\n",
"\n",
"1. Log in to [Kaggle](https://www.kaggle.com), or create a new Kaggle account if you don't already have one.\n",
"1. Go to the [PaliGemma model card](https://www.kaggle.com/models/google/paligemma/) and click **Request Access**.\n",
"1. Go to the [PaliGemma model card](https://www.kaggle.com/models/google/paligemma-2) and click **Request Access**.\n",
"1. Complete the consent form and accept the terms and conditions."
]
},
Expand Down Expand Up @@ -287,7 +287,7 @@
"source": [
"### Download the model checkpoint\n",
"\n",
"PaliGemma includes several model variations. For this tutorial, you'll use the base [JAX/FLAX PaliGemma 3B weight model](https://www.kaggle.com/models/google/paligemma/jax/paligemma-3b-pt-224).\n",
"PaliGemma includes several model variations. For this tutorial, you'll use the base [JAX/FLAX PaliGemma 3B weight model](https://www.kaggle.com/models/google/paligemma-2/jax/paligemma2-3b-pt-224).\n",
"\n",
"Download the model checkpoint from Kaggle by running the following code. This process takes several minutes to complete."
]
Expand Down Expand Up @@ -365,7 +365,7 @@
"source": [
"# Define model\n",
"\n",
"# IMPORTANT: Gemma-2 has a \"final_logits_softcap\" property, we set it to 0.0\n",
"# IMPORTANT: Gemma-2 has a \"final_logits_softcap\" property. Set it to 0.0\n",
"# for better transfer results.\n",
"model_config = ml_collections.FrozenConfigDict({\n",
" \"llm\": {\"vocab_size\": 257_152, \"variant\": LLM_VARIANT, \"final_logits_softcap\": 0.0},\n",
Expand Down Expand Up @@ -434,7 +434,7 @@
"\n",
"# Loading all params in simultaneous - albeit much faster and more succinct -\n",
"# requires more RAM than the T4 colab runtimes have by default.\n",
"# Instead we do it param by param.\n",
"# Instead, do it param by param.\n",
"params, treedef = jax.tree.flatten(params)\n",
"sharding_leaves = jax.tree.leaves(params_sharding)\n",
"trainable_leaves = jax.tree.leaves(trainable_mask)\n",
Expand Down Expand Up @@ -708,7 +708,7 @@
" targets = jax.nn.one_hot(txts[:, 1:], text_logits.shape[-1])\n",
"\n",
" # Compute the loss per example. i.e. the mean of per token pplx.\n",
" # Since each example has a different number of tokens we normalize it.\n",
" # Since each example has a different number of tokens, normalize it.\n",
" token_pplx = jnp.sum(logp * targets, axis=-1) # sum across vocab_size.\n",
" example_loss = -jnp.sum(token_pplx * mask_loss, axis=-1) # sum across seq_len.\n",
" example_loss /= jnp.clip(jnp.sum(mask_loss, -1), 1) # weight by num of tokens.\n",
Expand Down
Loading