Skip to content

Commit 2006d40

Browse files
committed
Update on "Reuse GELU implementation from PyTorch core"
kernels/optimized doesn't need to support embedded systems, so it can just take a header-only dep on PyTorch. Note that, because we will pick up Sleef internally and ignore it externally thanks to ATen vec, this PR gets to enable optimized GELU in OSS. Testing: CI to make sure this doesn't break mobile build modes; happy to take advice on anything not currently covered that might break. Differential Revision: [D66335522](https://our.internmc.facebook.com/intern/diff/D66335522/) [ghstack-poisoned]
2 parents 028526f + b256712 commit 2006d40

File tree

236 files changed

+6184
-5138
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

236 files changed

+6184
-5138
lines changed

.ci/scripts/__init__.py

Whitespace-only changes.

.ci/scripts/gather_benchmark_configs.py

Lines changed: 104 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99
import logging
1010
import os
1111
import re
12-
from typing import Any, Dict
12+
import sys
13+
from typing import Any, Dict, List
1314

15+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
1416
from examples.models import MODEL_NAME_TO_MODEL
1517

1618

@@ -45,6 +47,79 @@
4547
}
4648

4749

50+
def extract_all_configs(data, target_os=None):
51+
if isinstance(data, dict):
52+
# If target_os is specified, include "xplat" and the specified branch
53+
include_branches = {"xplat", target_os} if target_os else data.keys()
54+
return [
55+
v
56+
for key, value in data.items()
57+
if key in include_branches
58+
for v in extract_all_configs(value, target_os)
59+
]
60+
elif isinstance(data, list):
61+
return [v for item in data for v in extract_all_configs(item, target_os)]
62+
else:
63+
return [data]
64+
65+
66+
def generate_compatible_configs(model_name: str, target_os=None) -> List[str]:
67+
"""
68+
Generate a list of compatible benchmark configurations for a given model name and target OS.
69+
70+
Args:
71+
model_name (str): The name of the model to generate configurations for.
72+
target_os (Optional[str]): The target operating system (e.g., 'android', 'ios').
73+
74+
Returns:
75+
List[str]: A list of compatible benchmark configurations.
76+
77+
Raises:
78+
None
79+
80+
Example:
81+
generate_compatible_configs('meta-llama/Llama-3.2-1B', 'ios') -> ['llama3_fb16', 'llama3_coreml_ane']
82+
"""
83+
configs = []
84+
if is_valid_huggingface_model_id(model_name):
85+
if model_name.startswith("meta-llama/"):
86+
# LLaMA models
87+
repo_name = model_name.split("meta-llama/")[1]
88+
if "qlora" in repo_name.lower():
89+
configs.append("llama3_qlora")
90+
elif "spinquant" in repo_name.lower():
91+
configs.append("llama3_spinquant")
92+
else:
93+
configs.append("llama3_fb16")
94+
configs.extend(
95+
[
96+
config
97+
for config in BENCHMARK_CONFIGS.get(target_os, [])
98+
if config.startswith("llama")
99+
]
100+
)
101+
else:
102+
# Non-LLaMA models
103+
configs.append("hf_xnnpack_fp32")
104+
elif model_name in MODEL_NAME_TO_MODEL:
105+
# ExecuTorch in-tree non-GenAI models
106+
configs.append("xnnpack_q8")
107+
if target_os != "xplat":
108+
# Add OS-specific configs
109+
configs.extend(
110+
[
111+
config
112+
for config in BENCHMARK_CONFIGS.get(target_os, [])
113+
if not config.startswith("llama")
114+
]
115+
)
116+
else:
117+
# Skip unknown models with a warning
118+
logging.warning(f"Unknown or invalid model name '{model_name}'. Skipping.")
119+
120+
return configs
121+
122+
48123
def parse_args() -> Any:
49124
"""
50125
Parse command-line arguments.
@@ -82,6 +157,11 @@ def comma_separated(value: str):
82157
type=comma_separated, # Use the custom parser for comma-separated values
83158
help=f"Comma-separated device names. Available devices: {list(DEVICE_POOLS.keys())}",
84159
)
160+
parser.add_argument(
161+
"--configs",
162+
type=comma_separated, # Use the custom parser for comma-separated values
163+
help=f"Comma-separated benchmark configs. Available configs: {extract_all_configs(BENCHMARK_CONFIGS)}",
164+
)
85165

86166
return parser.parse_args()
87167

@@ -98,11 +178,16 @@ def set_output(name: str, val: Any) -> None:
98178
set_output("benchmark_configs", {"include": [...]})
99179
"""
100180

