Skip to content

Commit 3b3fae6

Browse files
authored
CI: display image about metirc (#1411)
* update * run cases * run all * run all cases * ready to PR * ready to PR
1 parent b917021 commit 3b3fae6

File tree

4 files changed

+108
-4
lines changed

4 files changed

+108
-4
lines changed

.github/workflows/e2e_test.yaml

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
11
name: ete_test
2+
3+
permissions:
4+
contents: write
5+
pages: write
6+
27
on:
38
workflow_dispatch:
49
inputs:
@@ -34,3 +39,21 @@ jobs:
3439
conda env list
3540
unset HTTP_PROXY;unset HTTPS_PROXY;unset http_proxy;unset https_proxy;
3641
pytest autotest/test_all.py -m all -n 1 -vv --run_id ${{ github.run_id }}
42+
43+
- name: Upload Artifacts
44+
if: ${{ !cancelled() }}
45+
uses: actions/upload-artifact@v4
46+
with:
47+
path: ${{ github.workspace }}/${{ github.run_id }}
48+
if-no-files-found: ignore
49+
retention-days: 7
50+
name: xtuner-e2e-${{ github.run_id }}
51+
52+
- name: Deploy to GitHub Pages
53+
if: ${{ !cancelled() }}
54+
uses: JamesIves/github-pages-deploy-action@v4
55+
with:
56+
token: ${{ github.token }}
57+
branch: gh-pages
58+
folder: ./${{ github.run_id }}
59+
target-folder: ${{ github.run_id }}

.github/workflows/unit_test.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ on:
88
- "docs/**"
99
- "**.md"
1010
- "autotest/**"
11+
- ".github/workflows/e2e_test.yaml "
12+
- ".github/workflows/lint.yml"
1113
env:
1214
WORKSPACE_PREFIX: $(echo $GITHUB_WORKSPACE |cut -d '/' -f 1-5)
1315
WORKSPACE_PREFIX_SHORT: $(echo $GITHUB_WORKSPACE |cut -d '/' -f 1-3)

autotest/module/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def validate(config):
5656
)
5757
cur_path = os.path.join(get_latest_subdir(work_dir), "logs/exp_tracking/rank0/tracker.jsonl")
5858
check_metrics = config.get("assert_info", {}).get("check_metrics", {})
59-
return check_result(base_path, cur_path, check_metrics)
59+
return check_result(config["case_name"], base_path, cur_path, check_metrics)
6060

6161
def pre_action(config=None):
6262
action_info = config.get("pre_action", None)

autotest/utils/check_metric.py

Lines changed: 82 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
import json
22
import logging
3+
import os
4+
import shutil
5+
import matplotlib.pyplot as plt
6+
import numpy as np
7+
from pathlib import Path
38
from statistics import mean
49

5-
610
logging.basicConfig(
711
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
812
)
@@ -21,8 +25,77 @@ def extract_value(file, metrics):
2125

2226
return total_step, metric_all
2327

