Skip to content

Commit d04173d

Browse files
committed
Update on "[ET-VK][ez] Update requirements for partitioning to_dim_order_copy"
## Context The previous registration of the to dim order copy op is incorrect. Currently, there is no implementation for the op in the Vulkan backend, but since Vulkan manages memory layout internally the op node can be removed as long as the only thing being changed is dim order. In some instances the op can be used to modify the dtype, in which case it will not be removed and the Vulkan delegate cannot execute the op correctly. Therefore, update the registration of the op to reflect this restriction. This diff should unblock enabling dim order ops for Vulkan. Differential Revision: [D68528213](https://our.internmc.facebook.com/intern/diff/D68528213/) [ghstack-poisoned]
2 parents 93679af + 021df27 commit d04173d

File tree

67 files changed

+1545
-533
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

67 files changed

+1545
-533
lines changed

.ci/scripts/__init__.py

Whitespace-only changes.

.ci/scripts/gather_benchmark_configs.py

Lines changed: 104 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99
import logging
1010
import os
1111
import re
12-
from typing import Any, Dict
12+
import sys
13+
from typing import Any, Dict, List
1314

15+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
1416
from examples.models import MODEL_NAME_TO_MODEL
1517

1618

@@ -45,6 +47,79 @@
4547
}
4648

4749

50+
def extract_all_configs(data, target_os=None):
51+
if isinstance(data, dict):
52+
# If target_os is specified, include "xplat" and the specified branch
53+
include_branches = {"xplat", target_os} if target_os else data.keys()
54+
return [
55+
v
56+
for key, value in data.items()
57+
if key in include_branches
58+
for v in extract_all_configs(value, target_os)
59+
]
60+
elif isinstance(data, list):
61+
return [v for item in data for v in extract_all_configs(item, target_os)]
62+
else:
63+
return [data]
64+
65+
66+
def generate_compatible_configs(model_name: str, target_os=None) -> List[str]:
67+
"""
68+
Generate a list of compatible benchmark configurations for a given model name and target OS.
69+
70+
Args:
71+
model_name (str): The name of the model to generate configurations for.
72+
target_os (Optional[str]): The target operating system (e.g., 'android', 'ios').
73+
74+
Returns:
75+
List[str]: A list of compatible benchmark configurations.
76+
77+
Raises:
78+
None
79+
80+
Example:
81+
generate_compatible_configs('meta-llama/Llama-3.2-1B', 'ios') -> ['llama3_fb16', 'llama3_coreml_ane']
82+
"""
83+
configs = []
84+
if is_valid_huggingface_model_id(model_name):
85+
if model_name.startswith("meta-llama/"):
86+
# LLaMA models
87+
repo_name = model_name.split("meta-llama/")[1]
88+
if "qlora" in repo_name.lower():
89+
configs.append("llama3_qlora")
90+
elif "spinquant" in repo_name.lower():
91+
configs.append("llama3_spinquant")
92+
else:
93+
configs.append("llama3_fb16")
94+
configs.extend(
95+
[
96+
config
97+
for config in BENCHMARK_CONFIGS.get(target_os, [])
98+
if config.startswith("llama")
99+
]
100+
)
101+
else:
102+
# Non-LLaMA models
103+
configs.append("hf_xnnpack_fp32")
104+
elif model_name in MODEL_NAME_TO_MODEL:
105+
# ExecuTorch in-tree non-GenAI models
106+
configs.append("xnnpack_q8")
107+
if target_os != "xplat":
108+
# Add OS-specific configs
109+
configs.extend(
110+
[
111+
config
112+
for config in BENCHMARK_CONFIGS.get(target_os, [])
113+
if not config.startswith("llama")
114+
]
115+
)
116+
else:
117+
# Skip unknown models with a warning
118+
logging.warning(f"Unknown or invalid model name '{model_name}'. Skipping.")
119+
120+
return configs
121+
122+
48123
def parse_args() -> Any:
49124
"""
50125
Parse command-line arguments.
@@ -82,6 +157,11 @@ def comma_separated(value: str):
82157
type=comma_separated, # Use the custom parser for comma-separated values
83158
help=f"Comma-separated device names. Available devices: {list(DEVICE_POOLS.keys())}",
84159
)
160+
parser.add_argument(
161+
"--configs",
162+
type=comma_separated, # Use the custom parser for comma-separated values
163+
help=f"Comma-separated benchmark configs. Available configs: {extract_all_configs(BENCHMARK_CONFIGS)}",
164+
)
85165

86166
return parser.parse_args()
87167

@@ -98,11 +178,16 @@ def set_output(name: str, val: Any) -> None:
98178
set_output("benchmark_configs", {"include": [...]})
99179
"""
100180