101-
if os.getenv("GITHUB_OUTPUT"):
102-
print(f"Setting {val} to GitHub output")
103-
with open(str(os.getenv("GITHUB_OUTPUT")), "a") as env:
104-
print(f"{name}={val}", file=env)
105-
else:
181+
github_output = os.getenv("GITHUB_OUTPUT")
182+
if not github_output:
183+
print(f"::set-output name={name}::{val}")
184+
return
185+
186+
try:
187+
with open(github_output, "a") as env:
188+
env.write(f"{name}={val}\n")
189+
except PermissionError:
190+
# Fall back to printing in case of permission error in unit tests
106191
print(f"::set-output name={name}::{val}")
107192

108193

@@ -123,7 +208,7 @@ def is_valid_huggingface_model_id(model_name: str) -> bool:
123208
return bool(re.match(pattern, model_name))
124209

125210

126-
def get_benchmark_configs() -> Dict[str, Dict]:
211+
def get_benchmark_configs() -> Dict[str, Dict]: # noqa: C901
127212
"""
128213
Gather benchmark configurations for a given set of models on the target operating system and devices.
129214
@@ -153,48 +238,26 @@ def get_benchmark_configs() -> Dict[str, Dict]:
153238
}
154239
"""
155240
args = parse_args()
156-
target_os = args.os
157241
devices = args.devices
158242
models = args.models
243+
target_os = args.os
244+
target_configs = args.configs
159245

160246
benchmark_configs = {"include": []}
161247

