Skip to content

Commit 9b0b138

Browse files
author
Github Executorch
committed
Support config filtering in ondemand benchmark flow
1 parent 57a09f4 commit 9b0b138

File tree

6 files changed

+301
-45
lines changed

6 files changed

+301
-45
lines changed

.ci/scripts/__init__.py

Whitespace-only changes.

.ci/scripts/gather_benchmark_configs.py

Lines changed: 93 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
import logging
1010
import os
1111
import re
12-
from typing import Any, Dict
12+
from typing import Any, Dict, List
1313

14-
from examples.models import MODEL_NAME_TO_MODEL
14+
from executorch.examples.models import MODEL_NAME_TO_MODEL
1515

1616

1717
# Device pools for AWS Device Farm
@@ -45,6 +45,79 @@
4545
}
4646

4747

48+
def extract_all_configs(data, target_os=None):
49+
if isinstance(data, dict):
50+
# If target_os is specified, include "xplat" and the specified branch
51+
include_branches = {"xplat", target_os} if target_os else data.keys()
52+
return [
53+
v
54+
for key, value in data.items()
55+
if key in include_branches
56+
for v in extract_all_configs(value, target_os)
57+
]
58+
elif isinstance(data, list):
59+
return [v for item in data for v in extract_all_configs(item, target_os)]
60+
else:
61+
return [data]
62+
63+
64+
def generate_compatible_configs(model_name: str, target_os=None) -> List[str]:
65+
"""
66+
Generate a list of compatible benchmark configurations for a given model name and target OS.
67+
68+
Args:
69+
model_name (str): The name of the model to generate configurations for.
70+
target_os (Optional[str]): The target operating system (e.g., 'android', 'ios').
71+
72+
Returns:
73+
List[str]: A list of compatible benchmark configurations.
74+
75+
Raises:
76+
None
77+
78+
Example:
79+
generate_compatible_configs('meta-llama/Llama-3.2-1B', 'ios') -> ['llama3_fb16', 'llama3_coreml_ane']
80+
"""
81+
configs = []
82+
if is_valid_huggingface_model_id(model_name):
83+
if model_name.startswith("meta-llama/"):
84+
# LLaMA models
85+
repo_name = model_name.split("meta-llama/")[1]
86+
if "qlora" in repo_name.lower():
87+
configs.append("llama3_qlora")
88+
elif "spinquant" in repo_name.lower():
89+
configs.append("llama3_spinquant")
90+
else:
91+
configs.append("llama3_fb16")
92+
configs.extend(
93+
[
94+
config
95+
for config in BENCHMARK_CONFIGS.get(target_os, [])
96+
if config.startswith("llama")
97+
]
98+
)
99+
else:
100+
# Non-LLaMA models
101+
configs.append("hf_xnnpack_fp32")
102+
elif model_name in MODEL_NAME_TO_MODEL:
103+
# ExecuTorch in-tree non-GenAI models
104+
configs.append("xnnpack_q8")
105+
if target_os != "xplat":
106+
# Add OS-specific configs
107+
configs.extend(
108+
[
109+
config
110+
for config in BENCHMARK_CONFIGS.get(target_os, [])
111+
if not config.startswith("llama")
112+
]
113+
)
114+
else:
115+
# Skip unknown models with a warning
116+
logging.warning(f"Unknown or invalid model name '{model_name}'. Skipping.")
117+
118+
return configs
119+
120+
48121
def parse_args() -> Any:
49122
"""
50123
Parse command-line arguments.
@@ -82,6 +155,11 @@ def comma_separated(value: str):
82155
type=comma_separated, # Use the custom parser for comma-separated values
83156
help=f"Comma-separated device names. Available devices: {list(DEVICE_POOLS.keys())}",
84157
)
158+
parser.add_argument(
159+
"--configs",
160+
type=comma_separated, # Use the custom parser for comma-separated values
161+
help=f"Comma-separated benchmark configs. Available configs: {extract_all_configs(BENCHMARK_CONFIGS)}",
162+
)
85163

86164
return parser.parse_args()
87165

@@ -123,7 +201,7 @@ def is_valid_huggingface_model_id(model_name: str) -> bool:
123201
return bool(re.match(pattern, model_name))
124202

125203

126-
def get_benchmark_configs() -> Dict[str, Dict]:
204+
def get_benchmark_configs() -> Dict[str, Dict]: # noqa: C901
127205
"""
128206
Gather benchmark configurations for a given set of models on the target operating system and devices.
129207
@@ -153,48 +231,26 @@ def get_benchmark_configs() -> Dict[str, Dict]:
153231
}
154232
"""
155233
args = parse_args()
156-
target_os = args.os
157234
devices = args.devices
158235
models = args.models
236+
target_os = args.os
237+
target_configs = args.configs
159238

160239
benchmark_configs = {"include": []}
161240

162241
for model_name in models:
163242
configs = []
164-
if is_valid_huggingface_model_id(model_name):
165-
if model_name.startswith("meta-llama/"):
166-
# LLaMA models
167-
repo_name = model_name.split("meta-llama/")[1]
168-
if "qlora" in repo_name.lower():
169-
configs.append("llama3_qlora")
170-
elif "spinquant" in repo_name.lower():
171-
configs.append("llama3_spinquant")
172-
else:
173-
configs.append("llama3_fb16")
174-
configs.extend(
175-
[
176-
config
177-
for config in BENCHMARK_CONFIGS.get(target_os, [])
178-
if config.startswith("llama")
179-
]
243+
configs.extend(generate_compatible_configs(model_name, target_os))
244+
print(f"Discovered all supported configs for model '{model_name}': {configs}")
245+
if target_configs is not None:
246+
for config in target_configs:
247+
if config not in configs:
248+
raise Exception(
249+
f"Unsupported config '{config}' for model '{model_name}' on '{target_os}'. Skipped.\n"
250+
f"Supported configs are: {configs}"
180251
)
181-
else:
182-
# Non-LLaMA models
183-
configs.append("hf_xnnpack_fp32")
184-
elif model_name in MODEL_NAME_TO_MODEL:
185-
# ExecuTorch in-tree non-GenAI models
186-
configs.append("xnnpack_q8")
187-
configs.extend(
188-
[
189-
config
190-
for config in BENCHMARK_CONFIGS.get(target_os, [])
191-
if not config.startswith("llama")
192-
]
193-
)
194-
else:
195-
# Skip unknown models with a warning
196-
logging.warning(f"Unknown or invalid model name '{model_name}'. Skipping.")
197-
continue
252+
configs = target_configs
253+
print(f"Using provided configs {configs} for model '{model_name}'")
198254

199255
# Add configurations for each valid device
200256
for device in devices:
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
import importlib.util
2+
import os
3+
import subprocess
4+
import unittest
5+
from unittest.mock import mock_open, patch
6+
7+
# Dynamically import the script
8+
script_path = os.path.join(".ci", "scripts", "gather_benchmark_configs.py")
9+
spec = importlib.util.spec_from_file_location("gather_benchmark_configs", script_path)
10+
gather_benchmark_configs = importlib.util.module_from_spec(spec)
11+
spec.loader.exec_module(gather_benchmark_configs)
12+
13+
14+
class TestGatehrBenchmarkConfigs(unittest.TestCase):
15+
16+
def test_extract_all_configs_android(self):
17+
android_configs = gather_benchmark_configs.extract_all_configs(
18+
gather_benchmark_configs.BENCHMARK_CONFIGS, "android"
19+
)
20+
self.assertIn("xnnpack_q8", android_configs)
21+
self.assertIn("qnn_q8", android_configs)
22+
self.assertIn("llama3_spinquant", android_configs)
23+
self.assertIn("llama3_qlora", android_configs)
24+
25+
def test_extract_all_configs_ios(self):
26+
ios_configs = gather_benchmark_configs.extract_all_configs(
27+
gather_benchmark_configs.BENCHMARK_CONFIGS, "ios"
28+
)
29+
30+
self.assertIn("xnnpack_q8", ios_configs)
31+
self.assertIn("coreml_fp16", ios_configs)
32+
self.assertIn("mps", ios_configs)
33+
self.assertIn("llama3_coreml_ane", ios_configs)
34+
self.assertIn("llama3_spinquant", ios_configs)
35+
self.assertIn("llama3_qlora", ios_configs)
36+
37+
def test_generate_compatible_configs_llama_model(self):
38+
model_name = "meta-llama/Llama-3.2-1B"
39+
target_os = "ios"
40+
result = gather_benchmark_configs.generate_compatible_configs(
41+
model_name, target_os
42+
)
43+
expected = ["llama3_fb16", "llama3_coreml_ane"]
44+
self.assertEqual(result, expected)
45+
46+
target_os = "android"
47+
result = gather_benchmark_configs.generate_compatible_configs(
48+
model_name, target_os
49+
)
50+
expected = ["llama3_fb16"]
51+
self.assertEqual(result, expected)
52+
53+
def test_generate_compatible_configs_quantized_llama_model(self):
54+
model_name = "meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8"
55+
result = gather_benchmark_configs.generate_compatible_configs(model_name, None)
56+
expected = ["llama3_spinquant"]
57+
self.assertEqual(result, expected)
58+
59+
model_name = "meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8"
60+
result = gather_benchmark_configs.generate_compatible_configs(model_name, None)
61+
expected = ["llama3_qlora"]
62+
self.assertEqual(result, expected)
63+
64+
def test_generate_compatible_configs_non_genai_model(self):
65+
model_name = "mv2"
66+
target_os = "xplat"
67+
result = gather_benchmark_configs.generate_compatible_configs(
68+
model_name, target_os
69+
)
70+
expected = ["xnnpack_q8"]
71+
self.assertEqual(result, expected)
72+
73+
target_os = "android"
74+
result = gather_benchmark_configs.generate_compatible_configs(
75+
model_name, target_os
76+
)
77+
expected = ["xnnpack_q8", "qnn_q8"]
78+
self.assertEqual(result, expected)
79+
80+
target_os = "ios"
81+
result = gather_benchmark_configs.generate_compatible_configs(
82+
model_name, target_os
83+
)
84+
expected = ["xnnpack_q8", "coreml_fp16", "mps"]
85+
self.assertEqual(result, expected)
86+
87+
def test_generate_compatible_configs_unknown_model(self):
88+
model_name = "unknown_model"
89+
target_os = "ios"
90+
result = gather_benchmark_configs.generate_compatible_configs(
91+
model_name, target_os
92+
)
93+
self.assertEqual(result, [])
94+
95+
def test_is_valid_huggingface_model_id_valid(self):
96+
valid_model = "meta-llama/Llama-3.2-1B"
97+
self.assertTrue(
98+
gather_benchmark_configs.is_valid_huggingface_model_id(valid_model)
99+
)
100+
101+
@patch("builtins.open", new_callable=mock_open)
102+
@patch("os.getenv", return_value=None)
103+
def test_set_output_no_github_env(self, mock_getenv, mock_file):
104+
with patch("builtins.print") as mock_print:
105+
gather_benchmark_configs.set_output("test_name", "test_value")
106+
mock_print.assert_called_with("::set-output name=test_name::test_value")
107+
108+
def test_device_pools_contains_all_devices(self):
109+
expected_devices = [
110+
"apple_iphone_15",
111+
"apple_iphone_15+ios_18",
112+
"samsung_galaxy_s22",
113+
"samsung_galaxy_s24",
114+
"google_pixel_8_pro",
115+
]
116+
for device in expected_devices:
117+
self.assertIn(device, gather_benchmark_configs.DEVICE_POOLS)
118+
119+
def test_gather_benchmark_configs_cli(self):
120+
args = {
121+
"models": "mv2,dl3",
122+
"os": "ios",
123+
"devices": "apple_iphone_15",
124+
"configs": None,
125+
}
126+
127+
cmd = ["python", ".ci/scripts/gather_benchmark_configs.py"]
128+
for key, value in args.items():
129+
if value is not None:
130+
cmd.append(f"--{key}")
131+
cmd.append(value)
132+
133+
result = subprocess.run(cmd, capture_output=True, text=True)
134+
self.assertEqual(result.returncode, 0, f"Error: {result.stderr}")
135+
self.assertIn('"model": "mv2"', result.stdout)
136+
self.assertIn('"model": "dl3"', result.stdout)
137+
self.assertIn('"config": "coreml_fp16"', result.stdout)
138+
self.assertIn('"config": "xnnpack_q8"', result.stdout)
139+
self.assertIn('"config": "mps"', result.stdout)
140+
141+
def test_gather_benchmark_configs_cli_specified_configs(self):
142+
args = {
143+
"models": "mv2,dl3",
144+
"os": "ios",
145+
"devices": "apple_iphone_15",
146+
"configs": "coreml_fp16,xnnpack_q8",
147+
}
148+
149+
cmd = ["python", ".ci/scripts/gather_benchmark_configs.py"]
150+
for key, value in args.items():
151+
if value is not None:
152+
cmd.append(f"--{key}")
153+
cmd.append(value)
154+
155+
result = subprocess.run(cmd, capture_output=True, text=True)
156+
self.assertEqual(result.returncode, 0, f"Error: {result.stderr}")
157+
self.assertIn('"model": "mv2"', result.stdout)
158+
self.assertIn('"model": "dl3"', result.stdout)
159+
self.assertIn('"config": "coreml_fp16"', result.stdout)
160+
self.assertIn('"config": "xnnpack_q8"', result.stdout)
161+
self.assertNotIn('"config": "mps"', result.stdout)
162+
163+
def test_gather_benchmark_configs_cli_specified_configs_raise(self):
164+
args = {
165+
"models": "mv2,dl3",
166+
"os": "ios",
167+
"devices": "apple_iphone_15",
168+
"configs": "qnn_q8",
169+
}
170+
171+
cmd = ["python", ".ci/scripts/gather_benchmark_configs.py"]
172+
for key, value in args.items():
173+
if value is not None:
174+
cmd.append(f"--{key}")
175+
cmd.append(value)
176+
177+
result = subprocess.run(cmd, capture_output=True, text=True)
178+
self.assertEqual(result.returncode, 1, f"Error: {result.stderr}")
179+
self.assertIn("Unsupported config 'qnn_q8'", result.stderr)
180+
181+
182+
if __name__ == "__main__":
183+
unittest.main()

.github/workflows/android-perf.yml

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,19 +74,27 @@ jobs:
7474
CRON_DEFAULT_DEVICES: samsung_galaxy_s22
7575
run: |
7676
set -eux
77+
78+
ARGS="--os android"
79+
7780
MODELS="${{ inputs.models }}"
7881
if [ -z "$MODELS" ]; then
7982
MODELS="$CRON_DEFAULT_MODELS"
8083
fi
84+
ARGS="$ARGS --models $MODELS"
85+
8186
DEVICES="${{ inputs.devices }}"
8287
if [ -z "$DEVICES" ]; then
8388
DEVICES="$CRON_DEFAULT_DEVICES"
8489
fi
90+
ARGS="$ARGS --devices $DEVICES"
91+
92+
BENCHMARK_CONFIGS="${{ inputs.benchmark_configs }}"
93+
if [ -n "$BENCHMARK_CONFIGS" ]; then
94+
ARGS="$ARGS --configs $BENCHMARK_CONFIGS"
95+
fi
8596
86-
PYTHONPATH="${PWD}" python .ci/scripts/gather_benchmark_configs.py \
87-
--os "android" \
88-
--models $MODELS \
89-
--devices $DEVICES
97+
PYTHONPATH="${PWD}" python .ci/scripts/gather_benchmark_configs.py $ARGS
9098
9199
prepare-test-specs:
92100
runs-on: linux.2xlarge

0 commit comments

Comments
 (0)