|
1 | 1 | """ |
2 | | -Airflow DAG for automating Llama3.1-8B RL training from Jupyter notebook. |
| 2 | +Airflow DAG for automating Llama3.1-8B RL training from Jupyter notebooks. |
3 | 3 |
|
4 | 4 | This DAG automates the rl_llama3_demo.ipynb notebook, executing GRPO/GSPO |
5 | 5 | training on single-host TPU VMs. |
|
11 | 11 |
|
12 | 12 | from dags import composer_env |
13 | 13 | from dags.common import test_owner |
14 | | -from dags.common.vm_resource import ( |
15 | | - Project, |
16 | | - RuntimeVersion, |
17 | | - TpuVersion, |
18 | | - V6E_GCE_NETWORK, |
19 | | - V6E_GCE_SUBNETWORK, |
20 | | - Zone, |
21 | | -) |
22 | 14 | from dags.post_training.util import notebook_util, test_config_util |
23 | | -from xlml.apis import gcp_config, metric_config, task, test_config |
| 15 | + |
24 | 16 |
|
25 | 17 | SCHEDULE = "0 21 * * *" if composer_env.is_prod_env() else None |
26 | 18 | DAG_TEST_NAME = "maxtext_rl_notebook" |
27 | | -DEFAULT_BUCKET = "gs://rl-automation" |
28 | 19 |
|
29 | 20 | with models.DAG( |
30 | 21 | dag_id=DAG_TEST_NAME, |
|
42 | 33 | "v6e-8", |
43 | 34 | "nightly", |
44 | 35 | ], |
45 | | - description="Automated Llama3.1-8B RL training from Jupyter notebook.", |
| 36 | + description="Automated Llama3.1-8B RL from Jupyter notebooks.", |
46 | 37 | doc_md=""" |
47 | 38 | # Llama3.1-8B RL Training (Notebook Automation) |
48 | 39 |
|
|
55 | 46 | ### Prerequisites |
56 | 47 | - MaxText checkpoint for Llama3.1-8B-Instruct model |
57 | 48 | - HuggingFace access token with read permissions |
58 | | - - Single-host TPU VM (v6e-8 or v5p-8) |
| 49 | + - Single-host TPU VM (v6e-8) |
59 | 50 |
|
60 | 51 | ### Execution Flow |
61 | 52 | 1. **TPU Creation:** Create TPU VM with required specifications |
62 | 53 | 2. **Environment Setup:** Clone MaxText, install dependencies |
63 | | - 3. **RL Training:** Execute GRPO/GSPO training with reward model |
| 54 | + 3. **RL Training:** Execute RL (GRPO/GSPO) training with reward model |
64 | 55 | 4. **Log Validation:** Verify training completion signals |
65 | 56 | 5. **Cleanup:** Delete TPU resources |
66 | 57 |
|
|
73 | 64 | """, |
74 | 65 | concurrency=1, |
75 | 66 | ) as dag: |
76 | | - # Test configuration |
77 | | - notebook_config = test_config_util.RLTestConfig( |
78 | | - cluster=None, # Not used for TPU VM tests |
79 | | - accelerator="v6e-8", |
80 | | - slices=[1], |
81 | | - model_name="llama3.1-8b", |
82 | | - base_dir=f"{DEFAULT_BUCKET}/llama3.1-8b-Instruct/outputs", |
83 | | - tokenizer_path="meta-llama/Llama-3.1-8B-Instruct", |
84 | | - load_parameters_path=( |
85 | | - f"{DEFAULT_BUCKET}/llama3.1-8b-Instruct/scanned-pathways/0/items" |
86 | | - ), |
87 | | - loss_algos=[ |
88 | | - test_config_util.LossAlgo.GRPO, |
89 | | - test_config_util.LossAlgo.GSPO, |
90 | | - ], |
91 | | - ) |
92 | | - |
93 | 67 | # HF token retrieved from Airflow Variables |
94 | 68 | HF_TOKEN_LLAMA31 = models.Variable.get("HF_TOKEN_CIENET", None) |
95 | 69 |
|
| 70 | + loss_algos = [ |
| 71 | + test_config_util.LossAlgo.GRPO, |
| 72 | + test_config_util.LossAlgo.GSPO, |
| 73 | + ] |
96 | 74 | # Test configuration |
97 | 75 | test_run_name = "llama31_rl_notebook" |
98 | 76 | current_datetime = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") |
99 | 77 |
|
100 | 78 | # Setup commands for MaxText environment |
101 | 79 | setup_script = notebook_util.build_maxtext_setup_script() |
102 | 80 |
|
103 | | - # Path to the RL demo notebook |
104 | | - notebook_path = "src/MaxText/examples/rl_llama3_demo.ipynb" |
105 | | - |
106 | 81 | # Test both GRPO and GSPO algorithms |
107 | | - for loss_algo in notebook_config.loss_algos: |
108 | | - run_name = f"{loss_algo.value}-{current_datetime}" |
109 | | - |
110 | | - # Parameters to inject into notebook |
111 | | - notebook_params = { |
112 | | - "MODEL_CHECKPOINT_PATH": notebook_config.load_parameters_path, |
113 | | - "OUTPUT_DIRECTORY": notebook_config.base_dir, |
114 | | - "LOSS_ALGO": loss_algo.loss_name, |
115 | | - } |
116 | | - |
117 | | - # Build notebook execution command |
118 | | - notebook_execution = notebook_util.build_notebook_execution_command( |
119 | | - notebook_path=notebook_path, |
120 | | - parameters=notebook_params, |
121 | | - maxtext_path="maxtext", |
122 | | - venv_path="maxtext_venv", |
123 | | - env_params={"HF_TOKEN": HF_TOKEN_LLAMA31}, |
124 | | - ) |
125 | | - |
126 | | - # Create TPU VM test configuration |
127 | | - rl_notebook_test = test_config.TpuVmTest( |
128 | | - test_config.Tpu( |
129 | | - version=TpuVersion.TRILLIUM, |
130 | | - cores=8, |
131 | | - runtime_version=RuntimeVersion.V2_ALPHA_TPUV6.value, |
132 | | - reserved=False, |
133 | | - network=V6E_GCE_NETWORK, |
134 | | - subnetwork=V6E_GCE_SUBNETWORK, |
135 | | - ), |
136 | | - test_name=f"{DAG_TEST_NAME}_{loss_algo.value}", |
137 | | - set_up_cmds=[setup_script], |
138 | | - run_model_cmds=[notebook_execution], |
139 | | - timeout=datetime.timedelta(minutes=180), |
140 | | - task_owner=test_owner.JACKY_F, |
141 | | - num_slices=1, |
142 | | - gcs_subfolder=f"{DEFAULT_BUCKET}/{DAG_TEST_NAME}", |
| 82 | + for loss_algo in loss_algos: |
| 83 | + rl_notebook_test = notebook_util.initialize_notebook_test( |
| 84 | + test_name=f"{DAG_TEST_NAME}_rl_{loss_algo.value}", |
| 85 | + dag_name=DAG_TEST_NAME, |
| 86 | + notebook_path="src/MaxText/examples/rl_llama3_demo.ipynb", |
| 87 | + set_up_script=setup_script, |
| 88 | + parameters={"LOSS_ALGO": loss_algo.loss_name}, |
| 89 | + task_owner=test_owner.DEPP_L, |
143 | 90 | ) |
144 | 91 |
|
145 | | - # Run the training task |
146 | | - training_task = task.run_queued_resource_test( |
147 | | - task_test_config=rl_notebook_test, |
148 | | - task_gcp_config=gcp_config.GCPConfig( |
149 | | - project_name=Project.CLOUD_ML_AUTO_SOLUTIONS.value, |
150 | | - zone=Zone.EUROPE_WEST4_A.value, |
151 | | - dataset_name=metric_config.DatasetOption.XLML_DATASET, |
152 | | - ), |
153 | | - skip_post_process=True, |
154 | | - ) |
| 92 | + notebook_util.run_training(rl_notebook_test, HF_TOKEN_LLAMA31) |
0 commit comments