Skip to content

Commit c4a2c5d

Browse files
author
Github Executorch
committed
Skip disabled configs for benchmarking
1 parent c1528cb commit c4a2c5d

File tree

2 files changed

+77
-21
lines changed

2 files changed

+77
-21
lines changed

.ci/scripts/gather_benchmark_configs.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,24 @@
4646
],
4747
}
4848

49+
DISABLED_CONFIGS = {
50+
"resnet50": [
51+
"qnn_q8",
52+
],
53+
"w2l": [
54+
"qnn_q8",
55+
],
56+
"mobilebert": [
57+
"mps",
58+
],
59+
"edsr": [
60+
"mps",
61+
],
62+
"llama": [
63+
"mps",
64+
],
65+
}
66+
4967

5068
def extract_all_configs(data, target_os=None):
5169
if isinstance(data, dict):
@@ -117,6 +135,9 @@ def generate_compatible_configs(model_name: str, target_os=None) -> List[str]:
117135
# Skip unknown models with a warning
118136
logging.warning(f"Unknown or invalid model name '{model_name}'. Skipping.")
119137

138+
# Remove disabled configs for the given model
139+
disabled_configs = DISABLED_CONFIGS.get(model_name, [])
140+
configs = [config for config in configs if config not in disabled_configs]
120141
return configs
121142

122143

.ci/scripts/tests/test_gather_benchmark_configs.py

Lines changed: 56 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,30 +7,34 @@
77

88
import pytest
99

10-
# Dynamically import the script
11-
script_path = os.path.join(".ci", "scripts", "gather_benchmark_configs.py")
12-
spec = importlib.util.spec_from_file_location("gather_benchmark_configs", script_path)
13-
gather_benchmark_configs = importlib.util.module_from_spec(spec)
14-
spec.loader.exec_module(gather_benchmark_configs)
15-
1610

1711
@pytest.mark.skipif(
1812
sys.platform != "linux", reason="The script under test runs on Linux runners only"
1913
)
2014
class TestGatehrBenchmarkConfigs(unittest.TestCase):
2115

