|
105 | 105 | "Before using PaliGemma for the first time, you must request access to the model through Kaggle by completing the following steps:\n", |
106 | 106 | "\n", |
107 | 107 | "1. Log in to [Kaggle](https://www.kaggle.com), or create a new Kaggle account if you don't already have one.\n", |
108 | | - "1. Go to the [PaliGemma model card](https://www.kaggle.com/models/google/paligemma/) and click **Request Access**.\n", |
| 108 | + "1. Go to the [PaliGemma model card](https://www.kaggle.com/models/google/paligemma-2) and click **Request Access**.\n", |
109 | 109 | "1. Complete the consent form and accept the terms and conditions." |
110 | 110 | ] |
111 | 111 | }, |
|
287 | 287 | "source": [ |
288 | 288 | "### Download the model checkpoint\n", |
289 | 289 | "\n", |
290 | | - "PaliGemma includes several model variations. For this tutorial, you'll use the base [JAX/FLAX PaliGemma 3B weight model](https://www.kaggle.com/models/google/paligemma/jax/paligemma-3b-pt-224).\n", |
| 290 | + "PaliGemma includes several model variations. For this tutorial, you'll use the base [JAX/FLAX PaliGemma 3B weight model](https://www.kaggle.com/models/google/paligemma-2/jax/paligemma2-3b-pt-224).\n", |
291 | 291 | "\n", |
292 | 292 | "Download the model checkpoint from Kaggle by running the following code. This process takes several minutes to complete." |
293 | 293 | ] |
|
365 | 365 | "source": [ |
366 | 366 | "# Define model\n", |
367 | 367 | "\n", |
368 | | - "# IMPORTANT: Gemma-2 has a \"final_logits_softcap\" property, we set it to 0.0\n", |
| 368 | + "# IMPORTANT: Gemma-2 has a \"final_logits_softcap\" property. Set it to 0.0\n", |
369 | 369 | "# for better transfer results.\n", |
370 | 370 | "model_config = ml_collections.FrozenConfigDict({\n", |
371 | 371 | " \"llm\": {\"vocab_size\": 257_152, \"variant\": LLM_VARIANT, \"final_logits_softcap\": 0.0},\n", |
|
434 | 434 | "\n", |
435 | 435 | "# Loading all params in simultaneous - albeit much faster and more succinct -\n", |
436 | 436 | "# requires more RAM than the T4 colab runtimes have by default.\n", |
437 | | - "# Instead we do it param by param.\n", |
| 437 | + "# Instead, do it param by param.\n", |
438 | 438 | "params, treedef = jax.tree.flatten(params)\n", |
439 | 439 | "sharding_leaves = jax.tree.leaves(params_sharding)\n", |
440 | 440 | "trainable_leaves = jax.tree.leaves(trainable_mask)\n", |
|
708 | 708 | " targets = jax.nn.one_hot(txts[:, 1:], text_logits.shape[-1])\n", |
709 | 709 | "\n", |
710 | 710 | " # Compute the loss per example. i.e. the mean of per token pplx.\n", |
711 | | - " # Since each example has a different number of tokens we normalize it.\n", |
| 711 | + " # Since each example has a different number of tokens, normalize it.\n", |
712 | 712 | " token_pplx = jnp.sum(logp * targets, axis=-1) # sum across vocab_size.\n", |
713 | 713 | " example_loss = -jnp.sum(token_pplx * mask_loss, axis=-1) # sum across seq_len.\n", |
714 | 714 | " example_loss /= jnp.clip(jnp.sum(mask_loss, -1), 1) # weight by num of tokens.\n", |
|
0 commit comments