Skip to content

Commit 6036dea

Browse files
authored
Add a new DAG for SFT post-training from Jupyter notebooks (#1164)
Add an Airflow DAG for automating Llama3.1-8B SFT training from Jupyter notebook. This DAG automates the sft_llama3_demo.ipynb notebook, executing SFT training on single-host TPU VMs.
1 parent 8bb2f2e commit 6036dea

File tree

4 files changed

+174
-81
lines changed

4 files changed

+174
-81
lines changed
Lines changed: 18 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
Airflow DAG for automating Llama3.1-8B RL training from Jupyter notebook.
2+
Airflow DAG for automating Llama3.1-8B RL training from Jupyter notebooks.
33
44
This DAG automates the rl_llama3_demo.ipynb notebook, executing GRPO/GSPO
55
training on single-host TPU VMs.
@@ -11,20 +11,11 @@
1111

1212
from dags import composer_env
1313
from dags.common import test_owner
14-
from dags.common.vm_resource import (
15-
Project,
16-
RuntimeVersion,
17-
TpuVersion,
18-
V6E_GCE_NETWORK,
19-
V6E_GCE_SUBNETWORK,
20-
Zone,
21-
)
2214
from dags.post_training.util import notebook_util, test_config_util
23-
from xlml.apis import gcp_config, metric_config, task, test_config
15+
2416

2517
SCHEDULE = "0 21 * * *" if composer_env.is_prod_env() else None
2618
DAG_TEST_NAME = "maxtext_rl_notebook"
27-
DEFAULT_BUCKET = "gs://rl-automation"
2819

2920
with models.DAG(
3021
dag_id=DAG_TEST_NAME,
@@ -42,7 +33,7 @@
4233
"v6e-8",
4334
"nightly",
4435
],
45-
description="Automated Llama3.1-8B RL training from Jupyter notebook.",
36+
description="Automated Llama3.1-8B RL from Jupyter notebooks.",
4637
doc_md="""
4738
# Llama3.1-8B RL Training (Notebook Automation)
4839
@@ -55,12 +46,12 @@
5546
### Prerequisites
5647
- MaxText checkpoint for Llama3.1-8B-Instruct model
5748
- HuggingFace access token with read permissions
58-
- Single-host TPU VM (v6e-8 or v5p-8)
49+
- Single-host TPU VM (v6e-8)
5950
6051
### Execution Flow
6152
1. **TPU Creation:** Create TPU VM with required specifications
6253
2. **Environment Setup:** Clone MaxText, install dependencies
63-
3. **RL Training:** Execute GRPO/GSPO training with reward model
54+
3. **RL Training:** Execute RL (GRPO/GSPO) training with reward model
6455
4. **Log Validation:** Verify training completion signals
6556
5. **Cleanup:** Delete TPU resources
6657
@@ -73,82 +64,29 @@
7364
""",
7465
concurrency=1,
7566
) as dag:
76-
# Test configuration
77-
notebook_config = test_config_util.RLTestConfig(
78-
cluster=None, # Not used for TPU VM tests
79-
accelerator="v6e-8",
80-
slices=[1],
81-
model_name="llama3.1-8b",
82-
base_dir=f"{DEFAULT_BUCKET}/llama3.1-8b-Instruct/outputs",
83-
tokenizer_path="meta-llama/Llama-3.1-8B-Instruct",
84-
load_parameters_path=(
85-
f"{DEFAULT_BUCKET}/llama3.1-8b-Instruct/scanned-pathways/0/items"
86-
),
87-
loss_algos=[
88-
test_config_util.LossAlgo.GRPO,
89-
test_config_util.LossAlgo.GSPO,
90-
],
91-
)
92-
9367
# HF token retrieved from Airflow Variables
9468
HF_TOKEN_LLAMA31 = models.Variable.get("HF_TOKEN_CIENET", None)
9569

70+
loss_algos = [
71+
test_config_util.LossAlgo.GRPO,
72+
test_config_util.LossAlgo.GSPO,
73+
]
9674
# Test configuration
9775
test_run_name = "llama31_rl_notebook"
9876
current_datetime = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
9977

10078
# Setup commands for MaxText environment
10179
setup_script = notebook_util.build_maxtext_setup_script()
10280

