11import json
22import logging
3+ import os
4+ import shutil
5+ import matplotlib .pyplot as plt
6+ import numpy as np
7+ from pathlib import Path
38from statistics import mean
49
5-
610logging .basicConfig (
711 level = logging .INFO , format = "%(asctime)s - %(levelname)s - %(message)s" , datefmt = "%Y-%m-%d %H:%M:%S"
812)
@@ -21,8 +25,77 @@ def extract_value(file, metrics):
2125
2226 return total_step , metric_all
2327
28+ def plot_all (case_name , check_metric , base_metrics , cur_metrics , output_root : Path ):
29+ metric_list = list (check_metric .keys ())
30+ n_plots = len (metric_list )
31+ n_cols = int (np .ceil (np .sqrt (n_plots )))
32+ n_rows = int (np .ceil (n_plots / n_cols ))
33+ fig , axes = plt .subplots (n_rows , n_cols , figsize = (n_cols * 4 , n_rows * 3 ))
34+ axes = np .array (axes ).flatten ()
35+
36+ for i , ax in enumerate (axes ):
37+ if i < n_plots :
38+ x_base = np .arange (len (base_metrics [metric_list [i ]]))
39+ x_current = np .arange (len (cur_metrics [metric_list [i ]]))
40+ ax .plot (
41+ x_base ,
42+ base_metrics [metric_list [i ]],
43+ "r--" ,
44+ label = "Base" ,
45+ marker = "x" ,
46+ markersize = 4 ,
47+ )
48+ ax .plot (
49+ x_current ,
50+ cur_metrics [metric_list [i ]],
51+ "b-" ,
52+ label = "Current" ,
53+ marker = "o" ,
54+ markersize = 4 ,
55+ )
56+ ax .set_title (f"{ metric_list [i ].replace ('/' , '_' )} _comparison" )
57+ ax .set_xlabel ("Step" )
58+ ax .set_ylabel ("Value" )
59+ ax .legend ()
60+ ax .grid (True , linestyle = "--" , alpha = 0.7 )
61+ else :
62+ ax .axis ("off" )
63+ fig .suptitle (f"{ case_name } _metrics_comparison" , fontsize = 16 )
64+ plt .tight_layout ()
65+ plt .savefig (output_root / f"{ case_name } _comparison.png" )
66+ plt .close ()
67+
68+
69+ def write_to_summary (case_name , base_jsonl , cur_jsonl ):
70+
71+ summary_file = os .environ .get ('GITHUB_STEP_SUMMARY' , './tmp.md' )
72+ repo_owner = os .environ .get ('GITHUB_REPOSITORY_OWNER' , 'internlm' )
73+ run_id = os .environ .get ('GITHUB_RUN_ID' , '0' )
74+ with open (summary_file , 'a' ) as f :
75+ f .write (f"## { case_name } 指标比较图\n " )
76+ f .write ('<div align="center">\n ' )
77+ f .write (f'<img src="https://{ repo_owner } .github.io/xtuner/{ run_id } /{ case_name } _comparison.png"\n ' )
78+ f .write (' style="max-width: 90%; border: 1px solid #ddd; border-radius: 8px;">\n ' )
79+ f .write ('</div>\n <div align=center>\n ' )
80+ f .write (f'<details>\n <summary><strong style="text-align: left;">📊 点击查看用例{ case_name } 指标数据,依次为基线、当前版本数据</strong></summary>\n \n ' )
2481
25- def check_result (base_path , cur_path , check_metric ):
82+ for json_f in [base_jsonl , cur_jsonl ]:
83+ with open (json_f , 'r' , encoding = 'utf-8' ) as f :
84+ lines = [line .strip () for line in f if line .strip ()]
85+
86+ md_content = '```json\n '
87+ for i , line in enumerate (lines , 1 ):
88+ md_content += f'{ line } \n '
89+
90+ md_content += '```\n \n '
91+
92+ with open (summary_file , 'a' , encoding = 'utf-8' ) as f :
93+ f .write (md_content )
94+ with open (summary_file , 'a' ) as f :
95+ f .write ('</details>\n \n ' )
96+
97+
98+ def check_result (case_name , base_path , cur_path , check_metric ):
2699 fail_metric = {}
27100 check_metric = check_metric
28101 metric_list = list (check_metric .keys ())
@@ -32,6 +105,12 @@ def check_result(base_path, cur_path, check_metric):
32105 f"current steps is not equal to base steps, current steps: { cur_steps } , base steps: { base_steps } "
33106 )
34107
108+ output_path = Path (f"../{ os .environ .get ('GITHUB_RUN_ID' ,'0' )} " )
109+ output_path .mkdir (parents = True , exist_ok = True )
110+ plot_all (case_name , check_metric , base_metrics , cur_metrics , output_path )
111+ shutil .copytree (output_path , f"./{ os .environ ['GITHUB_RUN_ID' ]} " , dirs_exist_ok = True )
112+ write_to_summary (case_name , base_path , cur_path )
113+
35114 for metric , threshold in check_metric .items ():
36115 max_error = 0.0
37116 max_error_idx = 0
@@ -75,4 +154,4 @@ def check_result(base_path, cur_path, check_metric):
75154 return result , f"Some metric check failed,{ fail_metric } "
76155
77156if __name__ == "__main__" :
78- print (check_result ("./base// tracker.jsonl" ,"./current/tracker.jsonl" ,{"grad_norm" :0.000001 ,"loss/reduced_llm_loss" :0.000001 ,"lr" :0 ,"memory/max_memory_GB" :0.2 ,"runtime_info/tgs" :0.05 ,"runtime_info/text_tokens" :0 }))
157+ print (check_result ("qwen3-sft" , " ./base/tracker.jsonl" , "./current/tracker.jsonl" ,{"grad_norm" :0.000001 ,"loss/reduced_llm_loss" :0.000001 ,"lr" :0 ,"memory/max_memory_GB" :0.2 ,"runtime_info/tgs" :0.05 ,"runtime_info/text_tokens" :0 }))
0 commit comments