diff --git a/.ci/scripts/gather_benchmark_configs.py b/.ci/scripts/gather_benchmark_configs.py index ae7b78ecbb7..fcd2c5ba7dd 100755 --- a/.ci/scripts/gather_benchmark_configs.py +++ b/.ci/scripts/gather_benchmark_configs.py @@ -135,12 +135,11 @@ def generate_compatible_configs(model_name: str, target_os=None) -> List[str]: # etLLM recipes for Llama repo_name = model_name.split("meta-llama/")[1] if "qlora" in repo_name.lower(): - configs.append("llama3_qlora") + configs = ["llama3_qlora"] elif "spinquant" in repo_name.lower(): - configs.append("llama3_spinquant") + configs = ["llama3_spinquant"] else: - configs.append("llama3_fb16") - configs.append("et_xnnpack_custom_spda_kv_cache_8da4w") + configs.extend(["llama3_fb16", "et_xnnpack_custom_spda_kv_cache_8da4w"]) configs.extend( [ config diff --git a/.ci/scripts/tests/test_gather_benchmark_configs.py b/.ci/scripts/tests/test_gather_benchmark_configs.py index 41bb7528b3e..8f422a1c391 100644 --- a/.ci/scripts/tests/test_gather_benchmark_configs.py +++ b/.ci/scripts/tests/test_gather_benchmark_configs.py @@ -112,15 +112,24 @@ def test_generate_compatible_configs_llama_model(self): result = self.gather_benchmark_configs.generate_compatible_configs( model_name, target_os ) - expected = ["llama3_fb16", "llama3_coreml_ane"] - self.assertEqual(result, expected) + expected = [ + "llama3_fb16", + "llama3_coreml_ane", + "et_xnnpack_custom_spda_kv_cache_8da4w", + "hf_xnnpack_custom_spda_kv_cache_8da4w", + ] + self.assertCountEqual(result, expected) target_os = "android" result = self.gather_benchmark_configs.generate_compatible_configs( model_name, target_os ) - expected = ["llama3_fb16"] - self.assertEqual(result, expected) + expected = [ + "llama3_fb16", + "et_xnnpack_custom_spda_kv_cache_8da4w", + "hf_xnnpack_custom_spda_kv_cache_8da4w", + ] + self.assertCountEqual(result, expected) def test_generate_compatible_configs_quantized_llama_model(self): model_name = "meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8" diff --git a/.github/workflows/android-perf.yml b/.github/workflows/android-perf.yml index 34744268ff5..b5c5da57833 100644 --- a/.github/workflows/android-perf.yml +++ b/.github/workflows/android-perf.yml @@ -6,12 +6,14 @@ on: pull_request: paths: - .github/workflows/android-perf.yml + - .ci/scripts/gather_benchmark_configs.py - extension/benchmark/android/benchmark/android-llm-device-farm-test-spec.yml.j2 push: branches: - main paths: - .github/workflows/android-perf.yml + - .ci/scripts/gather_benchmark_configs.py - extension/benchmark/android/benchmark/android-llm-device-farm-test-spec.yml.j2 # Note: GitHub has an upper limit of 10 inputs workflow_dispatch: diff --git a/.github/workflows/apple-perf.yml b/.github/workflows/apple-perf.yml index e2f2cc2fcc3..5945d7b2d29 100644 --- a/.github/workflows/apple-perf.yml +++ b/.github/workflows/apple-perf.yml @@ -6,12 +6,14 @@ on: pull_request: paths: - .github/workflows/apple-perf.yml + - .ci/scripts/gather_benchmark_configs.py - extension/benchmark/apple/Benchmark/default-ios-device-farm-appium-test-spec.yml.j2 push: branches: - main paths: - .github/workflows/apple-perf.yml + - .ci/scripts/gather_benchmark_configs.py - extension/benchmark/apple/Benchmark/default-ios-device-farm-appium-test-spec.yml.j2 # Note: GitHub has an upper limit of 10 inputs workflow_dispatch: