|
6 | 6 | "id": "G3MMAcssHTML" |
7 | 7 | }, |
8 | 8 | "source": [ |
9 | | - "<link rel=\"stylesheet\" href=\"/site-assets/css/gemma.css\">\n", |
10 | | - "<link rel=\"stylesheet\" href=\"https://fonts.googleapis.com/css2?family=Google+Symbols:opsz,wght,FILL,[email protected],100..700,0..1,-50..200\" />" |
| 9 | + "<link rel=\"stylesheet\" href=\"/site-assets/css/style.css\">\n", |
| 10 | + "<link rel=\"stylesheet\" href=\"/site-assets/css/gemma.css\">\n" |
11 | 11 | ] |
12 | 12 | }, |
13 | 13 | { |
|
59 | 59 | "<td>\n", |
60 | 60 | "<a target=\"_blank\" href=\"https://github.com/google/generative-ai-docs/blob/main/site/en/gemma/docs/paligemma/fine-tuning-paligemma.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n", |
61 | 61 | "</td>\n", |
62 | | - "</table>\n" |
63 | | - ] |
64 | | - }, |
65 | | - { |
66 | | - "cell_type": "markdown", |
67 | | - "metadata": { |
68 | | - "id": "wR53lePHuiP-" |
69 | | - }, |
70 | | - "source": [ |
| 62 | + "</table>\n", |
| 63 | + "\n", |
71 | 64 | "This notebook shows how to fine-tune [PaliGemma](https://ai.google.dev/gemma/docs/paligemma) on a vision-language task with [JAX](https://jax.readthedocs.io/en/latest/index.html). *Fine-tuning* is a process that can improve your model's performance on specific tasks or help the model adhere to specific output requirements when instructions aren't sufficient and you have a set of examples that demonstrate the outputs you want. Gemma-based models like PaliGemma require fine-tuning to produce expected results.\n", |
72 | 65 | "\n", |
73 | 66 | "### What's in this notebook\n", |
|
172 | 165 | "# vars as appropriate or make your credentials available in ~/.kaggle/kaggle.json\n", |
173 | 166 | "\n", |
174 | 167 | "os.environ[\"KAGGLE_USERNAME\"] = userdata.get('KAGGLE_USERNAME')\n", |
175 | | - "os.environ[\"KAGGLE_KEY\"] = userdata.get('KAGGLE_KEY')" |
| 168 | + "os.environ[\"KAGGLE_KEY\"] = userdata.get('KAGGLE_KEY')\n", |
| 169 | + "\n", |
| 170 | + "# The T4 runtime is tight on memory to finetune this model. Preallocate\n", |
| 171 | + "# all memory ahead of time to avoid out-of-memory due to fragmentation.\n", |
| 172 | + "os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"] = \"1.0\"" |
176 | 173 | ] |
177 | 174 | }, |
178 | 175 | { |
|
265 | 262 | "tf.config.set_visible_devices([], \"GPU\")\n", |
266 | 263 | "tf.config.set_visible_devices([], \"TPU\")\n", |
267 | 264 | "\n", |
268 | | - "backend = jax.lib.xla_bridge.get_backend()\n", |
| 265 | + "backend = jax.extend.backend.get_backend()\n", |
269 | 266 | "print(f\"JAX version: {jax.__version__}\")\n", |
270 | 267 | "print(f\"JAX platform: {backend.platform}\")\n", |
271 | 268 | "print(f\"JAX devices: {jax.device_count()}\")" |
|
292 | 289 | "\n", |
293 | 290 | "PaliGemma includes several model variations. For this tutorial, you'll use the base [JAX/FLAX PaliGemma 3B weight model](https://www.kaggle.com/models/google/paligemma/jax/paligemma-3b-pt-224).\n", |
294 | 291 | "\n", |
295 | | - "Download the `float16` version of the model checkpoint from Kaggle by running the following code. This process takes several minutes to complete." |
| 292 | + "Download the model checkpoint from Kaggle by running the following code. This process takes several minutes to complete." |
296 | 293 | ] |
297 | 294 | }, |
298 | 295 | { |
|
306 | 303 | "import os\n", |
307 | 304 | "import kagglehub\n", |
308 | 305 | "\n", |
309 | | - "MODEL_PATH = \"./pt_224_128.params.f16.npz\"\n", |
| 306 | + "# Use these for PaliGemma-2 3B 224px²\n", |
| 307 | + "LLM_VARIANT = \"gemma2_2b\"\n", |
| 308 | + "MODEL_PATH = \"./paligemma2-3b-pt-224.b16.npz\"\n", |
| 309 | + "KAGGLE_HANDLE = \"google/paligemma-2/jax/paligemma2-3b-pt-224\" # Path to fetch from Kaggle.\n", |
| 310 | + "\n", |
| 311 | + "# Use these for PaliGemma 1:\n", |
| 312 | + "# LLM_VARIANT = \"gemma_2b\"\n", |
| 313 | + "# MODEL_PATH = \"./paligemma-3b-pt-224.f16.npz\"\n", |
| 314 | + "# KAGGLE_HANDLE = \"google/paligemma/jax/paligemma-3b-pt-224\"\n", |
| 315 | + "\n", |
310 | 316 | "if not os.path.exists(MODEL_PATH):\n", |
311 | 317 | " print(\"Downloading the checkpoint from Kaggle, this could take a few minutes....\")\n", |
312 | | - " # Note: kaggle archive contains the same checkpoint in multiple formats.\n", |
313 | | - " # Download only the float16 model.\n", |
314 | | - " MODEL_PATH = kagglehub.model_download('google/paligemma/jax/paligemma-3b-pt-224', 'paligemma-3b-pt-224.f16.npz')\n", |
| 318 | + " MODEL_PATH = kagglehub.model_download(KAGGLE_HANDLE, MODEL_PATH)\n", |
315 | 319 | " print(f\"Model path: {MODEL_PATH}\")\n", |
316 | 320 | "\n", |
317 | 321 | "TOKENIZER_PATH = \"./paligemma_tokenizer.model\"\n", |
|
360 | 364 | "outputs": [], |
361 | 365 | "source": [ |
362 | 366 | "# Define model\n", |
| 367 | + "\n", |
| 368 | + "# IMPORTANT: Gemma-2 has a \"final_logits_softcap\" property, we set it to 0.0\n", |
| 369 | + "# for better transfer results.\n", |
363 | 370 | "model_config = ml_collections.FrozenConfigDict({\n", |
364 | | - " \"llm\": {\"vocab_size\": 257_152},\n", |
| 371 | + " \"llm\": {\"vocab_size\": 257_152, \"variant\": LLM_VARIANT, \"final_logits_softcap\": 0.0},\n", |
365 | 372 | " \"img\": {\"variant\": \"So400m/14\", \"pool_type\": \"none\", \"scan\": True, \"dtype_mm\": \"float16\"}\n", |
366 | 373 | "})\n", |
367 | 374 | "model = paligemma.Model(**model_config)\n", |
|
420 | 427 | "\n", |
421 | 428 | "@functools.partial(jax.jit, donate_argnums=(0,), static_argnums=(1,))\n", |
422 | 429 | "def maybe_cast_to_f32(params, trainable):\n", |
423 | | - " return jax.tree.map(lambda p, m: p.astype(jnp.float32) if m else p,\n", |
| 430 | + " # Cast others to float16, since some GPUs don't support bf16.\n", |
| 431 | + " return jax.tree.map(lambda p, m: p.astype(jnp.float32)\n", |
| 432 | + " if m else p.astype(jnp.float16),\n", |
424 | 433 | " params, trainable)\n", |
425 | 434 | "\n", |
426 | 435 | "# Loading all params in simultaneous - albeit much faster and more succinct -\n", |
|
0 commit comments