Skip to content

Commit 5d19368

Browse files
authored
Fixing Gemma PyTorch docs (#601)
1 parent eee21a7 commit 5d19368

File tree

1 file changed

+52
-24
lines changed

1 file changed

+52
-24
lines changed

site/en/gemma/docs/core/pytorch_gemma.ipynb

Lines changed: 52 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
"<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
5151
" <td>\n",
5252
" <a target=\"_blank\" href=\"https://ai.google.dev/gemma/docs/core/pytorch_gemma\"><img src=\"https://ai.google.dev/static/site-assets/images/docs/notebook-site-button.png\" height=\"32\" width=\"32\" />View on ai.google.dev</a>\n",
53+
" </td>\n",
5354
" <td>\n",
5455
" <a target=\"_blank\" href=\"https://colab.research.google.com/github/google/generative-ai-docs/blob/main/site/en/gemma/docs/core/pytorch_gemma.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
5556
" </td>\n",
@@ -67,7 +68,10 @@
6768
"source": [
6869
"# Run Gemma using PyTorch\n",
6970
"\n",
70-
"This guide shows you how to run Gemma using the PyTorch framework, including how to use image data for prompting Gemma release 3 and later models. For more details on the Gemma PyTorch implementation, see the project repository [README](https://github.com/google/gemma_pytorch)."
71+
"This guide shows you how to run Gemma using the PyTorch framework, including how\n",
72+
"to use image data for prompting Gemma release 3 and later models. For more\n",
73+
"details on the Gemma PyTorch implementation, see the project repository\n",
74+
"[README](https://github.com/google/gemma_pytorch)."
7175
]
7276
},
7377
{
@@ -78,7 +82,9 @@
7882
"source": [
7983
"## Setup\n",
8084
"\n",
81-
"The following sections explain how to set up your development environment, including how get access to Gemma models for downloading from Kaggle, setting authentication variables, installing dependencies, and importing packages."
85+
"The following sections explain how to set up your development environment,\n",
86+
"including how get access to Gemma models for downloading from Kaggle, setting\n",
87+
"authentication variables, installing dependencies, and importing packages."
8288
]
8389
},
8490
{
@@ -89,7 +95,12 @@
8995
"source": [
9096
"### System requirements\n",
9197
"\n",
92-
"This Gemma Pytorch library requires GPU or TPU processors to run the Gemma model. The standard Colab CPU Python runtime and T4 GPU Python runtime are sufficient for running Gemma 1B, 2B, and 4B size models. For advanced use cases for other GPUs or TPU, please refer to [README](https://github.com/google/gemma_pytorch/blob/main/README.md) in the Gemma PyTorch repo."
98+
"This Gemma Pytorch library requires GPU or TPU processors to run the Gemma \n",
99+
"model. The standard Colab CPU Python runtime and T4 GPU Python runtime are\n",
100+
"sufficient for running Gemma 1B, 2B, and 4B size models. For advanced use cases\n",
101+
"for other GPUs or TPU, please refer to the\n",
102+
"[README](https://github.com/google/gemma_pytorch/blob/main/README.md) in the\n",
103+
"Gemma PyTorch repo."
93104
]
94105
},
95106
{
@@ -100,13 +111,16 @@
100111
"source": [
101112
"### Get access to Gemma on Kaggle\n",
102113
"\n",
103-
"To complete this tutorial, you first need to follow the setup instructions at [Gemma setup](https://ai.google.dev/gemma/docs/setup), which show you how to do the following:\n",
114+
"To complete this tutorial, you first need to follow the setup instructions at\n",
115+
"[Gemma setup](https://ai.google.dev/gemma/docs/setup), which show you how to do\n",
116+
"the following:\n",
104117
"\n",
105-
"* Get access to Gemma on [kaggle.com](https://www.kaggle.com/models/google/gemma/).\n",
118+
"* Get access to Gemma on [Kaggle](https://www.kaggle.com/models/google/gemma/).\n",
106119
"* Select a Colab runtime with sufficient resources to run the Gemma model.\n",
107120
"* Generate and configure a Kaggle username and API key.\n",
108121
"\n",
109-
"After you've completed the Gemma setup, move on to the next section, where you'll set environment variables for your Colab environment."
122+
"After you've completed the Gemma setup, move on to the next section, where\n",
123+
"you'll set environment variables for your Colab environment."
110124
]
111125
},
112126
{
@@ -117,7 +131,8 @@
117131
"source": [
118132
"### Set environment variables\n",
119133
"\n",
120-
"Set environment variables for `KAGGLE_USERNAME` and `KAGGLE_KEY`. When prompted with the \"Grant access?\" messages, agree to provide secret access."
134+
"Set environment variables for `KAGGLE_USERNAME` and `KAGGLE_KEY`. When prompted\n",
135+
"with the \"Grant access?\" messages, agree to provide secret access."
121136
]
122137
},
123138
{
@@ -192,10 +207,7 @@
192207
"# Choose variant and machine type\n",
193208
"VARIANT = '4b-it' #@param ['1b','1b-it','4b','4b-it','12b','12b-it','27b','27b-it']\n",
194209
"MACHINE_TYPE = 'cuda' #@param ['cuda', 'cpu']\n",
195-
"\n",
196-
"CONFIG = VARIANT[:2]\n",
197-
"if CONFIG == '4b':\n",
198-
" CONFIG = '4b-v1'"
210+
"CONFIG = VARIANT.split('-')[0]"
199211
]
200212
},
201213
{
@@ -246,7 +258,8 @@
246258
"source": [
247259
"## Configure the run environment\n",
248260
"\n",
249-
"The following sections explain how to prepare an PyTorch environment for running Gemma."
261+
"The following sections explain how to prepare a PyTorch environment for running\n",
262+
"Gemma."
250263
]
251264
},
252265
{
@@ -257,7 +270,8 @@
257270
"source": [
258271
"### Prepare the PyTorch run environment\n",
259272
"\n",
260-
"Prepare the PyTorch model execution environment by cloning the Gemma Pytorch repository."
273+
"Prepare the PyTorch model execution environment by cloning the Gemma Pytorch\n",
274+
"repository."
261275
]
262276
},
263277
{
@@ -321,7 +335,8 @@
321335
"source": [
322336
"### Set the model configuration\n",
323337
"\n",
324-
"Before you run the model, you must set some configuration parameters, including the Gemma variant, tokenizer and quantization level."
338+
"Before you run the model, you must set some configuration parameters, including\n",
339+
"the Gemma variant, tokenizer and quantization level."
325340
]
326341
},
327342
{
@@ -333,7 +348,7 @@
333348
"outputs": [],
334349
"source": [
335350
"# Set up model config.\n",
336-
"model_config = get_model_config(VARIANT)\n",
351+
"model_config = get_model_config(CONFIG)\n",
337352
"model_config.dtype = \"float32\" if MACHINE_TYPE == \"cpu\" else \"float16\"\n",
338353
"model_config.tokenizer = tokenizer_path"
339354
]
@@ -418,8 +433,8 @@
418433
"- `<start_of_image>`: tag for image data input\n",
419434
"- `<end_of_turn><eos>`: end of dialog turn\n",
420435
"\n",
421-
"For more information, read about prompt formatting for instruction tuned Gemma models\n",
422-
"[here](https://ai.google.dev/gemma/core/prompt-structure.\n"
436+
"For more information, read about prompt formatting for instruction tuned Gemma\n",
437+
"models [here](https://ai.google.dev/gemma/core/prompt-structure).\n"
423438
]
424439
},
425440
{
@@ -430,7 +445,9 @@
430445
"source": [
431446
"### Generate text with text\n",
432447
"\n",
433-
"The following is a sample code snippet demonstrating how to format a prompt for an instruction-tuned Gemma model using user and model chat templates in a multi-turn conversation."
448+
"The following is a sample code snippet demonstrating how to format a prompt for\n",
449+
"an instruction-tuned Gemma model using user and model chat templates in a\n",
450+
"multi-turn conversation."
434451
]
435452
},
436453
{
@@ -530,7 +547,8 @@
530547
"source": [
531548
"### Generate text with images\n",
532549
"\n",
533-
"With Gemma release 3 and later, you can use images with your prompt. The following example shows you how to include visual data with your prompt."
550+
"With Gemma release 3 and later, you can use images with your prompt. The\n",
551+
"following example shows you how to include visual data with your prompt."
534552
]
535553
},
536554
{
@@ -551,13 +569,21 @@
551569
" contents = io.BytesIO(requests.get(url).content)\n",
552570
" return PIL.Image.open(contents)\n",
553571
"\n",
554-
"image_url = 'https://storage.googleapis.com/keras-cv/models/paligemma/cow_beach_1.png'\n",
555-
"image = read_image(image_url)\n",
572+
"image = read_image(\n",
573+
" 'https://storage.googleapis.com/keras-cv/models/paligemma/cow_beach_1.png'\n",
574+
")\n",
556575
"\n",
557576
"print(model.generate(\n",
558-
" [['<start_of_turn>user\\n',image, 'What animal is in this image?<end_of_turn>\\n', '<start_of_turn>model\\n']],\n",
577+
" [\n",
578+
" [\n",
579+
" '<start_of_turn>user\\n',\n",
580+
" image,\n",
581+
" 'What animal is in this image?<end_of_turn>\\n',\n",
582+
" '<start_of_turn>model\\n'\n",
583+
" ]\n",
584+
" ],\n",
559585
" device=device,\n",
560-
" output_len=OUTPUT_LEN,\n",
586+
" output_len=256,\n",
561587
"))"
562588
]
563589
},
@@ -570,7 +596,9 @@
570596
"## Learn more\n",
571597
"\n",
572598
"Now that you have learned how to use Gemma in Pytorch, you can explore the many\n",
573-
"other things that Gemma can do in [ai.google.dev/gemma](https://ai.google.dev/gemma).\n",
599+
"other things that Gemma can do in\n",
600+
"[ai.google.dev/gemma](https://ai.google.dev/gemma). \n",
601+
"\n",
574602
"See also these other related resources:\n",
575603
"\n",
576604
"- [Gemma core models overview](https://ai.google.dev/gemma/docs/core)\n",

0 commit comments

Comments
 (0)