|
4 | 4 | "cell_type": "markdown", |
5 | 5 | "metadata": {}, |
6 | 6 | "source": [ |
7 | | - "# GRPO Llama3.1-8B Demo: Direct Function Call\n", |
| 7 | + "# GRPO/GSPO Llama3.1-8B Demo\n", |
8 | 8 | "\n", |
9 | | - "This notebook demonstrates GRPO training by directly calling the `rl_train` function from `rl_trainer.py`.\n", |
| 9 | + "This notebook demonstrates GRPO (Group Relative Policy Optimization) training using the unified `rl_train` function or GSPO (Group Sequence Policy Optimization) - the change is in loss function which is a parameter\n", |
10 | 10 | "\n", |
11 | | - "## What is GRPO?\n", |
| 11 | + "## What is GRPO/GSPO?\n", |
12 | 12 | "\n", |
13 | | - "GRPO (Group Relative Policy Optimization) is an RL algorithm that enhances reasoning abilities of LLMs by:\n", |
| 13 | + "GRPO/GSPO is an RL algorithm that enhances reasoning abilities of LLMs by:\n", |
14 | 14 | "1. Generating multiple responses for each prompt\n", |
15 | 15 | "2. Evaluating responses using reward models \n", |
16 | 16 | "3. Calculating relative advantages to update the policy\n", |
17 | 17 | "\n", |
18 | | - "\n", |
19 | | - "This notebook imports and calls the `rl_train` function \n", |
| 18 | + "The difference is in the loss function - either it's optimizing each token (GRPO) or the whole sequence(GSPO).\n", |
20 | 19 | "\n", |
21 | 20 | "## Hardware Requirements\n", |
22 | 21 | "\n", |
|
28 | 27 | "cell_type": "markdown", |
29 | 28 | "metadata": {}, |
30 | 29 | "source": [ |
31 | | - "## Setup\n", |
| 30 | + "### Get Your Hugging Face Token\n", |
| 31 | + "\n", |
| 32 | + "To access model checkpoint from the Hugging Face Hub, you need to authenticate with a personal access token.\n", |
| 33 | + "\n", |
| 34 | + "**Follow these steps to get your token:**\n", |
| 35 | + "\n", |
| 36 | + "1. **Navigate to the Access Tokens page** in your Hugging Face account settings. You can go there directly by visiting this URL:\n", |
| 37 | + " * [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)\n", |
| 38 | + "\n", |
| 39 | + "2. **Create a new token** by clicking the **\"+ Create new token\"** button.\n", |
32 | 40 | "\n", |
33 | | - "Install dependencies and set up the environment:" |
| 41 | + "3. **Give your token a name** and assign it a **`read` role**. The `read` role is sufficient for downloading models.\n", |
| 42 | + "\n", |
| 43 | + "4. **Copy the generated token**. You will need to paste it in the next step.\n", |
| 44 | + "\n", |
| 45 | + "**Follow these steps to store your token:**\n", |
| 46 | + "\n", |
| 47 | + "Just put your token in the line below" |
34 | 48 | ] |
35 | 49 | }, |
36 | 50 | { |
|
39 | 53 | "metadata": {}, |
40 | 54 | "outputs": [], |
41 | 55 | "source": [ |
42 | | - "# Clone MaxText repository\n", |
43 | | - "!git clone https://github.com/AI-Hypercomputer/maxtext.git\n", |
44 | | - "%cd maxtext" |
| 56 | + "HF_TOKEN = \"\" # Set HF_TOKEN environment variable\n" |
45 | 57 | ] |
46 | 58 | }, |
47 | 59 | { |
48 | | - "cell_type": "code", |
49 | | - "execution_count": null, |
| 60 | + "cell_type": "markdown", |
50 | 61 | "metadata": {}, |
51 | | - "outputs": [], |
52 | 62 | "source": [ |
53 | | - "# Install dependencies\n", |
54 | | - "!chmod +x setup.sh\n", |
55 | | - "!./setup.sh\n", |
56 | | - "\n", |
57 | | - "# Install GRPO-specific dependencies\n", |
58 | | - "!./src/MaxText/examples/install_tunix_vllm_requirement.sh\n", |
59 | | - "\n", |
60 | | - "# Install additional requirements\n", |
61 | | - "%pip install --force-reinstall numpy==2.1.2\n", |
62 | | - "%pip install nest_asyncio\n", |
| 63 | + "## Setup\n", |
63 | 64 | "\n", |
64 | | - "import nest_asyncio\n", |
65 | | - "nest_asyncio.apply() # Fix for Colab event loop" |
| 65 | + "Install dependencies and set up the environment:\n", |
| 66 | + "https://maxtext.readthedocs.io/latest/tutorials/grpo.html#from-github" |
66 | 67 | ] |
67 | 68 | }, |
68 | 69 | { |
|
71 | 72 | "source": [ |
72 | 73 | "## Configuration\n", |
73 | 74 | "\n", |
74 | | - "Set up the training parameters:" |
| 75 | + "Set up the training parameters. We do not use Pathways and do use a single host. Defaults are hardcoded for Llama3.1-8B:" |
75 | 76 | ] |
76 | 77 | }, |
| 78 | + { |
| 79 | + "cell_type": "code", |
| 80 | + "execution_count": null, |
| 81 | + "metadata": {}, |
| 82 | + "outputs": [], |
| 83 | + "source": [ |
| 84 | + "!cd ~/maxtext/src/ # make sure we are in the right directory" |
| 85 | + ] |
| 86 | + }, |
| 87 | + { |
| 88 | + "cell_type": "markdown", |
| 89 | + "metadata": {}, |
| 90 | + "source": [] |
| 91 | + }, |
77 | 92 | { |
78 | 93 | "cell_type": "code", |
79 | 94 | "execution_count": null, |
|
82 | 97 | "source": [ |
83 | 98 | "# Configuration for GRPO training\n", |
84 | 99 | "import os\n", |
| 100 | + "from re import M\n", |
| 101 | + "import MaxText\n", |
| 102 | + "\n", |
| 103 | + "# Set up paths (adjust if needed)\n", |
| 104 | + "MAXTEXT_REPO_ROOT = os.path.dirname(MaxText.__file__)\n", |
| 105 | + "RUN_NAME=\"grpo_test\"\n", |
| 106 | + "# Hardcoded defaults for Llama3.1-8B\n", |
| 107 | + "MODEL_NAME = \"llama3.1-8b\"\n", |
| 108 | + "HF_REPO_ID = \"meta-llama/Llama-3.1-8B-Instruct\"\n", |
| 109 | + "CHAT_TEMPLATE_PATH = f\"{MAXTEXT_REPO_ROOT}/examples/chat_templates/gsm8k_rl.json\"\n", |
| 110 | + "LOSS_ALGO=\"grpo\" # or \"gspo-token\" if you want to use GSPO\n", |
| 111 | + "\n", |
| 112 | + "# Required: Set these before running\n", |
| 113 | + "MODEL_CHECKPOINT_PATH = \"\" # Update this!\n", |
| 114 | + "if not MODEL_CHECKPOINT_PATH:\n", |
| 115 | + " raise RuntimeError(\"MODEL_CHECKPOINT_PATH is not set\")\n", |
| 116 | + " \n", |
| 117 | + "OUTPUT_DIRECTORY = \"/tmp/gpo_output\" # Update this!\n", |
| 118 | + "os.environ[\"HF_TOKEN\"] = HF_TOKEN\n", |
85 | 119 | "\n", |
86 | | - "# Set up paths\n", |
87 | | - "MAXTEXT_REPO_ROOT = os.path.expanduser(\"~\") + \"/maxtext\"\n", |
88 | | - "print(f\"MaxText Home directory: {MAXTEXT_REPO_ROOT}\")\n", |
| 120 | + "if HF_TOKEN:\n", |
| 121 | + " login(token=HF_TOKEN)\n", |
| 122 | + " print(\"Authenticated with Hugging Face\")\n", |
| 123 | + "else:\n", |
| 124 | + " print(\"Authentication failed: Hugging Face token not set\")\n", |
89 | 125 | "\n", |
90 | | - "# Training configuration\n", |
91 | | - "MODEL_CHECKPOINT_PATH = \"gs://maxtext-model-checkpoints/llama3.1-8b/2025-01-23-19-04/scanned/0/items\"\n", |
92 | | - "OUTPUT_DIRECTORY = \"/tmp/grpo_output\"\n", |
| 126 | + "# Optional: Override training parameters\n", |
93 | 127 | "STEPS = 10 # Reduced for demo purposes\n", |
94 | | - "HF_TOKEN = os.environ.get(\"HF_TOKEN\", \"your_hf_token_here\")\n", |
| 128 | + "PER_DEVICE_BATCH_SIZE = 1\n", |
| 129 | + "LEARNING_RATE = 3e-6\n", |
| 130 | + "NUM_GENERATIONS = 2\n", |
| 131 | + "GRPO_BETA = 0.08\n", |
| 132 | + "GRPO_EPSILON = 0.2\n", |
| 133 | + "CHIPS_PER_VM = 1\n", |
95 | 134 | "\n", |
96 | | - "print(f\"Model checkpoint: {MODEL_CHECKPOINT_PATH}\")\n", |
97 | | - "print(f\"Output directory: {OUTPUT_DIRECTORY}\")\n", |
98 | | - "print(f\"Training steps: {STEPS}\")" |
| 135 | + "print(f\"📁 MaxText Home: {MAXTEXT_REPO_ROOT}\")\n", |
| 136 | + "print(f\"🤖 Model: {MODEL_NAME}\")\n", |
| 137 | + "print(f\"📦 Checkpoint: {MODEL_CHECKPOINT_PATH}\")\n", |
| 138 | + "print(f\"💾 Output: {OUTPUT_DIRECTORY}\")\n", |
| 139 | + "print(f\"🔑 HF Token: {'✅ Set' if HF_TOKEN else '❌ Missing - set HF_TOKEN env var'}\")\n", |
| 140 | + "print(f\"📊 Steps: {STEPS}\")\n", |
| 141 | + "print(f\"Loss Algorithm : {LOSS_ALGO}\")" |
99 | 142 | ] |
100 | 143 | }, |
101 | 144 | { |
|
104 | 147 | "metadata": {}, |
105 | 148 | "outputs": [], |
106 | 149 | "source": [ |
107 | | - "# Import GRPO training function directly\n", |
108 | | - "import sys\n", |
| 150 | + "# Import required modules\n", |
109 | 151 | "import os\n", |
| 152 | + "import sys\n", |
110 | 153 | "from pathlib import Path\n", |
111 | 154 | "\n", |
112 | 155 | "# Add MaxText to Python path\n", |
113 | | - "maxtext_path = Path(MAXTEXT_REPO_ROOT) / \"src\" / \"MaxText\"\n", |
| 156 | + "maxtext_path = Path(MAXTEXT_REPO_ROOT) \n", |
114 | 157 | "sys.path.insert(0, str(maxtext_path))\n", |
115 | 158 | "\n", |
116 | | - "# Import required modules\n", |
117 | | - "from MaxText import pyconfig\n", |
118 | | - "from MaxText.train_rl import rl_train\n", |
| 159 | + "from MaxText import pyconfig, max_utils\n", |
| 160 | + "from MaxText.rl.train_rl import rl_train, setup_configs_and_devices\n", |
| 161 | + "import jax\n", |
119 | 162 | "\n", |
120 | | - "print(\"✅ Successfully imported GRPO training function\")\n", |
121 | | - "print(f\"📁 MaxText path: {maxtext_path}\")\n", |
122 | | - "print(\"\\n\" + \"=\"*80)\n", |
123 | | - "print(\"Starting GRPO Training...\")\n", |
124 | | - "print(\"=\"*80)" |
| 163 | + "# Initialize JAX and Pathways\n", |
| 164 | + "os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"0\"\n", |
| 165 | + "os.environ[\"SKIP_JAX_PRECOMPILE\"] = \"1\" # Faster startup for vLLM\n", |
| 166 | + "\n", |
| 167 | + "print(\"✅ Successfully imported modules\")\n", |
| 168 | + "print(f\"📁 MaxText path: {maxtext_path}\")" |
125 | 169 | ] |
126 | 170 | }, |
127 | 171 | { |
|
131 | 175 | "outputs": [], |
132 | 176 | "source": [ |
133 | 177 | "# Build configuration for GRPO training\n", |
| 178 | + "config_file = os.path.join(MAXTEXT_REPO_ROOT, \"configs/rl.yml\")\n", |
| 179 | + "\n", |
| 180 | + "# Verify chat template exists\n", |
| 181 | + "if not os.path.exists(CHAT_TEMPLATE_PATH)):\n", |
| 182 | + " raise FileNotFoundError(f\"Chat template not found: {CHAT_TEMPLATE_PATH}\")\n", |
| 183 | + "\n", |
| 184 | + "# Build argv list for pyconfig.initialize()\n", |
134 | 185 | "config_argv = [\n", |
135 | | - " \"\", # Placeholder for argv[0]\n", |
136 | | - " \"src/MaxText/configs/grpo.yml\", # Base config\n", |
137 | | - " f\"model_name=llama3.1-8b\",\n", |
138 | | - " f\"tokenizer_path=meta-llama/Llama-3.1-8B-Instruct\",\n", |
| 186 | + " \"\", # argv[0] placeholder\n", |
| 187 | + " config_file,\n", |
| 188 | + " f\"model_name={MODEL_NAME}\",\n", |
| 189 | + " f\"tokenizer_path={HF_REPO_ID}\",\n", |
| 190 | + " f\"run_name={RUN_NAME}\",\n", |
| 191 | + " f\"chat_template_path={CHAT_TEMPLATE_PATH}\",\n", |
139 | 192 | " f\"load_parameters_path={MODEL_CHECKPOINT_PATH}\",\n", |
140 | 193 | " f\"base_output_directory={OUTPUT_DIRECTORY}\",\n", |
141 | 194 | " f\"hf_access_token={HF_TOKEN}\",\n", |
142 | 195 | " f\"steps={STEPS}\",\n", |
143 | | - " \"per_device_batch_size=1\",\n", |
144 | | - " \"learning_rate=3e-6\",\n", |
145 | | - " \"num_generations=2\",\n", |
146 | | - " \"grpo_beta=0.08\",\n", |
147 | | - " \"grpo_epsilon=0.2\",\n", |
148 | | - " \"chips_per_vm=4\"\n", |
| 196 | + " f\"per_device_batch_size={PER_DEVICE_BATCH_SIZE}\",\n", |
| 197 | + " f\"learning_rate={LEARNING_RATE}\",\n", |
| 198 | + " f\"num_generations={NUM_GENERATIONS}\",\n", |
| 199 | + " f\"grpo_beta={GRPO_BETA}\",\n", |
| 200 | + " f\"grpo_epsilon={GRPO_EPSILON}\",\n", |
| 201 | + " f\"chips_per_vm={CHIPS_PER_VM}\",\n", |
| 202 | + " f\"loss_algo={LOSS_ALGO}\",\n", |
| 203 | + " \"use_pathways=False\"\n", |
149 | 204 | "]\n", |
150 | 205 | "\n", |
151 | | - "# Create configuration object\n", |
152 | | - "config = pyconfig.Config()\n", |
153 | | - "config.parse_flags(config_argv)\n", |
| 206 | + "# Initialize configuration\n", |
| 207 | + "print(f\"🔧 Initializing configuration from: {config_file}\")\n", |
| 208 | + "config = pyconfig.initialize(config_argv)\n", |
| 209 | + "max_utils.print_system_information()\n", |
154 | 210 | "\n", |
155 | | - "print(\"✅ Configuration created successfully\")\n", |
| 211 | + "print(\"\\n✅ Configuration initialized successfully\")\n", |
156 | 212 | "print(f\"📊 Training steps: {config.steps}\")\n", |
157 | 213 | "print(f\"📁 Output directory: {config.base_output_directory}\")\n", |
158 | 214 | "print(f\"🤖 Model: {config.model_name}\")" |
|
164 | 220 | "metadata": {}, |
165 | 221 | "outputs": [], |
166 | 222 | "source": [ |
167 | | - "# Execute GRPO training directly\n", |
| 223 | + "# Execute GRPO/GSPO training\n", |
| 224 | + "print(\"\\n\" + \"=\"*80)\n", |
| 225 | + "print(\"🚀 Starting Training...\")\n", |
| 226 | + "print(\"=\"*80)\n", |
| 227 | + "print(1)\n", |
168 | 228 | "try:\n", |
169 | | - " # Call the rl_train function\n", |
170 | | - " grpo_trainer, rl_cluster = rl_train(config)\n", |
| 229 | + " # Call the rl_train function (it handles everything internally)\n", |
| 230 | + " rl_train(config)\n", |
171 | 231 | " \n", |
172 | 232 | " print(\"\\n\" + \"=\"*80)\n", |
173 | | - " print(\"✅ GRPO Training Completed Successfully!\")\n", |
| 233 | + " print(\"✅ Training Completed Successfully!\")\n", |
174 | 234 | " print(\"=\"*80)\n", |
175 | | - " print(f\"📁 Checkpoints and logs saved to: {config.base_output_directory}\")\n", |
176 | | - " print(f\"🎯 Final model ready for inference!\")\n", |
| 235 | + " print(f\"📁 Checkpoints saved to: {config.checkpoint_dir}\")\n", |
| 236 | + " print(f\"📊 TensorBoard logs: {config.tensorboard_dir}\")\n", |
| 237 | + " print(f\"🎯 Model ready for inference!\")\n", |
177 | 238 | " \n", |
178 | 239 | "except Exception as e:\n", |
179 | 240 | " print(\"\\n\" + \"=\"*80)\n", |
180 | | - " print(\"❌ GRPO Training Failed!\")\n", |
| 241 | + " print(\"❌Training Failed!\")\n", |
181 | 242 | " print(\"=\"*80)\n", |
182 | 243 | " print(f\"Error: {str(e)}\")\n", |
183 | | - " print(\"\\nPlease check the error message and try again.\")" |
| 244 | + " import traceback\n", |
| 245 | + " traceback.print_exc()\n", |
| 246 | + " print(\"\\n💡 Common issues:\")\n", |
| 247 | + " print(\" - Check that MODEL_CHECKPOINT_PATH points to a valid checkpoint\")\n", |
| 248 | + " print(\" - Ensure HF_TOKEN environment variable is set\")\n", |
| 249 | + " print(\" - Verify OUTPUT_DIRECTORY is writable\")\n", |
| 250 | + " print(\" - Check hardware requirements (TPU/GPU availability)\")" |
184 | 251 | ] |
185 | 252 | }, |
186 | 253 | { |
187 | 254 | "cell_type": "markdown", |
188 | 255 | "metadata": {}, |
189 | 256 | "source": [ |
190 | | - "### 📚 **Learn More**\n", |
191 | | - "- See `src/MaxText/examples/grpo_runner.py` for CLI usage\n", |
192 | | - "- Check `src/MaxText/configs/grpo.yml` for configuration options\n", |
193 | | - "- Read `src/MaxText/examples/README.md` for more examples" |
| 257 | + "## 📚 Learn More\n", |
| 258 | + "\n", |
| 259 | + "- **CLI Usage**: Run `python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml --model_name=llama3.1-8b ...`\n", |
| 260 | + "- **Configuration**: See `src/MaxText/configs/rl.yml` for all available options\n", |
| 261 | + "- **Documentation**: Check `src/MaxText/rl/train_rl.py` for the `rl_train` function implementation\n", |
| 262 | + "- **Examples**: See other examples in `src/MaxText/examples/`" |
194 | 263 | ] |
195 | 264 | } |
196 | 265 | ], |
|
0 commit comments