Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 18 additions & 80 deletions dags/post_training/maxtext_rl_notebook.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Airflow DAG for automating Llama3.1-8B RL training from Jupyter notebook.
Airflow DAG for automating Llama3.1-8B RL training from Jupyter notebooks.

This DAG automates the rl_llama3_demo.ipynb notebook, executing GRPO/GSPO
training on single-host TPU VMs.
Expand All @@ -11,20 +11,11 @@

from dags import composer_env
from dags.common import test_owner
from dags.common.vm_resource import (
Project,
RuntimeVersion,
TpuVersion,
V6E_GCE_NETWORK,
V6E_GCE_SUBNETWORK,
Zone,
)
from dags.post_training.util import notebook_util, test_config_util
from xlml.apis import gcp_config, metric_config, task, test_config


SCHEDULE = "0 21 * * *" if composer_env.is_prod_env() else None
DAG_TEST_NAME = "maxtext_rl_notebook"
DEFAULT_BUCKET = "gs://rl-automation"

with models.DAG(
dag_id=DAG_TEST_NAME,
Expand All @@ -42,7 +33,7 @@
"v6e-8",
"nightly",
],
description="Automated Llama3.1-8B RL training from Jupyter notebook.",
description="Automated Llama3.1-8B RL from Jupyter notebooks.",
doc_md="""
# Llama3.1-8B RL Training (Notebook Automation)

Expand All @@ -55,12 +46,12 @@
### Prerequisites
- MaxText checkpoint for Llama3.1-8B-Instruct model
- HuggingFace access token with read permissions
- Single-host TPU VM (v6e-8 or v5p-8)
- Single-host TPU VM (v6e-8)

### Execution Flow
1. **TPU Creation:** Create TPU VM with required specifications
2. **Environment Setup:** Clone MaxText, install dependencies
3. **RL Training:** Execute GRPO/GSPO training with reward model
3. **RL Training:** Execute RL (GRPO/GSPO) training with reward model
4. **Log Validation:** Verify training completion signals
5. **Cleanup:** Delete TPU resources

Expand All @@ -73,82 +64,29 @@
""",
concurrency=1,
) as dag:
# Test configuration
notebook_config = test_config_util.RLTestConfig(
cluster=None, # Not used for TPU VM tests
accelerator="v6e-8",
slices=[1],
model_name="llama3.1-8b",
base_dir=f"{DEFAULT_BUCKET}/llama3.1-8b-Instruct/outputs",
tokenizer_path="meta-llama/Llama-3.1-8B-Instruct",
load_parameters_path=(
f"{DEFAULT_BUCKET}/llama3.1-8b-Instruct/scanned-pathways/0/items"
),
loss_algos=[
test_config_util.LossAlgo.GRPO,
test_config_util.LossAlgo.GSPO,
],
)

# HF token retrieved from Airflow Variables
HF_TOKEN_LLAMA31 = models.Variable.get("HF_TOKEN_CIENET", None)

loss_algos = [
test_config_util.LossAlgo.GRPO,
test_config_util.LossAlgo.GSPO,
]
# Test configuration
test_run_name = "llama31_rl_notebook"
current_datetime = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")

# Setup commands for MaxText environment
setup_script = notebook_util.build_maxtext_setup_script()

# Path to the RL demo notebook
notebook_path = "src/MaxText/examples/rl_llama3_demo.ipynb"

# Test both GRPO and GSPO algorithms
for loss_algo in notebook_config.loss_algos:
run_name = f"{loss_algo.value}-{current_datetime}"

# Parameters to inject into notebook
notebook_params = {
"MODEL_CHECKPOINT_PATH": notebook_config.load_parameters_path,
"OUTPUT_DIRECTORY": notebook_config.base_dir,
"LOSS_ALGO": loss_algo.loss_name,
}

# Build notebook execution command
notebook_execution = notebook_util.build_notebook_execution_command(
notebook_path=notebook_path,
parameters=notebook_params,
maxtext_path="maxtext",
venv_path="maxtext_venv",
env_params={"HF_TOKEN": HF_TOKEN_LLAMA31},
)

