diff --git a/site/en/gemma/docs/core/pytorch_gemma.ipynb b/site/en/gemma/docs/core/pytorch_gemma.ipynb index 84d2c0e4c..7249a5821 100644 --- a/site/en/gemma/docs/core/pytorch_gemma.ipynb +++ b/site/en/gemma/docs/core/pytorch_gemma.ipynb @@ -50,6 +50,7 @@ "
\n",
" View on ai.google.dev\n",
+ " | \n",
" \n",
" Run in Google Colab\n",
" | \n",
@@ -67,7 +68,10 @@
"source": [
"# Run Gemma using PyTorch\n",
"\n",
- "This guide shows you how to run Gemma using the PyTorch framework, including how to use image data for prompting Gemma release 3 and later models. For more details on the Gemma PyTorch implementation, see the project repository [README](https://github.com/google/gemma_pytorch)."
+ "This guide shows you how to run Gemma using the PyTorch framework, including how\n",
+ "to use image data for prompting Gemma release 3 and later models. For more\n",
+ "details on the Gemma PyTorch implementation, see the project repository\n",
+ "[README](https://github.com/google/gemma_pytorch)."
]
},
{
@@ -78,7 +82,9 @@
"source": [
"## Setup\n",
"\n",
- "The following sections explain how to set up your development environment, including how get access to Gemma models for downloading from Kaggle, setting authentication variables, installing dependencies, and importing packages."
+ "The following sections explain how to set up your development environment,\n",
+ "including how get access to Gemma models for downloading from Kaggle, setting\n",
+ "authentication variables, installing dependencies, and importing packages."
]
},
{
@@ -89,7 +95,12 @@
"source": [
"### System requirements\n",
"\n",
- "This Gemma Pytorch library requires GPU or TPU processors to run the Gemma model. The standard Colab CPU Python runtime and T4 GPU Python runtime are sufficient for running Gemma 1B, 2B, and 4B size models. For advanced use cases for other GPUs or TPU, please refer to [README](https://github.com/google/gemma_pytorch/blob/main/README.md) in the Gemma PyTorch repo."
+ "This Gemma Pytorch library requires GPU or TPU processors to run the Gemma \n",
+ "model. The standard Colab CPU Python runtime and T4 GPU Python runtime are\n",
+ "sufficient for running Gemma 1B, 2B, and 4B size models. For advanced use cases\n",
+ "for other GPUs or TPU, please refer to the\n",
+ "[README](https://github.com/google/gemma_pytorch/blob/main/README.md) in the\n",
+ "Gemma PyTorch repo."
]
},
{
@@ -100,13 +111,16 @@
"source": [
"### Get access to Gemma on Kaggle\n",
"\n",
- "To complete this tutorial, you first need to follow the setup instructions at [Gemma setup](https://ai.google.dev/gemma/docs/setup), which show you how to do the following:\n",
+ "To complete this tutorial, you first need to follow the setup instructions at\n",
+ "[Gemma setup](https://ai.google.dev/gemma/docs/setup), which show you how to do\n",
+ "the following:\n",
"\n",
- "* Get access to Gemma on [kaggle.com](https://www.kaggle.com/models/google/gemma/).\n",
+ "* Get access to Gemma on [Kaggle](https://www.kaggle.com/models/google/gemma/).\n",
"* Select a Colab runtime with sufficient resources to run the Gemma model.\n",
"* Generate and configure a Kaggle username and API key.\n",
"\n",
- "After you've completed the Gemma setup, move on to the next section, where you'll set environment variables for your Colab environment."
+ "After you've completed the Gemma setup, move on to the next section, where\n",
+ "you'll set environment variables for your Colab environment."
]
},
{
@@ -117,7 +131,8 @@
"source": [
"### Set environment variables\n",
"\n",
- "Set environment variables for `KAGGLE_USERNAME` and `KAGGLE_KEY`. When prompted with the \"Grant access?\" messages, agree to provide secret access."
+ "Set environment variables for `KAGGLE_USERNAME` and `KAGGLE_KEY`. When prompted\n",
+ "with the \"Grant access?\" messages, agree to provide secret access."
]
},
{
@@ -192,10 +207,7 @@
"# Choose variant and machine type\n",
"VARIANT = '4b-it' #@param ['1b','1b-it','4b','4b-it','12b','12b-it','27b','27b-it']\n",
"MACHINE_TYPE = 'cuda' #@param ['cuda', 'cpu']\n",
- "\n",
- "CONFIG = VARIANT[:2]\n",
- "if CONFIG == '4b':\n",
- " CONFIG = '4b-v1'"
+ "CONFIG = VARIANT.split('-')[0]"
]
},
{
@@ -246,7 +258,8 @@
"source": [
"## Configure the run environment\n",
"\n",
- "The following sections explain how to prepare an PyTorch environment for running Gemma."
+ "The following sections explain how to prepare a PyTorch environment for running\n",
+ "Gemma."
]
},
{
@@ -257,7 +270,8 @@
"source": [
"### Prepare the PyTorch run environment\n",
"\n",
- "Prepare the PyTorch model execution environment by cloning the Gemma Pytorch repository."
+ "Prepare the PyTorch model execution environment by cloning the Gemma Pytorch\n",
+ "repository."
]
},
{
@@ -321,7 +335,8 @@
"source": [
"### Set the model configuration\n",
"\n",
- "Before you run the model, you must set some configuration parameters, including the Gemma variant, tokenizer and quantization level."
+ "Before you run the model, you must set some configuration parameters, including\n",
+ "the Gemma variant, tokenizer and quantization level."
]
},
{
@@ -333,7 +348,7 @@
"outputs": [],
"source": [
"# Set up model config.\n",
- "model_config = get_model_config(VARIANT)\n",
+ "model_config = get_model_config(CONFIG)\n",
"model_config.dtype = \"float32\" if MACHINE_TYPE == \"cpu\" else \"float16\"\n",
"model_config.tokenizer = tokenizer_path"
]
@@ -418,8 +433,8 @@
"- `