162248
for model_name in models:
163249
configs = []
164-
if is_valid_huggingface_model_id(model_name):
165-
if model_name.startswith("meta-llama/"):
166-
# LLaMA models
167-
repo_name = model_name.split("meta-llama/")[1]
168-
if "qlora" in repo_name.lower():
169-
configs.append("llama3_qlora")
170-
elif "spinquant" in repo_name.lower():
171-
configs.append("llama3_spinquant")
172-
else:
173-
configs.append("llama3_fb16")
174-
configs.extend(
175-
[
176-
config
177-
for config in BENCHMARK_CONFIGS.get(target_os, [])
178-
if config.startswith("llama")
179-
]
250+
configs.extend(generate_compatible_configs(model_name, target_os))
251+
print(f"Discovered all supported configs for model '{model_name}': {configs}")
252+
if target_configs is not None:
253+
for config in target_configs:
254+
if config not in configs:
255+
raise Exception(
256+
f"Unsupported config '{config}' for model '{model_name}' on '{target_os}'. Skipped.\n"
257+
f"Supported configs are: {configs}"
180258
)
181-
else:
182-
# Non-LLaMA models
183-
configs.append("hf_xnnpack_fp32")
184-
elif model_name in MODEL_NAME_TO_MODEL:
185-
# ExecuTorch in-tree non-GenAI models
186-
configs.append("xnnpack_q8")
187-
configs.extend(
188-
[
189-
config
190-
for config in BENCHMARK_CONFIGS.get(target_os, [])
191-
if not config.startswith("llama")
192-
]
193-
)
194-
else:
195-
# Skip unknown models with a warning
196-
logging.warning(f"Unknown or invalid model name '{model_name}'. Skipping.")
197-
continue
259+
configs = target_configs
260+
print(f"Using provided configs {configs} for model '{model_name}'")
198261

199262
# Add configurations for each valid device
200263
for device in devices:
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
import importlib.util
2+
import os
3+
import subprocess
4+
import sys
5+
import unittest
6+
from unittest.mock import mock_open, patch
7+
8+
import pytest
9+
10+
# Dynamically import the script
11+
script_path = os.path.join(".ci", "scripts", "gather_benchmark_configs.py")
12+
spec = importlib.util.spec_from_file_location("gather_benchmark_configs", script_path)
13+
gather_benchmark_configs = importlib.util.module_from_spec(spec)
14+
spec.loader.exec_module(gather_benchmark_configs)
15+
16+
17+
@pytest.mark.skipif(
18+
sys.platform != "linux", reason="The script under test runs on Linux runners only"
19+
)
20+
class TestGatehrBenchmarkConfigs(unittest.TestCase):
21+
22+
def test_extract_all_configs_android(self):
23+
android_configs = gather_benchmark_configs.extract_all_configs(
24+
gather_benchmark_configs.BENCHMARK_CONFIGS, "android"
25+
)
26+
self.assertIn("xnnpack_q8", android_configs)
27+
self.assertIn("qnn_q8", android_configs)
28+
self.assertIn("llama3_spinquant", android_configs)
29+
self.assertIn("llama3_qlora", android_configs)
30+
31+
def test_extract_all_configs_ios(self):
32+
ios_configs = gather_benchmark_configs.extract_all_configs(
33+
gather_benchmark_configs.BENCHMARK_CONFIGS, "ios"
34+
)
35+
36+
self.assertIn("xnnpack_q8", ios_configs)
37+
self.assertIn("coreml_fp16", ios_configs)
38+
self.assertIn("mps", ios_configs)
39+
self.assertIn("llama3_coreml_ane", ios_configs)
40+
self.assertIn("llama3_spinquant", ios_configs)
41+
self.assertIn("llama3_qlora", ios_configs)
42+
43+
def test_generate_compatible_configs_llama_model(self):
44+
model_name = "meta-llama/Llama-3.2-1B"
45+
target_os = "ios"
46+
result = gather_benchmark_configs.generate_compatible_configs(
47+
model_name, target_os
48+
)
49+
expected = ["llama3_fb16", "llama3_coreml_ane"]
50+
self.assertEqual(result, expected)
51+
52+
target_os = "android"
53+
result = gather_benchmark_configs.generate_compatible_configs(
54+
model_name, target_os
55+
)
56+
expected = ["llama3_fb16"]
57+
self.assertEqual(result, expected)
58+
59+
def test_generate_compatible_configs_quantized_llama_model(self):
60+
model_name = "meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8"
61+
result = gather_benchmark_configs.generate_compatible_configs(model_name, None)
62+
expected = ["llama3_spinquant"]
63+
self.assertEqual(result, expected)
64+
65+
model_name = "meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8"
66+
result = gather_benchmark_configs.generate_compatible_configs(model_name, None)
67+
expected = ["llama3_qlora"]
68+
self.assertEqual(result, expected)
69+
70+
def test_generate_compatible_configs_non_genai_model(self):
71+
model_name = "mv2"
72+
target_os = "xplat"
73+
result = gather_benchmark_configs.generate_compatible_configs(
74+
model_name, target_os
75+
)
76+
expected = ["xnnpack_q8"]
77+
self.assertEqual(result, expected)
78+
79+
target_os = "android"
80+
result = gather_benchmark_configs.generate_compatible_configs(
81+
model_name, target_os
82+
)
83+
expected = ["xnnpack_q8", "qnn_q8"]
84+
self.assertEqual(result, expected)
85+
86+
target_os = "ios"
87+
result = gather_benchmark_configs.generate_compatible_configs(
88+
model_name, target_os
89+
)
90+
expected = ["xnnpack_q8", "coreml_fp16", "mps"]
91+
self.assertEqual(result, expected)
92+
93+
def test_generate_compatible_configs_unknown_model(self):
94+
model_name = "unknown_model"
95+
target_os = "ios"
96+
result = gather_benchmark_configs.generate_compatible_configs(
97+
model_name, target_os
98+
)
99+
self.assertEqual(result, [])
100+
101+
def test_is_valid_huggingface_model_id_valid(self):
102+
valid_model = "meta-llama/Llama-3.2-1B"
103+
self.assertTrue(
104+
gather_benchmark_configs.is_valid_huggingface_model_id(valid_model)
105+
)
106+
107+
@patch("builtins.open", new_callable=mock_open)
108+
@patch("os.getenv", return_value=None)
109+
def test_set_output_no_github_env(self, mock_getenv, mock_file):
110+
with patch("builtins.print") as mock_print:
111+
gather_benchmark_configs.set_output("test_name", "test_value")
112+
mock_print.assert_called_with("::set-output name=test_name::test_value")
113+
114+
def test_device_pools_contains_all_devices(self):
115+
expected_devices = [
116+
"apple_iphone_15",
117+
"apple_iphone_15+ios_18",
118+
"samsung_galaxy_s22",
119+
"samsung_galaxy_s24",
120+
"google_pixel_8_pro",
121+
]
122+
for device in expected_devices:
123+
self.assertIn(device, gather_benchmark_configs.DEVICE_POOLS)
124+
125+
def test_gather_benchmark_configs_cli(self):
126+
args = {
127+
"models": "mv2,dl3",
128+
"os": "ios",
129+
"devices": "apple_iphone_15",
130+
"configs": None,
131+
}
132+
133+
cmd = ["python", ".ci/scripts/gather_benchmark_configs.py"]
134+
for key, value in args.items():
135+
if value is not None:
136+
cmd.append(f"--{key}")
137+
cmd.append(value)
138+
139+
result = subprocess.run(cmd, capture_output=True, text=True)
140+
self.assertEqual(result.returncode, 0, f"Error: {result.stderr}")
141+
self.assertIn('"model": "mv2"', result.stdout)
142+
self.assertIn('"model": "dl3"', result.stdout)
143+
self.assertIn('"config": "coreml_fp16"', result.stdout)
144+
self.assertIn('"config": "xnnpack_q8"', result.stdout)
145+
self.assertIn('"config": "mps"', result.stdout)
146+
147+
def test_gather_benchmark_configs_cli_specified_configs(self):
148+
args = {
149+
"models": "mv2,dl3",
150+
"os": "ios",
151+
"devices": "apple_iphone_15",
152+
"configs": "coreml_fp16,xnnpack_q8",
153+
}
154+
155+
cmd = ["python", ".ci/scripts/gather_benchmark_configs.py"]
156+
for key, value in args.items():
157+
if value is not None:
158+
cmd.append(f"--{key}")
159+
cmd.append(value)
160+
161+
result = subprocess.run(cmd, capture_output=True, text=True)
162+
self.assertEqual(result.returncode, 0, f"Error: {result.stderr}")
163+
self.assertIn('"model": "mv2"', result.stdout)
164+
self.assertIn('"model": "dl3"', result.stdout)
165+
self.assertIn('"config": "coreml_fp16"', result.stdout)
166+
self.assertIn('"config": "xnnpack_q8"', result.stdout)
167+
self.assertNotIn('"config": "mps"', result.stdout)
168+
169+
def test_gather_benchmark_configs_cli_specified_configs_raise(self):
170+
args = {
171+
"models": "mv2,dl3",
172+
"os": "ios",
173+
"devices": "apple_iphone_15",
174+
"configs": "qnn_q8",
175+
}
176+
177+
cmd = ["python", ".ci/scripts/gather_benchmark_configs.py"]
178+
for key, value in args.items():
179+
if value is not None:
180+
cmd.append(f"--{key}")
181+
cmd.append(value)
182+
183+
result = subprocess.run(cmd, capture_output=True, text=True)
184+
self.assertEqual(result.returncode, 1, f"Error: {result.stderr}")
185+
self.assertIn("Unsupported config 'qnn_q8'", result.stderr)
186+
187+
188+
if __name__ == "__main__":
189+
unittest.main()

0 commit comments

Comments
 (0)