Skip to content

Commit 6c6f6d9

Browse files
authored
Merge branch 'main' into fix_custom_ops
2 parents ea149bd + 57ef834 commit 6c6f6d9

File tree

323 files changed

+14237
-3755
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

323 files changed

+14237
-3755
lines changed

.ci/docker/requirements-ci.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
mpmath==1.3.0
2-
numpy==2.0.0; python_version >= '3.10'
2+
numpy>=2.0.0; python_version >= '3.10'
33
PyYAML==6.0.1
44
ruamel.yaml==0.17.32
55
sympy==1.12
@@ -8,7 +8,7 @@ tomli==2.0.1
88
torchsr==1.0.4
99
transformers==4.47.1
1010
zstd==1.5.5.1
11-
pandas==2.2.2; python_version >= '3.10'
11+
pandas>=2.2.2; python_version >= '3.10'
1212
pytest==7.2.0
1313
pytest-cov==4.1.0
1414
expecttest==0.1.6
@@ -21,7 +21,7 @@ sphinx-gallery==0.14.0
2121
breathe==4.34.0
2222
exhale==0.2.3
2323
docutils==0.16
24-
matplotlib==3.9.4
24+
matplotlib>=3.9.4
2525
# PyTorch Theme
2626
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
2727
myst-parser==0.18.1

.ci/scripts/__init__.py

Whitespace-only changes.

.ci/scripts/gather_benchmark_configs.py

Lines changed: 152 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99
import logging
1010
import os
1111
import re
12-
from typing import Any, Dict
12+
import sys
13+
from typing import Any, Dict, List, NamedTuple
1314

15+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
1416
from examples.models import MODEL_NAME_TO_MODEL
1517

1618

@@ -45,6 +47,127 @@
4547
}
4648

4749

50+
class DisabledConfig(NamedTuple):
51+
config_name: str
52+
github_issue: str # Link to the GitHub issue
53+
54+
55+
# Updated DISABLED_CONFIGS
56+
DISABLED_CONFIGS: Dict[str, List[DisabledConfig]] = {
57+
"resnet50": [
58+
DisabledConfig(
59+
config_name="qnn_q8",
60+
github_issue="https://github.com/pytorch/executorch/issues/7892",
61+
),
62+
],
63+
"w2l": [
64+
DisabledConfig(
65+
config_name="qnn_q8",
66+
github_issue="https://github.com/pytorch/executorch/issues/7634",
67+
),
68+
],
69+
"mobilebert": [
70+
DisabledConfig(
71+
config_name="mps",
72+
github_issue="https://github.com/pytorch/executorch/issues/7904",
73+
),
74+
],
75+
"edsr": [
76+
DisabledConfig(
77+
config_name="mps",
78+
github_issue="https://github.com/pytorch/executorch/issues/7905",
79+
),
80+
],
81+
"llama": [
82+
DisabledConfig(
83+
config_name="mps",
84+
github_issue="https://github.com/pytorch/executorch/issues/7907",
85+
),
86+
],
87+
}
88+
89+
90+
def extract_all_configs(data, target_os=None):
91+
if isinstance(data, dict):
92+
# If target_os is specified, include "xplat" and the specified branch
93+
include_branches = {"xplat", target_os} if target_os else data.keys()
94+
return [
95+
v
96+
for key, value in data.items()
97+
if key in include_branches
98+
for v in extract_all_configs(value, target_os)
99+
]
100+
elif isinstance(data, list):
101+
return [v for item in data for v in extract_all_configs(item, target_os)]
102+
else:
103+
return [data]
104+
105+
106+
def generate_compatible_configs(model_name: str, target_os=None) -> List[str]:
107+
"""
108+
Generate a list of compatible benchmark configurations for a given model name and target OS.
109+
110+
Args:
111+
model_name (str): The name of the model to generate configurations for.
112+
target_os (Optional[str]): The target operating system (e.g., 'android', 'ios').
113+
114+
Returns:
115+
List[str]: A list of compatible benchmark configurations.
116+
117+
Raises:
118+
None
119+
120+
Example:
121+
generate_compatible_configs('meta-llama/Llama-3.2-1B', 'ios') -> ['llama3_fb16', 'llama3_coreml_ane']
122+
"""
123+
configs = []
124+
if is_valid_huggingface_model_id(model_name):
125+
if model_name.startswith("meta-llama/"):
126+
# LLaMA models
127+
repo_name = model_name.split("meta-llama/")[1]
128+
if "qlora" in repo_name.lower():
129+
configs.append("llama3_qlora")
130+
elif "spinquant" in repo_name.lower():
131+
configs.append("llama3_spinquant")
132+
else:
133+
configs.append("llama3_fb16")
134+
configs.extend(
135+
[
136+
config
137+
for config in BENCHMARK_CONFIGS.get(target_os, [])
138+
if config.startswith("llama")
139+
]
140+
)
141+
else:
142+
# Non-LLaMA models
143+
configs.append("hf_xnnpack_fp32")
144+
elif model_name in MODEL_NAME_TO_MODEL:
145+
# ExecuTorch in-tree non-GenAI models
146+
configs.append("xnnpack_q8")
147+
if target_os != "xplat":
148+
# Add OS-specific configs
149+
configs.extend(
150+
[
151+
config
152+
for config in BENCHMARK_CONFIGS.get(target_os, [])
153+
if not config.startswith("llama")
154+
]
155+
)
156+
else:
157+
# Skip unknown models with a warning
158+
logging.warning(f"Unknown or invalid model name '{model_name}'. Skipping.")
159+
160+
# Remove disabled configs for the given model
161+
disabled_configs = DISABLED_CONFIGS.get(model_name, [])
162+
disabled_config_names = {disabled.config_name for disabled in disabled_configs}
163+
for disabled in disabled_configs:
164+
print(
165+
f"Excluding disabled config: '{disabled.config_name}' for model '{model_name}' on '{target_os}'. Linked GitHub issue: {disabled.github_issue}"
166+
)
167+
configs = [config for config in configs if config not in disabled_config_names]
168+
return configs
169+
170+
48171
def parse_args() -> Any:
49172
"""
50173
Parse command-line arguments.
@@ -82,6 +205,11 @@ def comma_separated(value: str):
82205
type=comma_separated, # Use the custom parser for comma-separated values
83206
help=f"Comma-separated device names. Available devices: {list(DEVICE_POOLS.keys())}",
84207
)
208+
parser.add_argument(
209+
"--configs",
210+
type=comma_separated, # Use the custom parser for comma-separated values
211+
help=f"Comma-separated benchmark configs. Available configs: {extract_all_configs(BENCHMARK_CONFIGS)}",
212+
)
85213

86214
return parser.parse_args()
87215

@@ -98,11 +226,16 @@ def set_output(name: str, val: Any) -> None:
98226
set_output("benchmark_configs", {"include": [...]})
99227
"""
100228

