|
| 1 | +import argparse |
1 | 2 | import os |
2 | 3 | import sys |
3 | 4 | from src.benchmark.utils import read_metrics, to_markdown_table |
4 | 5 |
|
5 | | -if __name__ == "__main__": |
6 | | - # Generate statistics report |
7 | | - statistics_path = sys.argv[1] |
8 | | - metrics = read_metrics(statistics_path, metric="accuracy") |
| 6 | + |
| 7 | +def parse_args(): |
| 8 | + parser = argparse.ArgumentParser() |
| 9 | + parser.add_argument("--path", type=str, required=True, help="Report path.") |
| 10 | + parser.add_argument("--write-gh-job-summary", action="store_true", help="Write to GitHub job summary.") |
| 11 | + parser.add_argument("--update-readme", action="store_true", help="Update statistics report in README.md.") |
| 12 | + return parser.parse_args() |
| 13 | + |
| 14 | + |
| 15 | +def generate_report(path: str): |
| 16 | + metrics = read_metrics(path, metric="accuracy") |
9 | 17 | html_table = to_markdown_table(metrics) |
| 18 | + return html_table |
10 | 19 |
|
11 | | - # Write to workflow job summary |
| 20 | + |
| 21 | +def write_job_summary(report): |
12 | 22 | summary_path = os.environ["GITHUB_STEP_SUMMARY"] |
13 | 23 | with open(summary_path, "a") as f: |
14 | 24 | f.write("## Torchbenchmark statistics report\n") |
15 | | - f.write(html_table) |
| 25 | + f.write(report) |
| 26 | + |
| 27 | + |
| 28 | +def update_readme(report): |
| 29 | + project_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
| 30 | + readme_path = os.path.join(project_path, "README.md") |
| 31 | + print(readme_path) |
| 32 | + with open(readme_path, "r") as f: |
| 33 | + readme_content = f.read() |
| 34 | + |
| 35 | + start_marker = "<!-- Torchbenchmark start -->" |
| 36 | + end_marker = "<!-- Torchbenchmark end -->" |
| 37 | + start_index = readme_content.find(start_marker) |
| 38 | + end_index = readme_content.find(end_marker) |
| 39 | + assert start_index != -1 |
| 40 | + assert end_index != -1 |
| 41 | + |
| 42 | + start_index += len(start_marker) |
| 43 | + new_readme_content = ( |
| 44 | + readme_content[:start_index] + "\n\n" + |
| 45 | + report + "\n\n" + |
| 46 | + readme_content[end_index:] |
| 47 | + ) |
| 48 | + with open(readme_path, "w") as f: |
| 49 | + f.write(new_readme_content) |
| 50 | + |
| 51 | + |
| 52 | +if __name__ == "__main__": |
| 53 | + args = parse_args() |
| 54 | + |
| 55 | + # Generate statistics report |
| 56 | + report = generate_report(args.path) |
| 57 | + |
| 58 | + # Write to workflow job summary |
| 59 | + if args.write_gh_job_summary: |
| 60 | + write_job_summary(report) |
| 61 | + |
| 62 | + # Update README.md |
| 63 | + if args.update_readme: |
| 64 | + update_readme(report) |
0 commit comments