Skip to content

Commit 3291b8c

Browse files
committed
Fixing nits
Signed-off-by: Vladimir Suvorov <[email protected]>
1 parent cd25f17 commit 3291b8c

File tree

2 files changed

+18
-16
lines changed

2 files changed

+18
-16
lines changed

src/MaxText/examples/rl_llama3_demo.ipynb

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -87,18 +87,24 @@
8787
]
8888
},
8989
{
90-
"cell_type": "markdown",
90+
"cell_type": "code",
91+
"execution_count": null,
9192
"metadata": {},
92-
"source": []
93+
"outputs": [],
94+
"source": [
95+
"#Choose the loss algorithm between GSPO or GRPO\n",
96+
"LOSS_ALGO=\"grpo\" # or \"gspo-token\" if you want to use GSPO"
97+
]
9398
},
9499
{
95100
"cell_type": "code",
96101
"execution_count": null,
97102
"metadata": {},
98103
"outputs": [],
99104
"source": [
100-
"# Configuration for GRPO training\n",
101105
"import os\n",
106+
"import sys\n",
107+
"from pathlib import Path\n",
102108
"import MaxText\n",
103109
"from huggingface_hub import login\n",
104110
"import jax\n",
@@ -110,7 +116,6 @@
110116
"MODEL_NAME = \"llama3.1-8b\"\n",
111117
"HF_REPO_ID = \"meta-llama/Llama-3.1-8B-Instruct\"\n",
112118
"CHAT_TEMPLATE_PATH = f\"{MAXTEXT_REPO_ROOT}/examples/chat_templates/gsm8k_rl.json\"\n",
113-
"LOSS_ALGO=\"grpo\" # or \"gspo-token\" if you want to use GSPO\n",
114119
"\n",
115120
"# Required: Set these before running\n",
116121
"MODEL_CHECKPOINT_PATH = \"\" # Update this!\n",
@@ -151,11 +156,6 @@
151156
"metadata": {},
152157
"outputs": [],
153158
"source": [
154-
"# Import required modules\n",
155-
"import os\n",
156-
"import sys\n",
157-
"from pathlib import Path\n",
158-
"\n",
159159
"# Add MaxText to Python path\n",
160160
"maxtext_path = Path(MAXTEXT_REPO_ROOT) \n",
161161
"sys.path.insert(0, str(maxtext_path))\n",

src/MaxText/examples/sft_llama3_demo.ipynb

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -136,14 +136,16 @@
136136
"source": [
137137
"## Set the model, checkpoint path and output directory\n",
138138
"MODEL_NAME = \"llama3.1-8b\"\n",
139-
"# Case 1: Set `MODEL_CHECKPOINT_PATH` to the path (local or gs://) that already has `Llama3.1-8B-Instruct` model checkpoint\n",
140-
"# Case 2: If you do not have the checkpoint, then do not update `MODEL_CHECKPOINT_PATH`\n",
141-
"# and this colab will download the checkpoint from HF and store at `\"{MAXTEXT_REPO_ROOT}/llama_checkpoint\"`\n",
142-
"MODEL_CHECKPOINT_PATH = f\"{MAXTEXT_REPO_ROOT}/llama_checkpoint\"\n",
139+
"# set the path to the model checkpoint or leave empty to download from HuggingFace\n",
140+
"MODEL_CHECKPOINT_PATH = \"\"\n",
141+
"if not MODEL_CHECKPOINT_PATH:\n",
142+
" MODEL_CHECKPOINT_PATH = f\"{MAXTEXT_REPO_ROOT}/llama_checkpoint\"\n",
143+
" print(\"Model checkpoint will be downloaded from HuggingFace at: \", MODEL_CHECKPOINT_PATH)\n",
144+
" print(\"Set MODEL_CHECKPOINT_PATH if you do not wish to download the checkpoint.\")\n",
143145
"\n",
144-
"# This is the directory where the fine-tuned model will be saved\n",
145-
"# You can change it to any path you want (local or gs://)\n",
146-
"BASE_OUTPUT_DIRECTORY = \"/tmp/out/maxtext_llama3_8b\"\n"
146+
"BASE_OUTPUT_DIRECTORY = \"\"\n",
147+
"if not BASE_OUTPUT_DIRECTORY:\n",
148+
" print(\"Please set BASE_OUTPUT_DIRECTORY to store output/logs.\")\n"
147149
]
148150
},
149151
{

0 commit comments

Comments
 (0)