Skip to content

Commit bc35383

Browse files
authored
【CI】add rl case (#1482)
* add rl cases * run cases * run cases * ready to PR * remove push
1 parent 72bac53 commit bc35383

File tree

4 files changed

+178
-27
lines changed

4 files changed

+178
-27
lines changed

autotest/cluster/clusterx.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,18 @@ def execute_task(self, task_config: Dict[str, Any]):
5858
raise RuntimeError(f"clusterx job {job_name} start fail, task config is {task_config}, exception is: {e}")
5959

6060
start_time = time.time()
61+
run_start_time = None
6162

6263
while True:
6364
status = self.get_task_status(job_schema.job_id)
65+
if status in [JobStatus.RUNNING] and run_start_time is None:
66+
run_start_time = time.time()
6467
if status in [JobStatus.SUCCEEDED]:
65-
return True, "Task succeeded"
68+
run_time = time.time() - run_start_time
69+
if run_time >= timeout:
70+
return False, 'Task succeeded, but run time is {run_time}, exceeding then {timeout}'
71+
else:
72+
return True, "Task succeeded"
6673
elif status in [JobStatus.FAILED, JobStatus.STOPPED]:
6774
if status in [JobStatus.FAILED]:
6875
time.sleep(10)

autotest/config.yaml

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,3 +208,46 @@ case:
208208
runtime_info/tgs: 0.05
209209
runtime_info/text_tokens: 0
210210
timeout: 10800
211+
212+
qwen3-rl-lmdeploy:
213+
-
214+
type: rl
215+
parameters:
216+
config: autotest/config/rl_qwen3_gsk8k_grpo.py
217+
infer_backend: lmdeploy
218+
output_path: /mnt/shared-storage-user/llmrazor-share/qa-llm-cicd/test_output
219+
resource:
220+
envs:
221+
- MODEL_PATH=/mnt/shared-storage-user/llmrazor-share/model/Qwen3-8B
222+
- DATA_PATH=/mnt/shared-storage-user/llmrazor-share/data/gsm8k/train-mini.jsonl
223+
- EVAL_DATA_PATH=/mnt/shared-storage-user/llmrazor-share/data/gsm8k/test.jsonl
224+
- XTUNER_DETERMINISTIC=true
225+
assert_info:
226+
base_metric: qwen3-rl-lmdeploy/20260203/tracker.jsonl
227+
check_metrics:
228+
-
229+
metric: eval/accuracy
230+
threshold: 0.05
231+
method: absolute
232+
operator: <
233+
-
234+
metric: response/rewards/mean
235+
threshold: 0.1
236+
method: absolute
237+
operator: <
238+
-
239+
metric: mismatch/mismatch_k3_kl
240+
threshold: 0.0001
241+
method: absolute
242+
operator: <=
243+
-
244+
metric: response/response_len/mean
245+
threshold: 0.12
246+
method: relative
247+
operator: <
248+
-
249+
metric: time/step
250+
threshold: 10
251+
method: absolute
252+
operator: <
253+
timeout: 2460

autotest/module/train.py

Lines changed: 43 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22

3-
from utils.check_metric import check_result
3+
from utils.check_metric import check_result, check_rl_result
44
from utils.run_cmd import run_cmd
55

66

@@ -25,28 +25,37 @@ def get_cmd(config):
2525
]
2626
)
2727

