Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions .github/workflows/e2e_test.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
name: ete_test

permissions:
contents: write
pages: write

on:
workflow_dispatch:
inputs:
Expand Down Expand Up @@ -34,3 +39,21 @@ jobs:
conda env list
unset HTTP_PROXY;unset HTTPS_PROXY;unset http_proxy;unset https_proxy;
pytest autotest/test_all.py -m all -n 1 -vv --run_id ${{ github.run_id }}
- name: Upload Artifacts
if: ${{ !cancelled() }}
uses: actions/upload-artifact@v4
with:
path: ${{ github.workspace }}/${{ github.run_id }}
if-no-files-found: ignore
retention-days: 7
name: xtuner-e2e-${{ github.run_id }}

- name: Deploy to GitHub Pages
if: ${{ !cancelled() }}
uses: JamesIves/github-pages-deploy-action@v4
with:
token: ${{ github.token }}
branch: gh-pages
folder: ./${{ github.run_id }}
target-folder: ${{ github.run_id }}
2 changes: 2 additions & 0 deletions .github/workflows/unit_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ on:
- "docs/**"
- "**.md"
- "autotest/**"
- ".github/workflows/e2e_test.yaml "
- ".github/workflows/lint.yml"
env:
WORKSPACE_PREFIX: $(echo $GITHUB_WORKSPACE |cut -d '/' -f 1-5)
WORKSPACE_PREFIX_SHORT: $(echo $GITHUB_WORKSPACE |cut -d '/' -f 1-3)
Expand Down
2 changes: 1 addition & 1 deletion autotest/module/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def validate(config):
)
cur_path = os.path.join(get_latest_subdir(work_dir), "logs/exp_tracking/rank0/tracker.jsonl")
check_metrics = config.get("assert_info", {}).get("check_metrics", {})
return check_result(base_path, cur_path, check_metrics)
return check_result(config["case_name"], base_path, cur_path, check_metrics)

def pre_action(config=None):
action_info = config.get("pre_action", None)
Expand Down
85 changes: 82 additions & 3 deletions autotest/utils/check_metric.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import json
import logging
import os
import shutil
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
from statistics import mean


logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
)
Expand All @@ -21,8 +25,77 @@ def extract_value(file, metrics):

return total_step, metric_all

def plot_all(case_name, check_metric, base_metrics, cur_metrics, output_root: Path):
metric_list = list(check_metric.keys())
n_plots = len(metric_list)
n_cols = int(np.ceil(np.sqrt(n_plots)))
n_rows = int(np.ceil(n_plots / n_cols))
fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 4, n_rows * 3))
axes = np.array(axes).flatten()

for i, ax in enumerate(axes):
if i < n_plots:
x_base = np.arange(len(base_metrics[metric_list[i]]))
x_current = np.arange(len(cur_metrics[metric_list[i]]))
ax.plot(
x_base,
base_metrics[metric_list[i]],
"r--",
label="Base",
marker="x",
markersize=4,
)
ax.plot(
x_current,
cur_metrics[metric_list[i]],
"b-",
label="Current",
marker="o",
markersize=4,
)
ax.set_title(f"{metric_list[i].replace('/', '_')}_comparison")
ax.set_xlabel("Step")
ax.set_ylabel("Value")
ax.legend()
ax.grid(True, linestyle="--", alpha=0.7)
else:
ax.axis("off")
fig.suptitle(f"{case_name}_metrics_comparison", fontsize=16)
plt.tight_layout()
plt.savefig(output_root / f"{case_name}_comparison.png")
plt.close()


def write_to_summary(case_name, base_jsonl, cur_jsonl ):

summary_file = os.environ.get('GITHUB_STEP_SUMMARY', './tmp.md')
repo_owner = os.environ.get('GITHUB_REPOSITORY_OWNER', 'internlm')
run_id = os.environ.get('GITHUB_RUN_ID', '0')
with open(summary_file, 'a') as f:
f.write(f"## {case_name}指标比较图\n")
f.write('<div align="center">\n')
f.write(f'<img src="https://{repo_owner}.github.io/xtuner/{run_id}/{case_name}_comparison.png"\n')
f.write(' style="max-width: 90%; border: 1px solid #ddd; border-radius: 8px;">\n')
f.write('</div>\n<div align=center>\n')
f.write(f'<details>\n<summary><strong style="text-align: left;">📊 点击查看用例{case_name}指标数据,依次为基线、当前版本数据</strong></summary>\n\n')

def check_result(base_path, cur_path, check_metric):
for json_f in [base_jsonl, cur_jsonl]:
with open(json_f, 'r', encoding='utf-8') as f:
lines = [line.strip() for line in f if line.strip()]

md_content = '```json\n'
for i, line in enumerate(lines, 1):
md_content += f'{line}\n'

md_content += '```\n\n'

with open(summary_file, 'a', encoding='utf-8') as f:
f.write(md_content)
with open(summary_file, 'a') as f:
f.write('</details>\n\n')


def check_result(case_name, base_path, cur_path, check_metric):
fail_metric = {}
check_metric = check_metric
metric_list = list(check_metric.keys())
Expand All @@ -32,6 +105,12 @@ def check_result(base_path, cur_path, check_metric):
f"current steps is not equal to base steps, current steps: {cur_steps}, base steps: {base_steps}"
)

output_path = Path(f"../{os.environ.get('GITHUB_RUN_ID','0')}")
output_path.mkdir(parents=True, exist_ok=True)
plot_all(case_name, check_metric, base_metrics, cur_metrics, output_path)
shutil.copytree(output_path, f"./{os.environ['GITHUB_RUN_ID']}", dirs_exist_ok=True)
write_to_summary(case_name, base_path, cur_path)

for metric, threshold in check_metric.items():
max_error = 0.0
max_error_idx = 0
Expand Down Expand Up @@ -75,4 +154,4 @@ def check_result(base_path, cur_path, check_metric):
return result, f"Some metric check failed,{fail_metric}"

if __name__ == "__main__":
print(check_result("./base//tracker.jsonl","./current/tracker.jsonl",{"grad_norm":0.000001,"loss/reduced_llm_loss":0.000001,"lr":0,"memory/max_memory_GB":0.2,"runtime_info/tgs":0.05,"runtime_info/text_tokens":0}))
print(check_result("qwen3-sft", "./base/tracker.jsonl", "./current/tracker.jsonl",{"grad_norm":0.000001,"loss/reduced_llm_loss":0.000001,"lr":0,"memory/max_memory_GB":0.2,"runtime_info/tgs":0.05,"runtime_info/text_tokens":0}))