Skip to content

Commit 87e7cde

Browse files
authored
Updating PaliGemma notebooks (#543)
* Updating PaliGemma notebooks * Notebook format updates.
1 parent ce40377 commit 87e7cde

File tree

2 files changed

+151
-146
lines changed

2 files changed

+151
-146
lines changed

site/en/gemma/docs/paligemma/fine-tuning-paligemma.ipynb

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
"id": "G3MMAcssHTML"
77
},
88
"source": [
9-
"<link rel=\"stylesheet\" href=\"/site-assets/css/gemma.css\">\n",
10-
"<link rel=\"stylesheet\" href=\"https://fonts.googleapis.com/css2?family=Google+Symbols:opsz,wght,FILL,[email protected],100..700,0..1,-50..200\" />"
9+
"<link rel=\"stylesheet\" href=\"/site-assets/css/style.css\">\n",
10+
"<link rel=\"stylesheet\" href=\"/site-assets/css/gemma.css\">\n"
1111
]
1212
},
1313
{
@@ -59,15 +59,8 @@
5959
"<td>\n",
6060
"<a target=\"_blank\" href=\"https://github.com/google/generative-ai-docs/blob/main/site/en/gemma/docs/paligemma/fine-tuning-paligemma.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
6161
"</td>\n",
62-
"</table>\n"
63-
]
64-
},
65-
{
66-
"cell_type": "markdown",
67-
"metadata": {
68-
"id": "wR53lePHuiP-"
69-
},
70-
"source": [
62+
"</table>\n",
63+
"\n",
7164
"This notebook shows how to fine-tune [PaliGemma](https://ai.google.dev/gemma/docs/paligemma) on a vision-language task with [JAX](https://jax.readthedocs.io/en/latest/index.html). *Fine-tuning* is a process that can improve your model's performance on specific tasks or help the model adhere to specific output requirements when instructions aren't sufficient and you have a set of examples that demonstrate the outputs you want. Gemma-based models like PaliGemma require fine-tuning to produce expected results.\n",
7265
"\n",
7366
"### What's in this notebook\n",
@@ -172,7 +165,11 @@
172165
"# vars as appropriate or make your credentials available in ~/.kaggle/kaggle.json\n",
173166
"\n",
174167
"os.environ[\"KAGGLE_USERNAME\"] = userdata.get('KAGGLE_USERNAME')\n",
175-
"os.environ[\"KAGGLE_KEY\"] = userdata.get('KAGGLE_KEY')"
168+
"os.environ[\"KAGGLE_KEY\"] = userdata.get('KAGGLE_KEY')\n",
169+
"\n",
170+
"# The T4 runtime is tight on memory to finetune this model. Preallocate\n",
171+
"# all memory ahead of time to avoid out-of-memory due to fragmentation.\n",
172+
"os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"] = \"1.0\""
176173
]
177174
},
178175
{
@@ -265,7 +262,7 @@
265262
"tf.config.set_visible_devices([], \"GPU\")\n",
266263
"tf.config.set_visible_devices([], \"TPU\")\n",
267264
"\n",
268-
"backend = jax.lib.xla_bridge.get_backend()\n",
265+
"backend = jax.extend.backend.get_backend()\n",
269266
"print(f\"JAX version: {jax.__version__}\")\n",
270267
"print(f\"JAX platform: {backend.platform}\")\n",
271268
"print(f\"JAX devices: {jax.device_count()}\")"
@@ -292,7 +289,7 @@
292289
"\n",
293290
"PaliGemma includes several model variations. For this tutorial, you'll use the base [JAX/FLAX PaliGemma 3B weight model](https://www.kaggle.com/models/google/paligemma/jax/paligemma-3b-pt-224).\n",
294291
"\n",
295-
"Download the `float16` version of the model checkpoint from Kaggle by running the following code. This process takes several minutes to complete."
292+
"Download the model checkpoint from Kaggle by running the following code. This process takes several minutes to complete."
296293
]
297294
},
298295
{
@@ -306,12 +303,19 @@
306303
"import os\n",
307304
"import kagglehub\n",
308305
"\n",
309-
"MODEL_PATH = \"./pt_224_128.params.f16.npz\"\n",
306+
"# Use these for PaliGemma-2 3B 224px²\n",
307+
"LLM_VARIANT = \"gemma2_2b\"\n",
308+
"MODEL_PATH = \"./paligemma2-3b-pt-224.b16.npz\"\n",
309+
"KAGGLE_HANDLE = \"google/paligemma-2/jax/paligemma2-3b-pt-224\" # Path to fetch from Kaggle.\n",
310+
"\n",
311+
"# Use these for PaliGemma 1:\n",
312+
"# LLM_VARIANT = \"gemma_2b\"\n",
313+
"# MODEL_PATH = \"./paligemma-3b-pt-224.f16.npz\"\n",
314+
"# KAGGLE_HANDLE = \"google/paligemma/jax/paligemma-3b-pt-224\"\n",
315+
"\n",
310316
"if not os.path.exists(MODEL_PATH):\n",
311317
" print(\"Downloading the checkpoint from Kaggle, this could take a few minutes....\")\n",
312-
" # Note: kaggle archive contains the same checkpoint in multiple formats.\n",
313-
" # Download only the float16 model.\n",
314-
" MODEL_PATH = kagglehub.model_download('google/paligemma/jax/paligemma-3b-pt-224', 'paligemma-3b-pt-224.f16.npz')\n",
318+
" MODEL_PATH = kagglehub.model_download(KAGGLE_HANDLE, MODEL_PATH)\n",
315319
" print(f\"Model path: {MODEL_PATH}\")\n",
316320
"\n",
317321
"TOKENIZER_PATH = \"./paligemma_tokenizer.model\"\n",
@@ -360,8 +364,11 @@
360364
"outputs": [],
361365
"source": [
362366
"# Define model\n",
367+
"\n",
368+
"# IMPORTANT: Gemma-2 has a \"final_logits_softcap\" property, we set it to 0.0\n",
369+
"# for better transfer results.\n",
363370
"model_config = ml_collections.FrozenConfigDict({\n",
364-
" \"llm\": {\"vocab_size\": 257_152},\n",
371+
" \"llm\": {\"vocab_size\": 257_152, \"variant\": LLM_VARIANT, \"final_logits_softcap\": 0.0},\n",
365372
" \"img\": {\"variant\": \"So400m/14\", \"pool_type\": \"none\", \"scan\": True, \"dtype_mm\": \"float16\"}\n",
366373
"})\n",
367374
"model = paligemma.Model(**model_config)\n",
@@ -420,7 +427,9 @@
420427
"\n",
421428
"@functools.partial(jax.jit, donate_argnums=(0,), static_argnums=(1,))\n",
422429
"def maybe_cast_to_f32(params, trainable):\n",
423-
" return jax.tree.map(lambda p, m: p.astype(jnp.float32) if m else p,\n",
430+
" # Cast others to float16, since some GPUs don't support bf16.\n",
431+
" return jax.tree.map(lambda p, m: p.astype(jnp.float32)\n",
432+
" if m else p.astype(jnp.float16),\n",
424433
" params, trainable)\n",
425434
"\n",
426435
"# Loading all params in simultaneous - albeit much faster and more succinct -\n",

0 commit comments

Comments
 (0)