Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 52 additions & 24 deletions site/en/gemma/docs/core/pytorch_gemma.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
"<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://ai.google.dev/gemma/docs/core/pytorch_gemma\"><img src=\"https://ai.google.dev/static/site-assets/images/docs/notebook-site-button.png\" height=\"32\" width=\"32\" />View on ai.google.dev</a>\n",
" </td>\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://colab.research.google.com/github/google/generative-ai-docs/blob/main/site/en/gemma/docs/core/pytorch_gemma.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
" </td>\n",
Expand All @@ -67,7 +68,10 @@
"source": [
"# Run Gemma using PyTorch\n",
"\n",
"This guide shows you how to run Gemma using the PyTorch framework, including how to use image data for prompting Gemma release 3 and later models. For more details on the Gemma PyTorch implementation, see the project repository [README](https://github.com/google/gemma_pytorch)."
"This guide shows you how to run Gemma using the PyTorch framework, including how\n",
"to use image data for prompting Gemma release 3 and later models. For more\n",
"details on the Gemma PyTorch implementation, see the project repository\n",
"[README](https://github.com/google/gemma_pytorch)."
]
},
{
Expand All @@ -78,7 +82,9 @@
"source": [
"## Setup\n",
"\n",
"The following sections explain how to set up your development environment, including how get access to Gemma models for downloading from Kaggle, setting authentication variables, installing dependencies, and importing packages."
"The following sections explain how to set up your development environment,\n",
"including how get access to Gemma models for downloading from Kaggle, setting\n",
"authentication variables, installing dependencies, and importing packages."
]
},
{
Expand All @@ -89,7 +95,12 @@
"source": [
"### System requirements\n",
"\n",
"This Gemma Pytorch library requires GPU or TPU processors to run the Gemma model. The standard Colab CPU Python runtime and T4 GPU Python runtime are sufficient for running Gemma 1B, 2B, and 4B size models. For advanced use cases for other GPUs or TPU, please refer to [README](https://github.com/google/gemma_pytorch/blob/main/README.md) in the Gemma PyTorch repo."
"This Gemma Pytorch library requires GPU or TPU processors to run the Gemma \n",
"model. The standard Colab CPU Python runtime and T4 GPU Python runtime are\n",
"sufficient for running Gemma 1B, 2B, and 4B size models. For advanced use cases\n",
"for other GPUs or TPU, please refer to the\n",
"[README](https://github.com/google/gemma_pytorch/blob/main/README.md) in the\n",
"Gemma PyTorch repo."
]
},
{
Expand All @@ -100,13 +111,16 @@
"source": [
"### Get access to Gemma on Kaggle\n",
"\n",
"To complete this tutorial, you first need to follow the setup instructions at [Gemma setup](https://ai.google.dev/gemma/docs/setup), which show you how to do the following:\n",
"To complete this tutorial, you first need to follow the setup instructions at\n",
"[Gemma setup](https://ai.google.dev/gemma/docs/setup), which show you how to do\n",
"the following:\n",
"\n",
"* Get access to Gemma on [kaggle.com](https://www.kaggle.com/models/google/gemma/).\n",
"* Get access to Gemma on [Kaggle](https://www.kaggle.com/models/google/gemma/).\n",
"* Select a Colab runtime with sufficient resources to run the Gemma model.\n",
"* Generate and configure a Kaggle username and API key.\n",
"\n",
"After you've completed the Gemma setup, move on to the next section, where you'll set environment variables for your Colab environment."
"After you've completed the Gemma setup, move on to the next section, where\n",
"you'll set environment variables for your Colab environment."
]
},
{
Expand All @@ -117,7 +131,8 @@
"source": [
"### Set environment variables\n",
"\n",
"Set environment variables for `KAGGLE_USERNAME` and `KAGGLE_KEY`. When prompted with the \"Grant access?\" messages, agree to provide secret access."
"Set environment variables for `KAGGLE_USERNAME` and `KAGGLE_KEY`. When prompted\n",
"with the \"Grant access?\" messages, agree to provide secret access."
]
},
{
Expand Down Expand Up @@ -192,10 +207,7 @@
"# Choose variant and machine type\n",
"VARIANT = '4b-it' #@param ['1b','1b-it','4b','4b-it','12b','12b-it','27b','27b-it']\n",
"MACHINE_TYPE = 'cuda' #@param ['cuda', 'cpu']\n",
"\n",
"CONFIG = VARIANT[:2]\n",
"if CONFIG == '4b':\n",
" CONFIG = '4b-v1'"
"CONFIG = VARIANT.split('-')[0]"
]
},
{
Expand Down Expand Up @@ -246,7 +258,8 @@
"source": [
"## Configure the run environment\n",
"\n",
"The following sections explain how to prepare an PyTorch environment for running Gemma."
"The following sections explain how to prepare a PyTorch environment for running\n",
"Gemma."
]
},
{
Expand All @@ -257,7 +270,8 @@
"source": [
"### Prepare the PyTorch run environment\n",
"\n",
"Prepare the PyTorch model execution environment by cloning the Gemma Pytorch repository."
"Prepare the PyTorch model execution environment by cloning the Gemma Pytorch\n",
"repository."
]
},
{
Expand Down Expand Up @@ -321,7 +335,8 @@
"source": [
"### Set the model configuration\n",
"\n",
"Before you run the model, you must set some configuration parameters, including the Gemma variant, tokenizer and quantization level."
"Before you run the model, you must set some configuration parameters, including\n",
"the Gemma variant, tokenizer and quantization level."
]
},
{
Expand All @@ -333,7 +348,7 @@
"outputs": [],
"source": [
"# Set up model config.\n",
"model_config = get_model_config(VARIANT)\n",
"model_config = get_model_config(CONFIG)\n",
"model_config.dtype = \"float32\" if MACHINE_TYPE == \"cpu\" else \"float16\"\n",
"model_config.tokenizer = tokenizer_path"
]
Expand Down Expand Up @@ -418,8 +433,8 @@
"- `<start_of_image>`: tag for image data input\n",
"- `<end_of_turn><eos>`: end of dialog turn\n",
"\n",
"For more information, read about prompt formatting for instruction tuned Gemma models\n",
"[here](https://ai.google.dev/gemma/core/prompt-structure.\n"
"For more information, read about prompt formatting for instruction tuned Gemma\n",
"models [here](https://ai.google.dev/gemma/core/prompt-structure).\n"
]
},
{
Expand All @@ -430,7 +445,9 @@
"source": [
"### Generate text with text\n",
"\n",
"The following is a sample code snippet demonstrating how to format a prompt for an instruction-tuned Gemma model using user and model chat templates in a multi-turn conversation."
"The following is a sample code snippet demonstrating how to format a prompt for\n",
"an instruction-tuned Gemma model using user and model chat templates in a\n",
"multi-turn conversation."
]
},
{
Expand Down Expand Up @@ -530,7 +547,8 @@
"source": [
"### Generate text with images\n",
"\n",
"With Gemma release 3 and later, you can use images with your prompt. The following example shows you how to include visual data with your prompt."
"With Gemma release 3 and later, you can use images with your prompt. The\n",
"following example shows you how to include visual data with your prompt."
]
},
{
Expand All @@ -551,13 +569,21 @@
" contents = io.BytesIO(requests.get(url).content)\n",
" return PIL.Image.open(contents)\n",
"\n",
"image_url = 'https://storage.googleapis.com/keras-cv/models/paligemma/cow_beach_1.png'\n",
"image = read_image(image_url)\n",
"image = read_image(\n",
" 'https://storage.googleapis.com/keras-cv/models/paligemma/cow_beach_1.png'\n",
")\n",
"\n",
"print(model.generate(\n",
" [['<start_of_turn>user\\n',image, 'What animal is in this image?<end_of_turn>\\n', '<start_of_turn>model\\n']],\n",
" [\n",
" [\n",
" '<start_of_turn>user\\n',\n",
" image,\n",
" 'What animal is in this image?<end_of_turn>\\n',\n",
" '<start_of_turn>model\\n'\n",
" ]\n",
" ],\n",
" device=device,\n",
" output_len=OUTPUT_LEN,\n",
" output_len=256,\n",
"))"
]
},
Expand All @@ -570,7 +596,9 @@
"## Learn more\n",
"\n",
"Now that you have learned how to use Gemma in Pytorch, you can explore the many\n",
"other things that Gemma can do in [ai.google.dev/gemma](https://ai.google.dev/gemma).\n",
"other things that Gemma can do in\n",
"[ai.google.dev/gemma](https://ai.google.dev/gemma). \n",
"\n",
"See also these other related resources:\n",
"\n",
"- [Gemma core models overview](https://ai.google.dev/gemma/docs/core)\n",
Expand Down
Loading