103-
# Path to the RL demo notebook
104-
notebook_path = "src/MaxText/examples/rl_llama3_demo.ipynb"
105-
10681
# Test both GRPO and GSPO algorithms
107-
for loss_algo in notebook_config.loss_algos:
108-
run_name = f"{loss_algo.value}-{current_datetime}"
109-
110-
# Parameters to inject into notebook
111-
notebook_params = {
112-
"MODEL_CHECKPOINT_PATH": notebook_config.load_parameters_path,
113-
"OUTPUT_DIRECTORY": notebook_config.base_dir,
114-
"LOSS_ALGO": loss_algo.loss_name,
115-
}
116-
117-
# Build notebook execution command
118-
notebook_execution = notebook_util.build_notebook_execution_command(
119-
notebook_path=notebook_path,
120-
parameters=notebook_params,
121-
maxtext_path="maxtext",
122-
venv_path="maxtext_venv",
123-
env_params={"HF_TOKEN": HF_TOKEN_LLAMA31},
124-
)
125-
126-
# Create TPU VM test configuration
127-
rl_notebook_test = test_config.TpuVmTest(
128-
test_config.Tpu(
129-
version=TpuVersion.TRILLIUM,
130-
cores=8,
131-
runtime_version=RuntimeVersion.V2_ALPHA_TPUV6.value,
132-
reserved=False,
133-
network=V6E_GCE_NETWORK,
134-
subnetwork=V6E_GCE_SUBNETWORK,
135-
),
136-
test_name=f"{DAG_TEST_NAME}_{loss_algo.value}",
137-
set_up_cmds=[setup_script],
138-
run_model_cmds=[notebook_execution],
139-
timeout=datetime.timedelta(minutes=180),
140-
task_owner=test_owner.JACKY_F,
141-
num_slices=1,
142-
gcs_subfolder=f"{DEFAULT_BUCKET}/{DAG_TEST_NAME}",
82+
for loss_algo in loss_algos:
83+
rl_notebook_test = notebook_util.initialize_notebook_test(
84+
test_name=f"{DAG_TEST_NAME}_rl_{loss_algo.value}",
85+
dag_name=DAG_TEST_NAME,
86+
notebook_path="src/MaxText/examples/rl_llama3_demo.ipynb",
87+
set_up_script=setup_script,
88+
parameters={"LOSS_ALGO": loss_algo.loss_name},
89+
task_owner=test_owner.DEPP_L,
14390
)
14491

145-
# Run the training task
146-
training_task = task.run_queued_resource_test(
147-
task_test_config=rl_notebook_test,
148-
task_gcp_config=gcp_config.GCPConfig(
149-
project_name=Project.CLOUD_ML_AUTO_SOLUTIONS.value,
150-
zone=Zone.EUROPE_WEST4_A.value,
151-
dataset_name=metric_config.DatasetOption.XLML_DATASET,
152-
),
153-
skip_post_process=True,
154-
)
92+
notebook_util.run_training(rl_notebook_test, HF_TOKEN_LLAMA31)
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
"""
2+
Airflow DAG for automating Llama3.1-8B SFT training from Jupyter notebook.
3+
4+
This DAG automates the sft_llama3_demo.ipynb notebook, executing SFT
5+
training on single-host TPU VMs.
6+
"""
7+
8+
import datetime
9+
10+
from airflow import models
11+
12+
from dags import composer_env
13+
from dags.common import test_owner
14+
from dags.post_training.util import notebook_util
15+
16+
17+
SCHEDULE = "0 23 * * *" if composer_env.is_prod_env() else None
18+
DAG_TEST_NAME = "maxtext_sft_notebook"
19+
20+
with models.DAG(
21+
dag_id=DAG_TEST_NAME,
22+
start_date=datetime.datetime(2026, 1, 9),
23+
schedule_interval=SCHEDULE,
24+
catchup=False,
25+
tags=[
26+
"maxtext",
27+
"post-training",
28+
"sft",
29+
"notebook",
30+
"TPU",
31+
"v6e-8",
32+
"nightly",
33+
],
34+
description="Automated Llama3.1-8B SFT training from Jupyter notebook.",
35+
doc_md="""
36+
# Llama3.1-8B SFT Training (Notebook Automation)
37+
38+
### Overview
39+
This DAG automates the `sft_llama3_demo.ipynb` notebook, which
40+
demonstrates Supervised Fine-Tuning (SFT) on Llama3.1-8B-Instruct.
41+
It executes SFT training on single-host TPU VMs.
42+
43+
### Prerequisites
44+
- MaxText checkpoint for Llama3.1-8B-Instruct model
45+
- HuggingFace access token with read permissions
46+
- Single-host TPU VM (v6e-8)
47+
48+
### Execution Flow
49+
1. **TPU Creation:** Create TPU VM with required specifications
50+
2. **Environment Setup:** Clone MaxText, install dependencies
51+
3. **SFT Training:** Execute SFT training notebook
52+
4. **Log Validation:** Verify training completion signals
53+
5. **Cleanup:** Delete TPU resources
54+
55+
### Success Criteria
56+
The test passes when:
57+
1. TPU VM is created successfully
58+
2. Training completes without errors
59+
3. "SFT Training Completed Successfully" appears in logs
60+
4. Checkpoints are saved to output directory
61+
""",
62+
concurrency=1,
63+
) as dag:
64+
# HF token retrieved from Airflow Variables
65+
HF_TOKEN_LLAMA31 = models.Variable.get("HF_TOKEN_CIENET", None)
66+
67+
# Test configuration
68+
test_run_name = "llama31_rl_notebook"
69+
current_datetime = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
70+
71+
# Setup commands for MaxText environment
72+
setup_script = notebook_util.build_maxtext_setup_script()
73+
74+
# Test SFT training
75+
sft_notebook_test = notebook_util.initialize_notebook_test(
76+
test_name=f"{DAG_TEST_NAME}_sft",
77+
dag_name=DAG_TEST_NAME,
78+
notebook_path="src/MaxText/examples/sft_llama3_demo.ipynb",
79+
set_up_script=setup_script,
80+
parameters={},
81+
task_owner=test_owner.DEPP_L,
82+
)
83+
84+
notebook_util.run_training(sft_notebook_test, HF_TOKEN_LLAMA31)

dags/post_training/util/notebook_util.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,20 @@
11
"""Utility functions for automating Jupyter notebooks in Airflow."""
22

3+
import datetime
34
import inspect
45
import textwrap
6+
from airflow.models.taskmixin import DAGNode
7+
8+
from dags.common.vm_resource import (
9+
Project,
10+
RuntimeVersion,
11+
TpuVersion,
12+
V6E_GCE_NETWORK,
13+
V6E_GCE_SUBNETWORK,
14+
Zone,
15+
)
16+
from dags.post_training.util import test_config_util
17+
from xlml.apis import gcp_config, metric_config, task, test_config
518

619

720
def build_maxtext_setup_script() -> str:
@@ -66,6 +79,10 @@ def build_maxtext_setup_script() -> str:
6679
uv pip install -e .
6780
cd ..
6881
82+
uv pip install --no-deps qwix==0.1.4
83+
uv pip install --no-deps protobuf==5.29.5
84+
python3 -m pip freeze
85+
6986
# =======================================================================
7087
# Notebook Automation Tools
7188
# =======================================================================
@@ -253,3 +270,50 @@ def build_notebook_execution_command(
253270
notebook_run_script=notebook_run_script,
254271
verification_script=verification_script,
255272
)
273+
274+
275+
def initialize_notebook_test(
276+
test_name: str,
277+
dag_name: str,
278+
notebook_path: str,
279+
set_up_script: str,
280+
parameters: dict[str, any],
281+
task_owner: str,
282+
) -> test_config.TpuVmTest:
283+
"""Creates a TpuVmTest configuration for notebook execution."""
284+
notebook_execution = build_notebook_execution_command(
285+
notebook_path=notebook_path,
286+
parameters=parameters,
287+
maxtext_path="maxtext",
288+
venv_path="maxtext_venv",
289+
)
290+
return test_config.TpuVmTest(
291+
test_config.Tpu(
292+
version=TpuVersion.TRILLIUM,
293+
cores=8,
294+
runtime_version=RuntimeVersion.V2_ALPHA_TPUV6.value,
295+
reserved=False,
296+
network=V6E_GCE_NETWORK,
297+
subnetwork=V6E_GCE_SUBNETWORK,
298+
),
299+
test_name=test_name,
300+
set_up_cmds=[set_up_script],
301+
run_model_cmds=[notebook_execution],
302+
timeout=datetime.timedelta(minutes=180),
303+
task_owner=task_owner,
304+
num_slices=1,
305+
gcs_subfolder=f"{test_config_util.DEFAULT_BUCKET}/{dag_name}",
306+
)
307+
308+
309+
def run_training(config: test_config.TpuVmTest, hf_token: str) -> DAGNode:
310+
return task.run_queued_resource_test(
311+
task_test_config=config,
312+
task_gcp_config=gcp_config.GCPConfig(
313+
project_name=Project.CLOUD_ML_AUTO_SOLUTIONS.value,
314+
zone=Zone.EUROPE_WEST4_A.value,
315+
dataset_name=metric_config.DatasetOption.XLML_DATASET,
316+
),
317+
skip_post_process=True,
318+
custom_env={"HF_TOKEN": hf_token},
319+
)

xlml/apis/task.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def run_queued_resource_test(
6767
tpu_name_env_var: bool = False,
6868
all_workers: bool = True,
6969
skip_post_process: bool = False,
70+
custom_env: dict[str, str] = {},
7071
):
7172
"""This is a class to set up tasks for TPU provisioned by Queued Resource.
7273
@@ -86,6 +87,7 @@ def run_queued_resource_test(
8687
all_workers: The flag to define if run commands on all workers or worker 0
8788
only.
8889
skip_post_process: If True, the post processing step will be skipped.
90+
custom_env: Extra enviroment variables.
8991
9092
Returns:
9193
A task group with the following tasks chained: provision, run_model,
@@ -137,7 +139,12 @@ def run_queued_resource_test(
137139
task_test_config.test_script,
138140
ssh_keys,
139141
all_workers,
140-
env={metric_config.SshEnvVars.GCS_OUTPUT.name: output_location},
142+
# We purposely put `custom_env` last to allow overriding values.
143+
# For example, `GCS_OUTPUT` can be overridden if needed.
144+
env={
145+
metric_config.SshEnvVars.GCS_OUTPUT.name: output_location,
146+
**custom_env,
147+
},
141148
)
142149

143150
clean_up = tpu.delete_queued_resource.override(group_id="clean_up")(

0 commit comments

Comments
 (0)