Skip to content

Commit 0716c28

Browse files
committed
Updating PaliGemma notebooks
1 parent ce40377 commit 0716c28

File tree

2 files changed

+215
-167
lines changed

2 files changed

+215
-167
lines changed

site/en/gemma/docs/paligemma/fine-tuning-paligemma.ipynb

Lines changed: 62 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
"id": "G3MMAcssHTML"
77
},
88
"source": [
9-
"<link rel=\"stylesheet\" href=\"/site-assets/css/gemma.css\">\n",
10-
"<link rel=\"stylesheet\" href=\"https://fonts.googleapis.com/css2?family=Google+Symbols:opsz,wght,FILL,[email protected],100..700,0..1,-50..200\" />"
9+
"<link rel=\"stylesheet\" href=\"/site-assets/css/style.css\">\n",
10+
"<link rel=\"stylesheet\" href=\"/site-assets/css/gemma.css\">\n"
1111
]
1212
},
1313
{
@@ -59,15 +59,8 @@
5959
"<td>\n",
6060
"<a target=\"_blank\" href=\"https://github.com/google/generative-ai-docs/blob/main/site/en/gemma/docs/paligemma/fine-tuning-paligemma.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
6161
"</td>\n",
62-
"</table>\n"
63-
]
64-
},
65-
{
66-
"cell_type": "markdown",
67-
"metadata": {
68-
"id": "wR53lePHuiP-"
69-
},
70-
"source": [
62+
"</table>\n",
63+
"\n",
7164
"This notebook shows how to fine-tune [PaliGemma](https://ai.google.dev/gemma/docs/paligemma) on a vision-language task with [JAX](https://jax.readthedocs.io/en/latest/index.html). *Fine-tuning* is a process that can improve your model's performance on specific tasks or help the model adhere to specific output requirements when instructions aren't sufficient and you have a set of examples that demonstrate the outputs you want. Gemma-based models like PaliGemma require fine-tuning to produce expected results.\n",
7265
"\n",
7366
"### What's in this notebook\n",
@@ -128,7 +121,8 @@
128121
"\n",
129122
"To generate a Kaggle API key, open your [**Settings** page in Kaggle](https://www.kaggle.com/settings) and click **Create New Token**. This triggers the download of a `kaggle.json` file containing your API credentials.\n",
130123
"\n",
131-
"Then, in Colab, select **Secrets** (🔑) in the left pane and add your Kaggle username and Kaggle API key. Store your username under the name `KAGGLE_USERNAME` and your API key under the name `KAGGLE_KEY`.\n"
124+
"Then, in Colab, select **Secrets** (🔑) in the left pane and add your Kaggle username and Kaggle API key. Store your username under the name `KAGGLE_USERNAME` and your API key under the name `KAGGLE_KEY`.\n",
125+
"\n"
132126
]
133127
},
134128
{
@@ -172,7 +166,11 @@
172166
"# vars as appropriate or make your credentials available in ~/.kaggle/kaggle.json\n",
173167
"\n",
174168
"os.environ[\"KAGGLE_USERNAME\"] = userdata.get('KAGGLE_USERNAME')\n",
175-
"os.environ[\"KAGGLE_KEY\"] = userdata.get('KAGGLE_KEY')"
169+
"os.environ[\"KAGGLE_KEY\"] = userdata.get('KAGGLE_KEY')\n",
170+
"\n",
171+
"# The T4 runtime is tight on memory to finetune this model. Preallocate\n",
172+
"# all memory ahead of time to avoid out-of-memory due to fragmentation.\n",
173+
"os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"] = \"1.0\""
176174
]
177175
},
178176
{
@@ -265,7 +263,7 @@
265263
"tf.config.set_visible_devices([], \"GPU\")\n",
266264
"tf.config.set_visible_devices([], \"TPU\")\n",
267265
"\n",
268-
"backend = jax.lib.xla_bridge.get_backend()\n",
266+
"backend = jax.extend.backend.get_backend()\n",
269267
"print(f\"JAX version: {jax.__version__}\")\n",
270268
"print(f\"JAX platform: {backend.platform}\")\n",
271269
"print(f\"JAX devices: {jax.device_count()}\")"
@@ -292,7 +290,7 @@
292290
"\n",
293291
"PaliGemma includes several model variations. For this tutorial, you'll use the base [JAX/FLAX PaliGemma 3B weight model](https://www.kaggle.com/models/google/paligemma/jax/paligemma-3b-pt-224).\n",
294292
"\n",
295-
"Download the `float16` version of the model checkpoint from Kaggle by running the following code. This process takes several minutes to complete."
293+
"Download the model checkpoint from Kaggle by running the following code. This process takes several minutes to complete."
296294
]
297295
},
298296
{
@@ -306,12 +304,19 @@
306304
"import os\n",
307305
"import kagglehub\n",
308306
"\n",
309-
"MODEL_PATH = \"./pt_224_128.params.f16.npz\"\n",
307+
"# Use these for PaliGemma-2 3B 224px²\n",
308+
"LLM_VARIANT = \"gemma2_2b\"\n",
309+
"MODEL_PATH = \"./paligemma2-3b-pt-224.b16.npz\"\n",
310+
"KAGGLE_HANDLE = \"google/paligemma-2/jax/paligemma2-3b-pt-224\" # Path to fetch from Kaggle.\n",
311+
"\n",
312+
"# Use these for PaliGemma 1:\n",
313+
"# LLM_VARIANT = \"gemma_2b\"\n",
314+
"# MODEL_PATH = \"./paligemma-3b-pt-224.f16.npz\"\n",
315+
"# KAGGLE_HANDLE = \"google/paligemma/jax/paligemma-3b-pt-224\"\n",
316+
"\n",
310317
"if not os.path.exists(MODEL_PATH):\n",
311318
" print(\"Downloading the checkpoint from Kaggle, this could take a few minutes....\")\n",
312-
" # Note: kaggle archive contains the same checkpoint in multiple formats.\n",
313-
" # Download only the float16 model.\n",
314-
" MODEL_PATH = kagglehub.model_download('google/paligemma/jax/paligemma-3b-pt-224', 'paligemma-3b-pt-224.f16.npz')\n",
319+
" MODEL_PATH = kagglehub.model_download(KAGGLE_HANDLE, MODEL_PATH)\n",
315320
" print(f\"Model path: {MODEL_PATH}\")\n",
316321
"\n",
317322
"TOKENIZER_PATH = \"./paligemma_tokenizer.model\"\n",
@@ -360,8 +365,11 @@
360365
"outputs": [],
361366
"source": [
362367
"# Define model\n",
368+
"\n",
369+
"# IMPORTANT: Gemma-2 has a \"final_logits_softcap\" property, we set it to 0.0\n",
370+
"# for better transfer results.\n",
363371
"model_config = ml_collections.FrozenConfigDict({\n",
364-
" \"llm\": {\"vocab_size\": 257_152},\n",
372+
" \"llm\": {\"vocab_size\": 257_152, \"variant\": LLM_VARIANT, \"final_logits_softcap\": 0.0},\n",
365373
" \"img\": {\"variant\": \"So400m/14\", \"pool_type\": \"none\", \"scan\": True, \"dtype_mm\": \"float16\"}\n",
366374
"})\n",
367375
"model = paligemma.Model(**model_config)\n",
@@ -420,7 +428,9 @@
420428
"\n",
421429
"@functools.partial(jax.jit, donate_argnums=(0,), static_argnums=(1,))\n",
422430
"def maybe_cast_to_f32(params, trainable):\n",
423-
" return jax.tree.map(lambda p, m: p.astype(jnp.float32) if m else p,\n",
431+
" # Cast others to float16, since some GPUs don't support bf16.\n",
432+
" return jax.tree.map(lambda p, m: p.astype(jnp.float32)\n",
433+
" if m else p.astype(jnp.float16),\n",
424434
" params, trainable)\n",
425435
"\n",
426436
"# Loading all params in simultaneous - albeit much faster and more succinct -\n",
@@ -492,7 +502,7 @@
492502
"\n",
493503
" image = tf.constant(image)\n",
494504
" image = tf.image.resize(image, (size, size), method='bilinear', antialias=True)\n",
495-
" return image.numpy() / 127.5 - 1.0 # [0, 255]->[-1,1]\n",
505+
" return image.numpy() / 127.5 - 1.0 # [0, 255]-\u003e[-1,1]\n",
496506
"\n",
497507
"def preprocess_tokens(prefix, suffix=None, seqlen=None):\n",
498508
" # Model has been trained to handle tokenized text composed of a prefix with\n",
@@ -632,12 +642,12 @@
632642
" return f\"data:image/jpeg;base64,{image_b64}\"\n",
633643
"\n",
634644
"def render_example(image, caption):\n",
635-
" image = ((image + 1)/2 * 255).astype(np.uint8) # [-1,1] -> [0, 255]\n",
645+
" image = ((image + 1)/2 * 255).astype(np.uint8) # [-1,1] -\u003e [0, 255]\n",
636646
" return f\"\"\"\n",
637-
" <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
638-
" <img style=\"width:128px; height:128px;\" src=\"{render_inline(image, resize=(64,64))}\" />\n",
639-
" <p style=\"width:256px; margin:10px; font-size:small;\">{html.escape(caption)}</p>\n",
640-
" </div>\n",
647+
" \u003cdiv style=\"display: inline-flex; align-items: center; justify-content: center;\"\u003e\n",
648+
" \u003cimg style=\"width:128px; height:128px;\" src=\"{render_inline(image, resize=(64,64))}\" /\u003e\n",
649+
" \u003cp style=\"width:256px; margin:10px; font-size:small;\"\u003e{html.escape(caption)}\u003c/p\u003e\n",
650+
" \u003c/div\u003e\n",
641651
" \"\"\"\n",
642652
"\n",
643653
"html_out = \"\"\n",
@@ -754,7 +764,7 @@
754764
" # Append to html output.\n",
755765
" for example, response in zip(examples, responses):\n",
756766
" outputs.append((example[\"image\"], response))\n",
757-
" if num_examples and len(outputs) >= num_examples:\n",
767+
" if num_examples and len(outputs) \u003e= num_examples:\n",
758768
" return outputs"
759769
]
760770
},
@@ -862,14 +872,36 @@
862872
],
863873
"metadata": {
864874
"colab": {
865-
"name": "fine-tuning-paligemma.ipynb",
875+
"gpuType": "T4",
876+
"last_runtime": {
877+
"build_target": "//learning/grp/tools/ml_python:ml_notebook",
878+
"kind": "private"
879+
},
880+
"private_outputs": true,
881+
"provenance": [
882+
{
883+
"file_id": "17AiK8gRY7oiquQGkBH0d08PFQo3Kyx1I",
884+
"timestamp": 1715287187925
885+
},
886+
{
887+
"file_id": "1qZlJfPyfKRrNcz2shxQ93HnnE5Ge1LLn",
888+
"timestamp": 1715019972450
889+
},
890+
{
891+
"file_id": "1JFnlD2kSiTNexdPw_NYRtuW6uuSTI0kD",
892+
"timestamp": 1714585741026
893+
}
894+
],
866895
"toc_visible": true
867896
},
868897
"kernelspec": {
869898
"display_name": "Python 3",
870899
"name": "python3"
900+
},
901+
"language_info": {
902+
"name": "python"
871903
}
872904
},
873905
"nbformat": 4,
874906
"nbformat_minor": 0
875-
}
907+
}

0 commit comments

Comments
 (0)