Skip to content

Commit 2761922

Browse files
committed
Update descriptions.
1 parent d20c245 commit 2761922

File tree

4 files changed

+46
-29
lines changed

4 files changed

+46
-29
lines changed

examples/dpo_demo_gemma3.ipynb

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -745,6 +745,8 @@
745745
},
746746
"outputs": [],
747747
"source": [
748+
"# The first couple of training step might take up to 5 minutes to finish. Please be patient. If you experience long training steps, e.g. \u003e10 minutes per step, please open a bug. Really appreciated!\n",
749+
"\n",
748750
"if mesh is None:\n",
749751
" dpo_trainer.train(train_dataset)\n",
750752
"else:\n",

examples/grpo_demo.ipynb

Lines changed: 40 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -242,9 +242,9 @@
242242
"## Data preprocessing\n",
243243
"\n",
244244
"First, let's define some special tokens. We instruct the model to first reason\n",
245-
"between the `<reasoning>` and `</reasoning>` tokens. After\n",
246-
"reasoning, we expect it to provide the answer between the `<answer>` and\n",
247-
"`</answer>` tokens."
245+
"between the `\u003creasoning\u003e` and `\u003c/reasoning\u003e` tokens. After\n",
246+
"reasoning, we expect it to provide the answer between the `\u003canswer\u003e` and\n",
247+
"`\u003c/answer\u003e` tokens."
248248
]
249249
},
250250
{
@@ -254,22 +254,22 @@
254254
"metadata": {},
255255
"outputs": [],
256256
"source": [
257-
"reasoning_start = \"<reasoning>\"\n",
258-
"reasoning_end = \"</reasoning>\"\n",
259-
"solution_start = \"<answer>\"\n",
260-
"solution_end = \"</answer>\"\n",
257+
"reasoning_start = \"\u003creasoning\u003e\"\n",
258+
"reasoning_end = \"\u003c/reasoning\u003e\"\n",
259+
"solution_start = \"\u003canswer\u003e\"\n",
260+
"solution_end = \"\u003c/answer\u003e\"\n",
261261
"\n",
262262
"\n",
263263
"SYSTEM_PROMPT = f\"\"\"You are given a problem. Think about the problem and \\\n",
264264
"provide your reasoning. Place it between {reasoning_start} and \\\n",
265265
"{reasoning_end}. Then, provide the final answer (i.e., just one numerical \\\n",
266266
"value) between {solution_start} and {solution_end}.\"\"\"\n",
267267
"\n",
268-
"TEMPLATE = \"\"\"<start_of_turn>user\n",
268+
"TEMPLATE = \"\"\"\u003cstart_of_turn\u003euser\n",
269269
"{system_prompt}\n",
270270
"\n",
271-
"{question}<end_of_turn>\n",
272-
"<start_of_turn>model\"\"\""
271+
"{question}\u003cend_of_turn\u003e\n",
272+
"\u003cstart_of_turn\u003emodel\"\"\""
273273
]
274274
},
275275
{
@@ -287,7 +287,7 @@
287287
"metadata": {},
288288
"outputs": [],
289289
"source": [
290-
"def extract_hash_answer(text: str) -> str | None:\n",
290+
"def extract_hash_answer(text: str) -\u003e str | None:\n",
291291
" if \"####\" not in text:\n",
292292
" return None\n",
293293
" return text.split(\"####\")[1].strip()\n",
@@ -315,7 +315,7 @@
315315
" return target_dir\n",
316316
"\n",
317317
"\n",
318-
"def get_dataset(data_dir, split=\"train\", source=\"tfds\") -> grain.MapDataset:\n",
318+
"def get_dataset(data_dir, split=\"train\", source=\"tfds\") -\u003e grain.MapDataset:\n",
319319
" # Download data\n",
320320
" if not os.path.exists(data_dir):\n",
321321
" os.makedirs(data_dir)\n",
@@ -508,6 +508,7 @@
508508
"outputs": [],
509509
"source": [
510510
"!rm /tmp/content/intermediate_ckpt/* -rf\n",
511+
"\n",
511512
"!rm /tmp/content/ckpts/* -rf\n",
512513
"\n",
513514
"if model_family == \"gemma2\":\n",
@@ -651,7 +652,7 @@
651652
"- reward if the format of the output approximately matches the instruction given\n",
652653
"in `TEMPLATE`;\n",
653654
"- reward if the answer is correct/partially correct;\n",
654-
"- Sometimes, the text between `<answer>`, `</answer>` might not be one\n",
655+
"- Sometimes, the text between `\u003canswer\u003e`, `\u003c/answer\u003e` might not be one\n",
655656
" number. So, we extract the number, and reward the model if the answer is correct.\n",
656657
"\n",
657658
"The reward functions are inspired from\n",
@@ -779,9 +780,9 @@
779780
" # Ie if the answer is within some range, reward it!\n",
780781
" try:\n",
781782
" ratio = float(guess) / float(true_answer)\n",
782-
" if ratio >= 0.9 and ratio <= 1.1:\n",
783+
" if ratio \u003e= 0.9 and ratio \u003c= 1.1:\n",
783784
" score += 0.5\n",
784-
" elif ratio >= 0.8 and ratio <= 1.2:\n",
785+
" elif ratio \u003e= 0.8 and ratio \u003c= 1.2:\n",
785786
" score += 0.25\n",
786787
" else:\n",
787788
" score -= 1.0 # Penalize wrong answers\n",
@@ -796,7 +797,7 @@
796797
"id": "nIpOVv78Tn1k",
797798
"metadata": {},
798799
"source": [
799-
"Sometimes, the text between `<answer>` and `</answer>` might not be one\n",
800+
"Sometimes, the text between `\u003canswer\u003e` and `\u003c/answer\u003e` might not be one\n",
800801
"number; it can be a sentence. So, we extract the number and compare the answer."
801802
]
802803
},
@@ -873,7 +874,7 @@
873874
"ratio lies between 0.9 and 1.1. \n",
874875
"* **Format Accuracy**: percentage of samples for which the model outputs the\n",
875876
"correct format, i.e., reasoning between the reasoning special tokens, and the\n",
876-
"final answer between the \\`\\<start\\_answer\\>\\`, \\`\\<end\\_answer\\>\\` tokens.\n",
877+
"final answer between the \\`\\\u003cstart\\_answer\\\u003e\\`, \\`\\\u003cend\\_answer\\\u003e\\` tokens.\n",
877878
"\n",
878879
"**Qualitative**\n",
879880
"\n",
@@ -995,7 +996,7 @@
995996
" corr_ctr_per_question += 1\n",
996997
"\n",
997998
" ratio = float(extracted_response.strip()) / float(answer.strip())\n",
998-
" if ratio >= 0.9 and ratio <= 1.1:\n",
999+
" if ratio \u003e= 0.9 and ratio \u003c= 1.1:\n",
9991000
" partially_corr_per_question += 1\n",
10001001
" except:\n",
10011002
" print(\"SKIPPED\")\n",
@@ -1005,28 +1006,28 @@
10051006
" corr_format_per_question += 1\n",
10061007
"\n",
10071008
" if (\n",
1008-
" corr_ctr_per_question > 0\n",
1009-
" and partially_corr_per_question > 0\n",
1010-
" and corr_format_per_question > 0\n",
1009+
" corr_ctr_per_question \u003e 0\n",
1010+
" and partially_corr_per_question \u003e 0\n",
1011+
" and corr_format_per_question \u003e 0\n",
10111012
" ):\n",
10121013
" break\n",
10131014
"\n",
1014-
" if corr_ctr_per_question > 0:\n",
1015+
" if corr_ctr_per_question \u003e 0:\n",
10151016
" corr += 1\n",
10161017
" if corr_lst and make_lst:\n",
10171018
" response_lst.append((question, answer, multiple_call_response))\n",
10181019
" else:\n",
10191020
" if not corr_lst and make_lst:\n",
10201021
" response_lst.append((question, answer, multiple_call_response))\n",
1021-
" if partially_corr_per_question > 0:\n",
1022+
" if partially_corr_per_question \u003e 0:\n",
10221023
" partially_corr += 1\n",
1023-
" if corr_format_per_question > 0:\n",
1024+
" if corr_format_per_question \u003e 0:\n",
10241025
" corr_format += 1\n",
10251026
"\n",
10261027
" total += 1\n",
10271028
" if total % 10 == 0:\n",
10281029
" print(\n",
1029-
" f\"===> {corr=}, {total=}, {corr / total * 100=}, \"\n",
1030+
" f\"===\u003e {corr=}, {total=}, {corr / total * 100=}, \"\n",
10301031
" f\"{partially_corr / total * 100=}, {corr_format / total * 100=}\"\n",
10311032
" )\n",
10321033
"\n",
@@ -1066,7 +1067,7 @@
10661067
"id": "UOAQe06DyVlQ",
10671068
"metadata": {},
10681069
"source": [
1069-
"Now let's see how the original model does on the test set. You can see the percentages of the mode outputs that are fully correct, partially correct and just correct in format. "
1070+
"Now let's see how the original model does on the test set. You can see the percentages of the mode outputs that are fully correct, partially correct and just correct in format. The following step might take couple of minutes to finish."
10701071
]
10711072
},
10721073
{
@@ -1076,6 +1077,8 @@
10761077
"metadata": {},
10771078
"outputs": [],
10781079
"source": [
1080+
"# The evaluation might take up to couple of minutes to finish. Please be patient.\n",
1081+
"\n",
10791082
"(corr, total, accuracy, partial_accuracy, format_accuracy) = evaluate(\n",
10801083
" test_dataset,\n",
10811084
" sampler,\n",
@@ -1211,11 +1214,11 @@
12111214
"\n",
12121215
"We then create a `GRPOLearner`, the specialized trainer that uses a list of **reward functions** to evaluate and optimize the model's output, completing the RL training setup.\n",
12131216
"\n",
1214-
"Tunix trainers are integrated with [Weights & Biases](https://wandb.ai/) to help you visualize the training progress. You can choose how you want to use it:\n",
1217+
"Tunix trainers are integrated with [Weights \u0026 Biases](https://wandb.ai/) to help you visualize the training progress. You can choose how you want to use it:\n",
12151218
"\n",
12161219
"**Option 1 (Type 1)**: If you're running a quick experiment or just testing things out, choose this. It creates a temporary, private dashboard right in your browser without requiring you to log in or create an account.\n",
12171220
"\n",
1218-
"**Option 2 (Type 2)**: If you have an existing W&B account and want to save your project's history to your personal dashboard, choose this. You'll be prompted to enter your API key or log in."
1221+
"**Option 2 (Type 2)**: If you have an existing W\u0026B account and want to save your project's history to your personal dashboard, choose this. You'll be prompted to enter your API key or log in."
12191222
]
12201223
},
12211224
{
@@ -1246,6 +1249,14 @@
12461249
")"
12471250
]
12481251
},
1252+
{
1253+
"cell_type": "markdown",
1254+
"id": "e8b71ed5",
1255+
"metadata": {},
1256+
"source": [
1257+
"The first couple of training step might take up to 5 minutes to finish. Please be patient. If you experience long training steps, e.g. \u003e10 minutes per step, please open a bug. Really appreciated!"
1258+
]
1259+
},
12491260
{
12501261
"cell_type": "code",
12511262
"execution_count": null,
@@ -1323,6 +1334,7 @@
13231334
"metadata": {},
13241335
"outputs": [],
13251336
"source": [
1337+
"# The evaluation might take up to couple of minutes to finish. Please be patient.\n",
13261338
"(corr, total, accuracy, partial_accuracy, format_accuracy) = evaluate(\n",
13271339
" test_dataset,\n",
13281340
" sampler,\n",

examples/logit_distillation.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,7 @@
392392
" training_config=config,\n",
393393
").with_gen_model_input_fn(gen_model_input_fn)\n",
394394
"\n",
395-
"# 5. Run training within the mesh context\n",
395+
"# 5. Run training within the mesh context, the first couple of training step might take up to 5 minutes to finish. Please be patient. If you experience long training steps, e.g. \u003e10 minutes per, please open a bug. Really appreciated!\n",
396396
"print(\"Starting distillation training...\")\n",
397397
"with mesh:\n",
398398
" trainer.train(train_ds, validation_ds)\n",

examples/qlora_demo.ipynb

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,7 @@
607607
")\n",
608608
"trainer = trainer.with_gen_model_input_fn(gen_model_input_fn)\n",
609609
"\n",
610+
"# The first couple of training step might take up to 5 minutes to finish. Please be patient. If you experience long training steps, e.g. \u003e10 minutes per step, please open a bug. Really appreciated!\n",
610611
"with jax.profiler.trace(os.path.join(PROFILING_DIR, \"full_training\")):\n",
611612
" with mesh:\n",
612613
" trainer.train(train_ds, validation_ds)"
@@ -641,6 +642,7 @@
641642
" lora_model, optax.adamw(1e-3), training_config\n",
642643
").with_gen_model_input_fn(gen_model_input_fn)\n",
643644
"\n",
645+
"# The first couple of training step might take up to 5 minutes to finish. Please be patient. If you experience long training steps, e.g. \u003e10 minutes per step, please open a bug. Really appreciated!\n",
644646
"with jax.profiler.trace(os.path.join(PROFILING_DIR, \"peft with LoRA\")):\n",
645647
" with mesh:\n",
646648
" lora_trainer.train(train_ds, validation_ds)"
@@ -663,6 +665,7 @@
663665
" qlora_model, optax.adamw(1e-3), training_config\n",
664666
").with_gen_model_input_fn(gen_model_input_fn)\n",
665667
"\n",
668+
"# The first couple of training step might take up to 5 minutes to finish. Please be patient. If you experience long training steps, e.g. \u003e10 minutes per step, please open a bug. Really appreciated!\n",
666669
"with jax.profiler.trace(os.path.join(PROFILING_DIR, \"peft with QLoRA\")):\n",
667670
" with mesh:\n",
668671
" qlora_trainer.train(train_ds, validation_ds)"

0 commit comments

Comments
 (0)