101-
if os.getenv("GITHUB_OUTPUT"):
102-
print(f"Setting {val} to GitHub output")
103-
with open(str(os.getenv("GITHUB_OUTPUT")), "a") as env:
104-
print(f"{name}={val}", file=env)
105-
else:
181+
github_output = os.getenv("GITHUB_OUTPUT")
182+
if not github_output:
183+
print(f"::set-output name={name}::{val}")
184+
return
185+
186+
try:
187+
with open(github_output, "a") as env:
188+
env.write(f"{name}={val}\n")
189+
except PermissionError:
190+
# Fall back to printing in case of permission error in unit tests
106191
print(f"::set-output name={name}::{val}")
107192

108193

@@ -123,7 +208,7 @@ def is_valid_huggingface_model_id(model_name: str) -> bool:
123208
return bool(re.match(pattern, model_name))
124209

125210

126-
def get_benchmark_configs() -> Dict[str, Dict]:
211+
def get_benchmark_configs() -> Dict[str, Dict]: # noqa: C901
127212
"""
128213
Gather benchmark configurations for a given set of models on the target operating system and devices.
129214
@@ -153,48 +238,26 @@ def get_benchmark_configs() -> Dict[str, Dict]:
153238
}
154239
"""
155240
args = parse_args()
156-
target_os = args.os
157241
devices = args.devices
158242
models = args.models
243+
target_os = args.os
244+
target_configs = args.configs
159245

160246
benchmark_configs = {"include": []}
161247

162248
for model_name in models:
163249
configs = []
164-
if is_valid_huggingface_model_id(model_name):
165-
if model_name.startswith("meta-llama/"):
166-
# LLaMA models
167-
repo_name = model_name.split("meta-llama/")[1]
168-
if "qlora" in repo_name.lower():
169-
configs.append("llama3_qlora")
170-
elif "spinquant" in repo_name.lower():
171-
configs.append("llama3_spinquant")
172-
else:
173-
configs.append("llama3_fb16")
174-
configs.extend(
175-
[
176-
config
177-
for config in BENCHMARK_CONFIGS.get(target_os, [])
178-
if config.startswith("llama")
179-
]
250+
configs.extend(generate_compatible_configs(model_name, target_os))
251+
print(f"Discovered all supported configs for model '{model_name}': {configs}")
252+
if target_configs is not None:
253+
for config in target_configs:
254+
if config not in configs:
255+
raise Exception(
256+
f"Unsupported config '{config}' for model '{model_name}' on '{target_os}'. Skipped.\n"
257+
f"Supported configs are: {configs}"
180258
)
181-
else:
182-
# Non-LLaMA models
183-
configs.append("hf_xnnpack_fp32")
184-
elif model_name in MODEL_NAME_TO_MODEL:
185-
# ExecuTorch in-tree non-GenAI models
186-
configs.append("xnnpack_q8")
187-
configs.extend(
188-
[
189-
config
190-
for config in BENCHMARK_CONFIGS.get(target_os, [])
191-
if not config.startswith("llama")
192-
]
193-
)
194-
else:
195-
# Skip unknown models with a warning
196-
logging.warning(f"Unknown or invalid model name '{model_name}'. Skipping.")
197-
continue
259+
configs = target_configs
260+
print(f"Using provided configs {configs} for model '{model_name}'")
198261

199262
# Add configurations for each valid device
200263
for device in devices:

.ci/scripts/test_llama.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ fi
112112

113113
if [[ "${MODE}" =~ .*quantize_kv.* ]]; then
114114
QUANTIZE_KV_CACHE=ON
115+
# quantize_kv cache transform uses custom kv cache update op
116+
CUSTOM=ON
115117
else
116118
QUANTIZE_KV_CACHE=OFF
117119
fi

.ci/scripts/test_model.sh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,13 @@ test_model_with_qnn() {
169169
EXPORT_SCRIPT=inception_v3
170170
elif [[ "${MODEL_NAME}" == "vit" ]]; then
171171
EXPORT_SCRIPT=torchvision_vit
172+
elif [[ "${MODEL_NAME}" == "edsr" ]]; then
173+
EXPORT_SCRIPT=edsr
174+
# Additional deps for edsr
175+
pip install piq
176+
else
177+
echo "Unsupported model $MODEL_NAME"
178+
exit 1
172179
fi
173180

174181
# Use SM8450 for S22, SM8550 for S23, and SM8560 for S24

0 commit comments

Comments
 (0)