# Create TPU VM test configuration
rl_notebook_test = test_config.TpuVmTest(
test_config.Tpu(
version=TpuVersion.TRILLIUM,
cores=8,
runtime_version=RuntimeVersion.V2_ALPHA_TPUV6.value,
reserved=False,
network=V6E_GCE_NETWORK,
subnetwork=V6E_GCE_SUBNETWORK,
),
test_name=f"{DAG_TEST_NAME}_{loss_algo.value}",
set_up_cmds=[setup_script],
run_model_cmds=[notebook_execution],
timeout=datetime.timedelta(minutes=180),
task_owner=test_owner.JACKY_F,
num_slices=1,
gcs_subfolder=f"{DEFAULT_BUCKET}/{DAG_TEST_NAME}",
for loss_algo in loss_algos:
rl_notebook_test = notebook_util.initialize_notebook_test(
test_name=f"{DAG_TEST_NAME}_rl_{loss_algo.value}",
dag_name=DAG_TEST_NAME,
notebook_path="src/MaxText/examples/rl_llama3_demo.ipynb",
set_up_script=setup_script,
parameters={"LOSS_ALGO": loss_algo.loss_name},
task_owner=test_owner.DEPP_L,
)

# Run the training task
training_task = task.run_queued_resource_test(
task_test_config=rl_notebook_test,
task_gcp_config=gcp_config.GCPConfig(
project_name=Project.CLOUD_ML_AUTO_SOLUTIONS.value,
zone=Zone.EUROPE_WEST4_A.value,
dataset_name=metric_config.DatasetOption.XLML_DATASET,
),
skip_post_process=True,
)
notebook_util.run_training(rl_notebook_test, HF_TOKEN_LLAMA31)
84 changes: 84 additions & 0 deletions dags/post_training/maxtext_sft_notebook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""
Airflow DAG for automating Llama3.1-8B SFT training from Jupyter notebook.

This DAG automates the sft_llama3_demo.ipynb notebook, executing SFT
training on single-host TPU VMs.
"""

import datetime

from airflow import models

from dags import composer_env
from dags.common import test_owner
from dags.post_training.util import notebook_util


SCHEDULE = "0 23 * * *" if composer_env.is_prod_env() else None
DAG_TEST_NAME = "maxtext_sft_notebook"

with models.DAG(
dag_id=DAG_TEST_NAME,
start_date=datetime.datetime(2026, 1, 9),
schedule_interval=SCHEDULE,
catchup=False,
tags=[
"maxtext",
"post-training",
"sft",
"notebook",
"TPU",
"v6e-8",
"nightly",
],
description="Automated Llama3.1-8B SFT training from Jupyter notebook.",
doc_md="""
# Llama3.1-8B SFT Training (Notebook Automation)

### Overview
This DAG automates the `sft_llama3_demo.ipynb` notebook, which
demonstrates Supervised Fine-Tuning (SFT) on Llama3.1-8B-Instruct.
It executes SFT training on single-host TPU VMs.

### Prerequisites
- MaxText checkpoint for Llama3.1-8B-Instruct model
- HuggingFace access token with read permissions
- Single-host TPU VM (v6e-8)

### Execution Flow
1. **TPU Creation:** Create TPU VM with required specifications
2. **Environment Setup:** Clone MaxText, install dependencies
3. **SFT Training:** Execute SFT training notebook
4. **Log Validation:** Verify training completion signals
5. **Cleanup:** Delete TPU resources

### Success Criteria
The test passes when:
1. TPU VM is created successfully
2. Training completes without errors
3. "SFT Training Completed Successfully" appears in logs
4. Checkpoints are saved to output directory
""",
concurrency=1,
) as dag:
# HF token retrieved from Airflow Variables
HF_TOKEN_LLAMA31 = models.Variable.get("HF_TOKEN_CIENET", None)

# Test configuration
test_run_name = "llama31_rl_notebook"
current_datetime = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")

# Setup commands for MaxText environment
setup_script = notebook_util.build_maxtext_setup_script()

# Test SFT training
sft_notebook_test = notebook_util.initialize_notebook_test(
test_name=f"{DAG_TEST_NAME}_sft",
dag_name=DAG_TEST_NAME,
notebook_path="src/MaxText/examples/sft_llama3_demo.ipynb",
set_up_script=setup_script,
parameters={},
task_owner=test_owner.DEPP_L,
)

notebook_util.run_training(sft_notebook_test, HF_TOKEN_LLAMA31)
64 changes: 64 additions & 0 deletions dags/post_training/util/notebook_util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,20 @@
"""Utility functions for automating Jupyter notebooks in Airflow."""

import datetime
import inspect
import textwrap
from airflow.models.taskmixin import DAGNode

from dags.common.vm_resource import (
Project,
RuntimeVersion,
TpuVersion,
V6E_GCE_NETWORK,
V6E_GCE_SUBNETWORK,
Zone,
)
from dags.post_training.util import test_config_util
from xlml.apis import gcp_config, metric_config, task, test_config


def build_maxtext_setup_script() -> str:
Expand Down Expand Up @@ -66,6 +79,10 @@ def build_maxtext_setup_script() -> str:
uv pip install -e .
cd ..

uv pip install --no-deps qwix==0.1.4
uv pip install --no-deps protobuf==5.29.5
python3 -m pip freeze

# =======================================================================
# Notebook Automation Tools
# =======================================================================
Expand Down Expand Up @@ -253,3 +270,50 @@ def build_notebook_execution_command(
notebook_run_script=notebook_run_script,
verification_script=verification_script,
)


def initialize_notebook_test(
test_name: str,
dag_name: str,
notebook_path: str,
set_up_script: str,
parameters: dict[str, any],
task_owner: str,
) -> test_config.TpuVmTest:
"""Creates a TpuVmTest configuration for notebook execution."""
notebook_execution = build_notebook_execution_command(
notebook_path=notebook_path,
parameters=parameters,
maxtext_path="maxtext",
venv_path="maxtext_venv",
)
return test_config.TpuVmTest(
test_config.Tpu(
version=TpuVersion.TRILLIUM,
cores=8,
runtime_version=RuntimeVersion.V2_ALPHA_TPUV6.value,
reserved=False,
network=V6E_GCE_NETWORK,
subnetwork=V6E_GCE_SUBNETWORK,
),
test_name=test_name,
set_up_cmds=[set_up_script],
run_model_cmds=[notebook_execution],
timeout=datetime.timedelta(minutes=180),
task_owner=task_owner,
num_slices=1,
gcs_subfolder=f"{test_config_util.DEFAULT_BUCKET}/{dag_name}",
)


def run_training(config: test_config.TpuVmTest, hf_token: str) -> DAGNode:
return task.run_queued_resource_test(
task_test_config=config,
task_gcp_config=gcp_config.GCPConfig(
project_name=Project.CLOUD_ML_AUTO_SOLUTIONS.value,
zone=Zone.EUROPE_WEST4_A.value,
dataset_name=metric_config.DatasetOption.XLML_DATASET,
),
skip_post_process=True,
custom_env={"HF_TOKEN": hf_token},
)
9 changes: 8 additions & 1 deletion xlml/apis/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def run_queued_resource_test(
tpu_name_env_var: bool = False,
all_workers: bool = True,
skip_post_process: bool = False,
custom_env: dict[str, str] = {},
):
"""This is a class to set up tasks for TPU provisioned by Queued Resource.

Expand All @@ -86,6 +87,7 @@ def run_queued_resource_test(
all_workers: The flag to define if run commands on all workers or worker 0
only.
skip_post_process: If True, the post processing step will be skipped.
custom_env: Extra enviroment variables.

Returns:
A task group with the following tasks chained: provision, run_model,
Expand Down Expand Up @@ -137,7 +139,12 @@ def run_queued_resource_test(
task_test_config.test_script,
ssh_keys,
all_workers,
env={metric_config.SshEnvVars.GCS_OUTPUT.name: output_location},
# We purposely put `custom_env` last to allow overriding values.
# For example, `GCS_OUTPUT` can be overridden if needed.
env={
metric_config.SshEnvVars.GCS_OUTPUT.name: output_location,
**custom_env,
},
)

clean_up = tpu.delete_queued_resource.override(group_id="clean_up")(
Expand Down
Loading