diff --git a/.github/workflows/ci_e2e_llm_model_test.yml b/.github/workflows/ci_e2e_llm_model_test.yml index 81f7309d3e8..d9f778e8228 100644 --- a/.github/workflows/ci_e2e_llm_model_test.yml +++ b/.github/workflows/ci_e2e_llm_model_test.yml @@ -8,6 +8,10 @@ name: CI - E2E Model Tests on: workflow_dispatch: + pull_request: + push: + branches: + - main schedule: - cron: "0 12 * * *" @@ -107,11 +111,59 @@ jobs: output_artifacts/output_*/consolidated_benchmark.json output_artifacts/output_*/*.log output_artifacts/version.txt + + - name: Run the Python script + if: always() + run: python sharktank/tests/e2e/automate_pr_for_gold.py + - name: Cleanup output artifacts if: always() run: | - rm -rf output_artifacts - test ! -d output_artifacts && echo "Output artifacts are removed" + rm -rf output_artifacts + test ! -d output_artifacts && echo "Output artifacts are removed" + + - name: Check for changes + id: changes_check + if: always() + run: | + if git diff --exit-code sharktank/tests/e2e/configs/models.json > /dev/null; then + echo "No changes detected." + echo "CHANGED_JSON=false" >> $GITHUB_OUTPUT + exit 0 + else + echo "Changes detected." + echo "CHANGED_JSON=true" >> $GITHUB_OUTPUT + fi + + - uses: actions/create-github-app-token@df432ceedc7162793a195dd1713ff69aefc7379e # v2.0.6 + # if: ${{ env.CREATE_PULL_REQUEST_TOKEN_APP_ID != '' && env.CREATE_PULL_REQUEST_TOKEN_APP_PRIVATE_KEY != '' }} + if: always() + id: generate-token + with: + app-id: ${{ secrets.CREATE_PULL_REQUEST_TOKEN_APP_ID }} + private-key: ${{ secrets.CREATE_PULL_REQUEST_TOKEN_APP_PRIVATE_KEY }} + env: + CREATE_PULL_REQUEST_TOKEN_APP_ID: ${{ secrets.CREATE_PULL_REQUEST_TOKEN_APP_ID }} + CREATE_PULL_REQUEST_TOKEN_APP_PRIVATE_KEY: ${{ secrets.CREATE_PULL_REQUEST_TOKEN_APP_PRIVATE_KEY }} + + - name: Create Pull Request + id: cpr + # if: env.CHANGED_JSON == 'true' + if: always() + uses: peter-evans/create-pull-request@271a8d0340265f705b14b6d32b9829c1cb33d45e # v7.0.8 + with: + token: ${{ steps.generate-token.outputs.token || secrets.GITHUB_TOKEN }} + commit-message: Update gold values in sharktank/tests/e2e/configs/models.json + title: '[Auto-Update] Update gold values for E2E tests' + body: | + This PR updates the gold values for models in sharktank/tests/e2e/configs/models.json + branch: update-gold-values + delete-branch: true + author: shark-pr-automator[bot] <41898282+github-actions[bot]@users.noreply.github.com> + signoff: true + add-paths: | + sharktank/tests/e2e/configs/models.json + base: main # New job to push logs to shark-ai-reports repository push_logs: diff --git a/sharktank/tests/e2e/automate_pr_for_gold.py b/sharktank/tests/e2e/automate_pr_for_gold.py new file mode 100644 index 00000000000..53bde91036b --- /dev/null +++ b/sharktank/tests/e2e/automate_pr_for_gold.py @@ -0,0 +1,107 @@ +# Copyright 2025 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import json +import os + + +def normalize_ascii(obj): + if isinstance(obj, str): + return obj.replace("–", "-").replace("—", "-") + elif isinstance(obj, list): + return [normalize_ascii(x) for x in obj] + elif isinstance(obj, dict): + return {k: normalize_ascii(v) for k, v in obj.items()} + return obj + + +def parse_log(log_file_path): + gold_prefill_time = current_prefill_time = None + gold_decode_time = current_decode_time = None + with open(log_file_path, "r") as file: + for line in file: + if "GOLD PREFILL_TIME" in line: + gold_prefill_time = float( + line.split(":")[3].strip().split(" ")[0].replace(" ", "") + ) + current_prefill_time = float( + line.split(":")[4].strip().split(" ")[0].replace(" ", "") + ) + if "GOLD DECODE_TIME" in line: + gold_decode_time = float( + line.split(":")[3].strip().split(" ")[0].replace(" ", "") + ) + current_decode_time = float( + line.split(":")[4].strip().split(" ")[0].replace(" ", "") + ) + + if None in ( + gold_prefill_time, + current_prefill_time, + gold_decode_time, + current_decode_time, + ): + return None + return ( + gold_prefill_time, + current_prefill_time, + gold_decode_time, + current_decode_time, + ) + + +def update_json_for_conditions(json_file_path, log_path): + with open(json_file_path, "r") as f: + data = json.load(f) + + updated = False + for model, details in data.items(): + log_file_path = os.path.join( + log_path, f"output_{model}/e2e_testing_log_file.log" + ) + + if not os.path.exists(log_file_path): + print(f"Skipping {model} — log file not found.") + continue + + parsed = parse_log(log_file_path) + if parsed is None: + print(f"Skipping {model} — GOLD PREFILL/DECODE values missing in log.") + continue + + gold_prefill, current_prefill, gold_decode, current_decode = parsed + + gold_prefill_mi325x = float(details.get("prefill_gold_mi325x", None)) + gold_decode_mi325x = float(details.get("decode_gold_mi325x", None)) + if gold_prefill_mi325x and gold_decode_mi325x: + if current_prefill < gold_prefill_mi325x * (1 - 0.03): + print( + f"Updating PREFILL gold for {model}: {gold_prefill_mi325x} -> {current_prefill}" + ) + details["prefill_gold_mi325x"] = round(current_prefill, 3) + updated = True + if current_decode < gold_decode_mi325x * (1 - 0.06): + print( + f"Updating DECODE gold for {model}: {gold_decode_mi325x} -> {current_decode}" + ) + details["decode_gold_mi325x"] = round(current_decode, 3) + updated = True + + if updated: + with open(json_file_path, "w") as f: + json.dump(normalize_ascii(data), f, indent=2, ensure_ascii=False) + f.write("\n") + print( + "[IMPROVEMENT SEEN] Gold values updated in the JSON file. Creating a Pr.." + ) + else: + print("No updates made — all models within tolerance.") + + +if __name__ == "__main__": + json_file_path = "sharktank/tests/e2e/configs/models.json" + log_path = "output_artifacts" + update_json_for_conditions(json_file_path, log_path) diff --git a/sharktank/tests/e2e/configs/models.json b/sharktank/tests/e2e/configs/models.json index e3a488c987d..e87d8b3a6cf 100644 --- a/sharktank/tests/e2e/configs/models.json +++ b/sharktank/tests/e2e/configs/models.json @@ -41,9 +41,9 @@ ], "output_dir": "../shark-ai/output_artifacts/", "prefill_gold_mi325x": 175.0, - "decode_gold_mi325x": 8.50, - "prefill_gold_mi300x": 200.20, - "decode_gold_mi300x": 9.60, + "decode_gold_mi325x": 8.5, + "prefill_gold_mi300x": 200.2, + "decode_gold_mi300x": 9.6, "extra_compile_flags_list": [ "--iree-dispatch-creation-propagate-collapse-across-expands=true", "--iree-hip-specialize-dispatches", @@ -105,7 +105,7 @@ "prefill_gold_mi325x": 263.0, "decode_gold_mi325x": 8.23, "prefill_gold_mi300x": 320.0, - "decode_gold_mi300x": 9.90, + "decode_gold_mi300x": 9.9, "extra_compile_flags_list": [ "--iree-dispatch-creation-propagate-collapse-across-expands=true", "--iree-hip-specialize-dispatches", @@ -166,8 +166,8 @@ "extra_export_flags_list": [], "output_dir": "../shark-ai/output_artifacts/", "prefill_gold_mi325x": 1790.079, - "decode_gold_mi325x": 49.90, - "prefill_gold_mi300x": 2237.90, + "decode_gold_mi325x": 49.9, + "prefill_gold_mi300x": 2237.9, "decode_gold_mi300x": 57.0, "extra_compile_flags_list": [ "--iree-dispatch-creation-propagate-collapse-across-expands=true", @@ -235,10 +235,10 @@ "--kv-cache-dtype=float8_e4m3fnuz" ], "output_dir": "../shark-ai/output_artifacts/", - "prefill_gold_mi325x": 1530.90, - "decode_gold_mi325x": 43.20, + "prefill_gold_mi325x": 1530.9, + "decode_gold_mi325x": 43.2, "prefill_gold_mi300x": 1579.0, - "decode_gold_mi300x": 48.20, + "decode_gold_mi300x": 48.2, "extra_compile_flags_list": [ "--iree-dispatch-creation-propagate-collapse-across-expands=true", "--iree-hip-specialize-dispatches", @@ -306,7 +306,7 @@ "output_dir": "../shark-ai/output_artifacts/", "prefill_gold_mi325x": 143.0, "decode_gold_mi325x": 16.7, - "prefill_gold_mi300x": 160.50, + "prefill_gold_mi300x": 160.5, "decode_gold_mi300x": 16.853, "extra_compile_flags_list": [ "--iree-dispatch-creation-propagate-collapse-across-expands=true", @@ -462,7 +462,7 @@ 1 ], "extra_benchmark_flags_list": [ - "-—benchmark_min_warmup_time=10.0" + "--benchmark_min_warmup_time=10.0" ], "isl": 19988, "prefill_bs_for_time_check": 4, @@ -538,7 +538,7 @@ 1 ], "extra_benchmark_flags_list": [ - "-—benchmark_min_warmup_time=10.0" + "--benchmark_min_warmup_time=10.0" ], "isl": 19988, "prefill_bs_for_time_check": 4, @@ -632,7 +632,7 @@ 7 ], "extra_benchmark_flags_list": [ - "-—benchmark_min_warmup_time=10.0" + "--benchmark_min_warmup_time=10.0" ], "isl": 19988, "prefill_bs_for_time_check": 4, @@ -726,7 +726,7 @@ 7 ], "extra_benchmark_flags_list": [ - "-—benchmark_min_warmup_time=10.0" + "--benchmark_min_warmup_time=10.0" ], "isl": 19988, "prefill_bs_for_time_check": 4, @@ -801,7 +801,7 @@ 0 ], "extra_benchmark_flags_list": [ - "-—benchmark_min_warmup_time=10.0" + "--benchmark_min_warmup_time=10.0" ], "isl": 2500, "prefill_bs_for_time_check": 4, @@ -875,7 +875,7 @@ 0 ], "extra_benchmark_flags_list": [ - "-—benchmark_min_warmup_time=10.0" + "--benchmark_min_warmup_time=10.0" ], "isl": 2500, "prefill_bs_for_time_check": 4,