28+
def plot_all(case_name, check_metric, base_metrics, cur_metrics, output_root: Path):
29+
metric_list = list(check_metric.keys())
30+
n_plots = len(metric_list)
31+
n_cols = int(np.ceil(np.sqrt(n_plots)))
32+
n_rows = int(np.ceil(n_plots / n_cols))
33+
fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 4, n_rows * 3))
34+
axes = np.array(axes).flatten()
35+
36+
for i, ax in enumerate(axes):
37+
if i < n_plots:
38+
x_base = np.arange(len(base_metrics[metric_list[i]]))
39+
x_current = np.arange(len(cur_metrics[metric_list[i]]))
40+
ax.plot(
41+
x_base,
42+
base_metrics[metric_list[i]],
43+
"r--",
44+
label="Base",
45+
marker="x",
46+
markersize=4,
47+
)
48+
ax.plot(
49+
x_current,
50+
cur_metrics[metric_list[i]],
51+
"b-",
52+
label="Current",
53+
marker="o",
54+
markersize=4,
55+
)
56+
ax.set_title(f"{metric_list[i].replace('/', '_')}_comparison")
57+
ax.set_xlabel("Step")
58+
ax.set_ylabel("Value")
59+
ax.legend()
60+
ax.grid(True, linestyle="--", alpha=0.7)
61+
else:
62+
ax.axis("off")
63+
fig.suptitle(f"{case_name}_metrics_comparison", fontsize=16)
64+
plt.tight_layout()
65+
plt.savefig(output_root / f"{case_name}_comparison.png")
66+
plt.close()
67+
68+
69+
def write_to_summary(case_name, base_jsonl, cur_jsonl ):
70+
71+
summary_file = os.environ.get('GITHUB_STEP_SUMMARY', './tmp.md')
72+
repo_owner = os.environ.get('GITHUB_REPOSITORY_OWNER', 'internlm')
73+
run_id = os.environ.get('GITHUB_RUN_ID', '0')
74+
with open(summary_file, 'a') as f:
75+
f.write(f"## {case_name}指标比较图\n")
76+
f.write('<div align="center">\n')
77+
f.write(f'<img src="https://{repo_owner}.github.io/xtuner/{run_id}/{case_name}_comparison.png"\n')
78+
f.write(' style="max-width: 90%; border: 1px solid #ddd; border-radius: 8px;">\n')
79+
f.write('</div>\n<div align=center>\n')
80+
f.write(f'<details>\n<summary><strong style="text-align: left;">📊 点击查看用例{case_name}指标数据,依次为基线、当前版本数据</strong></summary>\n\n')
2481

25-
def check_result(base_path, cur_path, check_metric):
82+
for json_f in [base_jsonl, cur_jsonl]:
83+
with open(json_f, 'r', encoding='utf-8') as f:
84+
lines = [line.strip() for line in f if line.strip()]
85+
86+
md_content = '```json\n'
87+
for i, line in enumerate(lines, 1):
88+
md_content += f'{line}\n'
89+
90+
md_content += '```\n\n'
91+
92+
with open(summary_file, 'a', encoding='utf-8') as f:
93+
f.write(md_content)
94+
with open(summary_file, 'a') as f:
95+
f.write('</details>\n\n')
96+
97+
98+
def check_result(case_name, base_path, cur_path, check_metric):
2699
fail_metric = {}
27100
check_metric = check_metric
28101
metric_list = list(check_metric.keys())
@@ -32,6 +105,12 @@ def check_result(base_path, cur_path, check_metric):
32105
f"current steps is not equal to base steps, current steps: {cur_steps}, base steps: {base_steps}"
33106
)
34107

108+
output_path = Path(f"../{os.environ.get('GITHUB_RUN_ID','0')}")
109+
output_path.mkdir(parents=True, exist_ok=True)
110+
plot_all(case_name, check_metric, base_metrics, cur_metrics, output_path)
111+
shutil.copytree(output_path, f"./{os.environ['GITHUB_RUN_ID']}", dirs_exist_ok=True)
112+
write_to_summary(case_name, base_path, cur_path)
113+
35114
for metric, threshold in check_metric.items():
36115
max_error = 0.0
37116
max_error_idx = 0
@@ -75,4 +154,4 @@ def check_result(base_path, cur_path, check_metric):
75154
return result, f"Some metric check failed,{fail_metric}"
76155

77156
if __name__ == "__main__":
78-
print(check_result("./base//tracker.jsonl","./current/tracker.jsonl",{"grad_norm":0.000001,"loss/reduced_llm_loss":0.000001,"lr":0,"memory/max_memory_GB":0.2,"runtime_info/tgs":0.05,"runtime_info/text_tokens":0}))
157+
print(check_result("qwen3-sft", "./base/tracker.jsonl", "./current/tracker.jsonl",{"grad_norm":0.000001,"loss/reduced_llm_loss":0.000001,"lr":0,"memory/max_memory_GB":0.2,"runtime_info/tgs":0.05,"runtime_info/text_tokens":0}))

0 commit comments

Comments
 (0)