101-
if os.getenv("GITHUB_OUTPUT"):
102-
print(f"Setting {val} to GitHub output")
103-
with open(str(os.getenv("GITHUB_OUTPUT")), "a") as env:
104-
print(f"{name}={val}", file=env)
105-
else:
229+
github_output = os.getenv("GITHUB_OUTPUT")
230+
if not github_output:
231+
print(f"::set-output name={name}::{val}")
232+
return
233+
234+
try:
235+
with open(github_output, "a") as env:
236+
env.write(f"{name}={val}\n")
237+
except PermissionError:
238+
# Fall back to printing in case of permission error in unit tests
106239
print(f"::set-output name={name}::{val}")
107240

108241

@@ -123,7 +256,7 @@ def is_valid_huggingface_model_id(model_name: str) -> bool:
123256
return bool(re.match(pattern, model_name))
124257

125258

126-
def get_benchmark_configs() -> Dict[str, Dict]:
259+
def get_benchmark_configs() -> Dict[str, Dict]: # noqa: C901
127260
"""
128261
Gather benchmark configurations for a given set of models on the target operating system and devices.
129262
@@ -153,48 +286,26 @@ def get_benchmark_configs() -> Dict[str, Dict]:
153286
}
154287
"""
155288
args = parse_args()
156-
target_os = args.os
157289
devices = args.devices
158290
models = args.models
291+
target_os = args.os
292+
target_configs = args.configs
159293

160294
benchmark_configs = {"include": []}
161295

162296
for model_name in models:
163297
configs = []
164-
if is_valid_huggingface_model_id(model_name):
165-
if model_name.startswith("meta-llama/"):
166-
# LLaMA models
167-
repo_name = model_name.split("meta-llama/")[1]
168-
if "qlora" in repo_name.lower():
169-
configs.append("llama3_qlora")
170-
elif "spinquant" in repo_name.lower():
171-
configs.append("llama3_spinquant")
172-
else:
173-
configs.append("llama3_fb16")
174-
configs.extend(
175-
[
176-
config
177-
for config in BENCHMARK_CONFIGS.get(target_os, [])
178-
if config.startswith("llama")
179-
]
298+
configs.extend(generate_compatible_configs(model_name, target_os))
299+
print(f"Discovered all supported configs for model '{model_name}': {configs}")
300+
if target_configs is not None:
301+
for config in target_configs:
302+
if config not in configs:
303+
raise Exception(
304+
f"Unsupported config '{config}' for model '{model_name}' on '{target_os}'. Skipped.\n"
305+
f"Supported configs are: {configs}"
180306
)
181-
else:
182-
# Non-LLaMA models
183-
configs.append("hf_xnnpack_fp32")
184-
elif model_name in MODEL_NAME_TO_MODEL:
185-
# ExecuTorch in-tree non-GenAI models
186-
configs.append("xnnpack_q8")
187-
configs.extend(
188-
[
189-
config
190-
for config in BENCHMARK_CONFIGS.get(target_os, [])
191-
if not config.startswith("llama")
192-
]
193-
)
194-
else:
195-
# Skip unknown models with a warning
196-
logging.warning(f"Unknown or invalid model name '{model_name}'. Skipping.")
197-
continue
307+
configs = target_configs
308+
print(f"Using provided configs {configs} for model '{model_name}'")
198309

199310
# Add configurations for each valid device
200311
for device in devices:

.ci/scripts/test_llama.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ fi
112112

113113
if [[ "${MODE}" =~ .*quantize_kv.* ]]; then
114114
QUANTIZE_KV_CACHE=ON
115+
# quantize_kv cache transform uses custom kv cache update op
116+
CUSTOM=ON
115117
else
116118
QUANTIZE_KV_CACHE=OFF
117119
fi

.ci/scripts/test_model.sh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,13 @@ test_model_with_qnn() {
169169
EXPORT_SCRIPT=inception_v3
170170
elif [[ "${MODEL_NAME}" == "vit" ]]; then
171171
EXPORT_SCRIPT=torchvision_vit
172+
elif [[ "${MODEL_NAME}" == "edsr" ]]; then
173+
EXPORT_SCRIPT=edsr
174+
# Additional deps for edsr
175+
pip install piq
176+
else
177+
echo "Unsupported model $MODEL_NAME"
178+
exit 1
172179
fi
173180

174181
# Use SM8450 for S22, SM8550 for S23, and SM8560 for S24

0 commit comments

Comments
 (0)