16+
@classmethod
17+
def setUpClass(cls):
18+
# Dynamically import the script
19+
script_path = os.path.join(".ci", "scripts", "gather_benchmark_configs.py")
20+
spec = importlib.util.spec_from_file_location(
21+
"gather_benchmark_configs", script_path
22+
)
23+
cls.gather_benchmark_configs = importlib.util.module_from_spec(spec)
24+
spec.loader.exec_module(cls.gather_benchmark_configs)
25+
2226
def test_extract_all_configs_android(self):
23-
android_configs = gather_benchmark_configs.extract_all_configs(
24-
gather_benchmark_configs.BENCHMARK_CONFIGS, "android"
27+
android_configs = self.gather_benchmark_configs.extract_all_configs(
28+
self.gather_benchmark_configs.BENCHMARK_CONFIGS, "android"
2529
)
2630
self.assertIn("xnnpack_q8", android_configs)
2731
self.assertIn("qnn_q8", android_configs)
2832
self.assertIn("llama3_spinquant", android_configs)
2933
self.assertIn("llama3_qlora", android_configs)
3034

3135
def test_extract_all_configs_ios(self):
32-
ios_configs = gather_benchmark_configs.extract_all_configs(
33-
gather_benchmark_configs.BENCHMARK_CONFIGS, "ios"
36+
ios_configs = self.gather_benchmark_configs.extract_all_configs(
37+
self.gather_benchmark_configs.BENCHMARK_CONFIGS, "ios"
3438
)
3539

3640
self.assertIn("xnnpack_q8", ios_configs)
@@ -40,51 +44,82 @@ def test_extract_all_configs_ios(self):
4044
self.assertIn("llama3_spinquant", ios_configs)
4145
self.assertIn("llama3_qlora", ios_configs)
4246

47+
def test_skip_disabled_configs(self):
48+
# Use patch as a context manager to avoid modifying DISABLED_CONFIGS and BENCHMARK_CONFIGS
49+
with patch.dict(
50+
self.gather_benchmark_configs.DISABLED_CONFIGS,
51+
{"mv3": ["disabled_config1", "disabled_config2"]},
52+
), patch.dict(
53+
self.gather_benchmark_configs.BENCHMARK_CONFIGS,
54+
{
55+
"ios": [
56+
"disabled_config1",
57+
"disabled_config2",
58+
"enabled_config1",
59+
"enabled_config2",
60+
]
61+
},
62+
):
63+
result = self.gather_benchmark_configs.generate_compatible_configs(
64+
"mv3", target_os="ios"
65+
)
66+
67+
# Assert that disabled configs are excluded
68+
self.assertNotIn("disabled_config1", result)
69+
self.assertNotIn("disabled_config2", result)
70+
# Assert enabled configs are included
71+
self.assertIn("enabled_config1", result)
72+
self.assertIn("enabled_config2", result)
73+
4374
def test_generate_compatible_configs_llama_model(self):
4475
model_name = "meta-llama/Llama-3.2-1B"
4576
target_os = "ios"
46-
result = gather_benchmark_configs.generate_compatible_configs(
77+
result = self.gather_benchmark_configs.generate_compatible_configs(
4778
model_name, target_os
4879
)
4980
expected = ["llama3_fb16", "llama3_coreml_ane"]
5081
self.assertEqual(result, expected)
5182

5283
target_os = "android"
53-
result = gather_benchmark_configs.generate_compatible_configs(
84+
result = self.gather_benchmark_configs.generate_compatible_configs(
5485
model_name, target_os
5586
)
5687
expected = ["llama3_fb16"]
5788
self.assertEqual(result, expected)
5889

5990
def test_generate_compatible_configs_quantized_llama_model(self):
6091
model_name = "meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8"
61-
result = gather_benchmark_configs.generate_compatible_configs(model_name, None)
92+
result = self.gather_benchmark_configs.generate_compatible_configs(
93+
model_name, None
94+
)
6295
expected = ["llama3_spinquant"]
6396
self.assertEqual(result, expected)
6497

6598
model_name = "meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8"
66-
result = gather_benchmark_configs.generate_compatible_configs(model_name, None)
99+
result = self.gather_benchmark_configs.generate_compatible_configs(
100+
model_name, None
101+
)
67102
expected = ["llama3_qlora"]
68103
self.assertEqual(result, expected)
69104

70105
def test_generate_compatible_configs_non_genai_model(self):
71106
model_name = "mv2"
72107
target_os = "xplat"
73-
result = gather_benchmark_configs.generate_compatible_configs(
108+
result = self.gather_benchmark_configs.generate_compatible_configs(
74109
model_name, target_os
75110
)
76111
expected = ["xnnpack_q8"]
77112
self.assertEqual(result, expected)
78113

79114
target_os = "android"
80-
result = gather_benchmark_configs.generate_compatible_configs(
115+
result = self.gather_benchmark_configs.generate_compatible_configs(
81116
model_name, target_os
82117
)
83118
expected = ["xnnpack_q8", "qnn_q8"]
84119
self.assertEqual(result, expected)
85120

86121
target_os = "ios"
87-
result = gather_benchmark_configs.generate_compatible_configs(
122+
result = self.gather_benchmark_configs.generate_compatible_configs(
88123
model_name, target_os
89124
)
90125
expected = ["xnnpack_q8", "coreml_fp16", "mps"]
@@ -93,22 +128,22 @@ def test_generate_compatible_configs_non_genai_model(self):
93128
def test_generate_compatible_configs_unknown_model(self):
94129
model_name = "unknown_model"
95130
target_os = "ios"
96-
result = gather_benchmark_configs.generate_compatible_configs(
131+
result = self.gather_benchmark_configs.generate_compatible_configs(
97132
model_name, target_os
98133
)
99134
self.assertEqual(result, [])
100135

101136
def test_is_valid_huggingface_model_id_valid(self):
102137
valid_model = "meta-llama/Llama-3.2-1B"
103138
self.assertTrue(
104-
gather_benchmark_configs.is_valid_huggingface_model_id(valid_model)
139+
self.gather_benchmark_configs.is_valid_huggingface_model_id(valid_model)
105140
)
106141

107142
@patch("builtins.open", new_callable=mock_open)
108143
@patch("os.getenv", return_value=None)
109144
def test_set_output_no_github_env(self, mock_getenv, mock_file):
110145
with patch("builtins.print") as mock_print:
111-
gather_benchmark_configs.set_output("test_name", "test_value")
146+
self.gather_benchmark_configs.set_output("test_name", "test_value")
112147
mock_print.assert_called_with("::set-output name=test_name::test_value")
113148

114149
def test_device_pools_contains_all_devices(self):
@@ -120,7 +155,7 @@ def test_device_pools_contains_all_devices(self):
120155
"google_pixel_8_pro",
121156
]
122157
for device in expected_devices:
123-
self.assertIn(device, gather_benchmark_configs.DEVICE_POOLS)
158+
self.assertIn(device, self.gather_benchmark_configs.DEVICE_POOLS)
124159

125160
def test_gather_benchmark_configs_cli(self):
126161
args = {

0 commit comments

Comments
 (0)