28-
command = (
29-
f"cd {current_dir}; pwd; pip install -e .[all]; pip install more-itertools; export GITHUB_RUN_ID={config.get('run_id')}; "
30-
+ f"torchrun --nproc-per-node {nproc_per_node} --master_addr=${{MASTER_ADDR}} --master_port=${{MASTER_PORT}} --nnodes=${{WORLD_SIZE}} --node_rank=${{RANK}} "
31-
+ f"xtuner/v1/train/cli/{train_type}.py"
32-
)
33-
if config_path:
34-
output_path = model_config = config.get("parameters", {}).get("output_path", ".")
35-
if output_path == ".":
36-
command += f" --config {config_path}; mkdir -p {work_dir}; mv {output_path}/.xtuner {work_dir}; mv {output_path}/202* {work_dir}"
28+
if train_type == "sft":
29+
command = (
30+
f"cd {current_dir}; pwd; pip install -e .[all]; pip install more-itertools; export GITHUB_RUN_ID={config.get('run_id')}; "
31+
+ f"torchrun --nproc-per-node {nproc_per_node} --master_addr=${{MASTER_ADDR}} --master_port=${{MASTER_PORT}} --nnodes=${{WORLD_SIZE}} --node_rank=${{RANK}} "
32+
+ f"xtuner/v1/train/cli/{train_type}.py"
33+
)
34+
if config_path:
35+
output_path = model_config = config.get("parameters", {}).get("output_path", ".")
36+
if output_path == ".":
37+
command += f" --config {config_path}; mkdir -p {work_dir}; mv {output_path}/.xtuner {work_dir}; mv {output_path}/202* {work_dir}"
38+
else:
39+
command += f" --config {config_path}"
3740
else:
38-
command += f" --config {config_path}"
39-
else:
40-
if model_config:
41-
command += f" --model-cfg {model_config}"
42-
if chat_template:
43-
command += f" --chat_template {chat_template}"
44-
if dataset_path:
45-
command += f" --dataset {dataset_path}"
46-
command += f" --work_dir {work_dir}"
41+
if model_config:
42+
command += f" --model-cfg {model_config}"
43+
if chat_template:
44+
command += f" --chat_template {chat_template}"
45+
if dataset_path:
46+
command += f" --dataset {dataset_path}"
47+
command += f" --work_dir {work_dir}"
4748

48-
config["work_dir"] = work_dir
49-
return command, config
49+
config["work_dir"] = work_dir
50+
return command, config
51+
elif train_type == "rl":
52+
infer_type = config.get("parameters", {}).get("infer_backend", "lmdeploy")
53+
config["work_dir"] = work_dir
54+
command = (
55+
f"cd {current_dir}; pwd; pip install -e .[all]; export GITHUB_RUN_ID={config.get('run_id')}; export WORK_DIR={work_dir}; "
56+
+ f"bash -x examples/v1/scripts/run_rl.sh {config_path} {infer_type} ${{MODEL_PATH}} ${{DATA_PATH}} ${{EVAL_DATA_PATH}}"
57+
)
58+
return command, config
5059
else:
5160
return "", config
5261

@@ -55,9 +64,18 @@ def validate(config):
5564
base_path = os.path.join(
5665
config.get("base_path").get("base_baseline_path"), config.get("assert_info", {}).get("base_metric", None)
5766
)
58-
cur_path = os.path.join(get_latest_subdir(work_dir), "logs/exp_tracking/rank0/tracker.jsonl")
59-
check_metrics = config.get("assert_info", {}).get("check_metrics", {})
60-
return check_result(config["case_name"], base_path, cur_path, check_metrics)
67+
train_type = config.get("type")
68+
if train_type == 'sft':
69+
cur_path = os.path.join(get_latest_subdir(work_dir), "logs/exp_tracking/rank0/tracker.jsonl")
70+
check_metrics = config.get("assert_info", {}).get("check_metrics", {})
71+
return check_result(config["case_name"], base_path, cur_path, check_metrics)
72+
elif train_type == 'rl':
73+
cur_path = os.path.join(get_latest_subdir(work_dir), "exp_tracking/tracker.jsonl")
74+
check_metrics = config.get("assert_info", {})
75+
return check_rl_result(config["case_name"], base_path, cur_path, check_metrics)
76+
else:
77+
print("Unknown type: {train_type}")
78+
return False
6179

6280
def pre_action(config=None):
6381
action_info = config.get("pre_action", None)
@@ -71,7 +89,7 @@ def post_action(config=None):
7189

7290

7391
def get_latest_subdir(work_dir):
74-
dirs = [d for d in os.listdir(work_dir) if os.path.isdir(os.path.join(work_dir, d))]
92+
dirs = [d for d in os.listdir(work_dir) if os.path.isdir(os.path.join(work_dir, d)) and len(d) == 14 and d.isdigit()]
7593

7694
if not dirs:
7795
return None

autotest/utils/check_metric.py

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ def extract_value(file, metrics):
2020
for line in f:
2121
line = json.loads(line)
2222
for metric in metrics:
23-
metric_all[metric].append(line[metric])
23+
if metric in line:
24+
metric_all[metric].append(line[metric])
2425
total_step += 1
2526

2627
return total_step, metric_all
@@ -154,5 +155,87 @@ def check_result(case_name, base_path, cur_path, check_metric):
154155
result = not fail_metric
155156
return result, f"Some metric check failed,{fail_metric}"
156157

158+
def check_rl_result(case_name, base_path, cur_path, assert_info):
159+
fail_metric = {}
160+
check_metrics_list = assert_info["check_metrics"]
161+
162+
metric_list = [item["metric"] for item in check_metrics_list]
163+
164+
base_steps, base_metrics = extract_value(base_path, metric_list)
165+
cur_steps, cur_metrics = extract_value(cur_path, metric_list)
166+
167+
assert (
168+
cur_steps == base_steps
169+
), f"current steps is not equal to base steps, current steps: {cur_steps}, base steps: {base_steps}"
170+
171+
output_path = Path(f"../{os.environ.get('GITHUB_RUN_ID','0')}")
172+
output_path.mkdir(parents=True, exist_ok=True)
173+
174+
check_metric_dict = {item["metric"]: item["threshold"] for item in check_metrics_list}
175+
plot_all(case_name, check_metric_dict, base_metrics, cur_metrics, output_path)
176+
177+
shutil.copytree(output_path, f"./{os.environ['GITHUB_RUN_ID']}", dirs_exist_ok=True)
178+
write_to_summary(case_name, base_path, cur_path)
179+
180+
for config in check_metrics_list:
181+
metric = config["metric"]
182+
threshold = config["threshold"]
183+
method = config["method"] # 'absolute' or 'relative'
184+
operator = config["operator"] # '<' or '<='
185+
186+
max_error = 0.0
187+
max_error_idx = 0
188+
check_flag = True
189+
190+
for idx, (base_val, cur_val) in enumerate(
191+
zip(base_metrics[metric], cur_metrics[metric])
192+
):
193+
if method == "absolute":
194+
error = round(abs(cur_val - base_val), 5)
195+
elif method == "relative":
196+
if abs(base_val) < 1e-10:
197+
error = float("inf") if abs(cur_val) > 1e-10 else 0.0
198+
else:
199+
error = round(abs(cur_val - base_val) / abs(base_val), 5)
200+
else:
201+
raise ValueError(f"Unknown method: {method}")
202+
203+
if error > max_error:
204+
max_error = error
205+
max_error_idx = idx
206+
207+
if operator == "<":
208+
if not (error < threshold):
209+
fail_metric[metric] = (
210+
f"{metric} error {error:.6f} not less than threshold {threshold} "
211+
f"(method: {method}, operator: {operator}) at step {idx}, "
212+
f"baseline: {base_val:.6f}, current: {cur_val:.6f}"
213+
)
214+
check_flag = False
215+
break
216+
elif operator == "<=":
217+
if not (error <= threshold):
218+
fail_metric[metric] = (
219+
f"{metric} error {error:.6f} not less than or equal to threshold {threshold} "
220+
f"(method: {method}, operator: {operator}) at step {idx}, "
221+
f"baseline: {base_val:.6f}, current: {cur_val:.6f}"
222+
)
223+
check_flag = False
224+
break
225+
else:
226+
raise ValueError(f"Unknown operator: {operator}")
227+
228+
if check_flag:
229+
logger.info(
230+
f"✓ {metric} check passed, max error is {max_error:.6f} at step {max_error_idx} "
231+
f"(method: {method}, operator: {operator})"
232+
)
233+
234+
result = not bool(fail_metric)
235+
if result:
236+
return result, "All metrics check passed."
237+
else:
238+
return result, f"Some metric check failed: {fail_metric}"
239+
157240
if __name__ == "__main__":
158241
print(check_result("qwen3-sft", "./base/tracker.jsonl", "./current/tracker.jsonl",{"grad_norm":0.000001,"loss/reduced_llm_loss":0.000001,"lr":0,"memory/max_memory_GB":0.2,"runtime_info/tgs":0.05,"runtime_info/text_tokens":0}))

0 commit comments

Comments
 (0)