Skip to content

Commit 08faac8

Browse files
committed
Update on "Reuse GELU implementation from PyTorch core"
kernels/optimized doesn't need to support embedded systems, so it can just take a header-only dep on PyTorch. Note that, because we will pick up Sleef internally and ignore it externally thanks to ATen vec, this PR gets to enable optimized GELU in OSS. Testing: CI to make sure this doesn't break mobile build modes; happy to take advice on anything not currently covered that might break. Differential Revision: [D66335522](https://our.internmc.facebook.com/intern/diff/D66335522/) [ghstack-poisoned]
2 parents 2006d40 + dc5a503 commit 08faac8

File tree

755 files changed

+19085
-5972
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

755 files changed

+19085
-5972
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
cb4af2b4752220c3ca3de6e7e95b3a6fdc31f794
1+
4eeb6f34ef415aa8701a36453380e9c1ba2c8f3a

.ci/docker/requirements-ci.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
mpmath==1.3.0
2-
numpy==2.0.0; python_version >= '3.10'
2+
numpy>=2.0.0; python_version >= '3.10'
33
PyYAML==6.0.1
44
ruamel.yaml==0.17.32
55
sympy==1.12
@@ -8,7 +8,7 @@ tomli==2.0.1
88
torchsr==1.0.4
99
transformers==4.47.1
1010
zstd==1.5.5.1
11-
pandas==2.2.2; python_version >= '3.10'
11+
pandas>=2.2.2; python_version >= '3.10'
1212
pytest==7.2.0
1313
pytest-cov==4.1.0
1414
expecttest==0.1.6
@@ -21,7 +21,7 @@ sphinx-gallery==0.14.0
2121
breathe==4.34.0
2222
exhale==0.2.3
2323
docutils==0.16
24-
matplotlib==3.9.4
24+
matplotlib>=3.9.4
2525
# PyTorch Theme
2626
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
2727
myst-parser==0.18.1

.ci/scripts/gather_benchmark_configs.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import os
1111
import re
1212
import sys
13-
from typing import Any, Dict, List
13+
from typing import Any, Dict, List, NamedTuple
1414

1515
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
1616
from examples.models import MODEL_NAME_TO_MODEL
@@ -47,6 +47,50 @@
4747
}
4848

4949

50+
class DisabledConfig(NamedTuple):
51+
config_name: str
52+
github_issue: str # Link to the GitHub issue
53+
54+
55+
# Updated DISABLED_CONFIGS
56+
DISABLED_CONFIGS: Dict[str, List[DisabledConfig]] = {
57+
"resnet50": [
58+
DisabledConfig(
59+
config_name="qnn_q8",
60+
github_issue="https://github.com/pytorch/executorch/issues/7892",
61+
),
62+
],
63+
"w2l": [
64+
DisabledConfig(
65+
config_name="qnn_q8",
66+
github_issue="https://github.com/pytorch/executorch/issues/7634",
67+
),
68+
],
69+
"mobilebert": [
70+
DisabledConfig(
71+
config_name="mps",
72+
github_issue="https://github.com/pytorch/executorch/issues/7904",
73+
),
74+
DisabledConfig(
75+
config_name="qnn_q8",
76+
github_issue="https://github.com/pytorch/executorch/issues/7946",
77+
),
78+
],
79+
"edsr": [
80+
DisabledConfig(
81+
config_name="mps",
82+
github_issue="https://github.com/pytorch/executorch/issues/7905",
83+
),
84+
],
85+
"llama": [
86+
DisabledConfig(
87+
config_name="mps",
88+
github_issue="https://github.com/pytorch/executorch/issues/7907",
89+
),
90+
],
91+
}
92+
93+
5094
def extract_all_configs(data, target_os=None):
5195
if isinstance(data, dict):
5296
# If target_os is specified, include "xplat" and the specified branch
@@ -117,6 +161,14 @@ def generate_compatible_configs(model_name: str, target_os=None) -> List[str]:
117161
# Skip unknown models with a warning
118162
logging.warning(f"Unknown or invalid model name '{model_name}'. Skipping.")
119163

164+
# Remove disabled configs for the given model
165+
disabled_configs = DISABLED_CONFIGS.get(model_name, [])
166+
disabled_config_names = {disabled.config_name for disabled in disabled_configs}
167+
for disabled in disabled_configs:
168+
print(
169+
f"Excluding disabled config: '{disabled.config_name}' for model '{model_name}' on '{target_os}'. Linked GitHub issue: {disabled.github_issue}"
170+
)
171+
configs = [config for config in configs if config not in disabled_config_names]
120172
return configs
121173

122174

.ci/scripts/test_llama.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ fi
112112

113113
if [[ "${MODE}" =~ .*quantize_kv.* ]]; then
114114
QUANTIZE_KV_CACHE=ON
115+
# quantize_kv cache transform uses custom kv cache update op
116+
CUSTOM=ON
115117
else
116118
QUANTIZE_KV_CACHE=OFF
117119
fi

.ci/scripts/test_model.sh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,13 @@ test_model_with_qnn() {
170170
EXPORT_SCRIPT=inception_v3
171171
elif [[ "${MODEL_NAME}" == "vit" ]]; then
172172
EXPORT_SCRIPT=torchvision_vit
173+
elif [[ "${MODEL_NAME}" == "edsr" ]]; then
174+
EXPORT_SCRIPT=edsr
175+
# Additional deps for edsr
176+
pip install piq
177+
else
178+
echo "Unsupported model $MODEL_NAME"
179+
exit 1
173180
fi
174181

175182
# Use SM8450 for S22, SM8550 for S23, and SM8560 for S24

.ci/scripts/tests/test_gather_benchmark_configs.py

Lines changed: 89 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,41 @@
11
import importlib.util
22
import os
3+
import re
34
import subprocess
45
import sys
56
import unittest
67
from unittest.mock import mock_open, patch
78

89
import pytest
910

10-
# Dynamically import the script
11-
script_path = os.path.join(".ci", "scripts", "gather_benchmark_configs.py")
12-
spec = importlib.util.spec_from_file_location("gather_benchmark_configs", script_path)
13-
gather_benchmark_configs = importlib.util.module_from_spec(spec)
14-
spec.loader.exec_module(gather_benchmark_configs)
15-
1611

1712
@pytest.mark.skipif(
1813
sys.platform != "linux", reason="The script under test runs on Linux runners only"
1914
)
2015
class TestGatehrBenchmarkConfigs(unittest.TestCase):
2116

17+
@classmethod
18+
def setUpClass(cls):
19+
# Dynamically import the script
20+
script_path = os.path.join(".ci", "scripts", "gather_benchmark_configs.py")
21+
spec = importlib.util.spec_from_file_location(
22+
"gather_benchmark_configs", script_path
23+
)
24+
cls.gather_benchmark_configs = importlib.util.module_from_spec(spec)
25+
spec.loader.exec_module(cls.gather_benchmark_configs)
26+
2227
def test_extract_all_configs_android(self):
23-
android_configs = gather_benchmark_configs.extract_all_configs(
24-
gather_benchmark_configs.BENCHMARK_CONFIGS, "android"
28+
android_configs = self.gather_benchmark_configs.extract_all_configs(
29+
self.gather_benchmark_configs.BENCHMARK_CONFIGS, "android"
2530
)
2631
self.assertIn("xnnpack_q8", android_configs)
2732
self.assertIn("qnn_q8", android_configs)
2833
self.assertIn("llama3_spinquant", android_configs)
2934
self.assertIn("llama3_qlora", android_configs)
3035

3136
def test_extract_all_configs_ios(self):
32-
ios_configs = gather_benchmark_configs.extract_all_configs(
33-
gather_benchmark_configs.BENCHMARK_CONFIGS, "ios"
37+
ios_configs = self.gather_benchmark_configs.extract_all_configs(
38+
self.gather_benchmark_configs.BENCHMARK_CONFIGS, "ios"
3439
)
3540

3641
self.assertIn("xnnpack_q8", ios_configs)
@@ -40,51 +45,114 @@ def test_extract_all_configs_ios(self):
4045
self.assertIn("llama3_spinquant", ios_configs)
4146
self.assertIn("llama3_qlora", ios_configs)
4247

48+
def test_skip_disabled_configs(self):
49+
# Use patch as a context manager to avoid modifying DISABLED_CONFIGS and BENCHMARK_CONFIGS
50+
with patch.dict(
51+
self.gather_benchmark_configs.DISABLED_CONFIGS,
52+
{
53+
"mv3": [
54+
self.gather_benchmark_configs.DisabledConfig(
55+
config_name="disabled_config1",
56+
github_issue="https://github.com/org/repo/issues/123",
57+
),
58+
self.gather_benchmark_configs.DisabledConfig(
59+
config_name="disabled_config2",
60+
github_issue="https://github.com/org/repo/issues/124",
61+
),
62+
]
63+
},
64+
), patch.dict(
65+
self.gather_benchmark_configs.BENCHMARK_CONFIGS,
66+
{
67+
"ios": [
68+
"disabled_config1",
69+
"disabled_config2",
70+
"enabled_config1",
71+
"enabled_config2",
72+
]
73+
},
74+
):
75+
result = self.gather_benchmark_configs.generate_compatible_configs(
76+
"mv3", target_os="ios"
77+
)
78+
79+
# Assert that disabled configs are excluded
80+
self.assertNotIn("disabled_config1", result)
81+
self.assertNotIn("disabled_config2", result)
82+
# Assert enabled configs are included
83+
self.assertIn("enabled_config1", result)
84+
self.assertIn("enabled_config2", result)
85+
86+
def test_disabled_configs_have_github_links(self):
87+
github_issue_regex = re.compile(r"https://github\.com/.+/.+/issues/\d+")
88+
89+
for (
90+
model_name,
91+
disabled_configs,
92+
) in self.gather_benchmark_configs.DISABLED_CONFIGS.items():
93+
for disabled in disabled_configs:
94+
with self.subTest(model_name=model_name, config=disabled.config_name):
95+
# Assert that disabled is an instance of DisabledConfig
96+
self.assertIsInstance(
97+
disabled, self.gather_benchmark_configs.DisabledConfig
98+
)
99+
100+
# Assert that github_issue is provided and matches the expected pattern
101+
self.assertTrue(
102+
disabled.github_issue
103+
and github_issue_regex.match(disabled.github_issue),
104+
f"Invalid or missing GitHub issue link for '{disabled.config_name}' in model '{model_name}'.",
105+
)
106+
43107
def test_generate_compatible_configs_llama_model(self):
44108
model_name = "meta-llama/Llama-3.2-1B"
45109
target_os = "ios"
46-
result = gather_benchmark_configs.generate_compatible_configs(
110+
result = self.gather_benchmark_configs.generate_compatible_configs(
47111
model_name, target_os
48112
)
49113
expected = ["llama3_fb16", "llama3_coreml_ane"]
50114
self.assertEqual(result, expected)
51115

52116
target_os = "android"
53-
result = gather_benchmark_configs.generate_compatible_configs(
117+
result = self.gather_benchmark_configs.generate_compatible_configs(
54118
model_name, target_os
55119
)
56120
expected = ["llama3_fb16"]
57121
self.assertEqual(result, expected)
58122

59123
def test_generate_compatible_configs_quantized_llama_model(self):
60124
model_name = "meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8"
61-
result = gather_benchmark_configs.generate_compatible_configs(model_name, None)
125+
result = self.gather_benchmark_configs.generate_compatible_configs(
126+
model_name, None
127+
)
62128
expected = ["llama3_spinquant"]
63129
self.assertEqual(result, expected)
64130

65131
model_name = "meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8"
66-
result = gather_benchmark_configs.generate_compatible_configs(model_name, None)
132+
result = self.gather_benchmark_configs.generate_compatible_configs(
133+
model_name, None
134+
)
67135
expected = ["llama3_qlora"]
68136
self.assertEqual(result, expected)
69137

70138
def test_generate_compatible_configs_non_genai_model(self):
71139
model_name = "mv2"
72140
target_os = "xplat"
73-
result = gather_benchmark_configs.generate_compatible_configs(
141+
result = self.gather_benchmark_configs.generate_compatible_configs(
74142
model_name, target_os
75143
)
76144
expected = ["xnnpack_q8"]
77145
self.assertEqual(result, expected)
78146

79147
target_os = "android"
80-
result = gather_benchmark_configs.generate_compatible_configs(
148+
result = self.gather_benchmark_configs.generate_compatible_configs(
81149
model_name, target_os
82150
)
83151
expected = ["xnnpack_q8", "qnn_q8"]
84152
self.assertEqual(result, expected)
85153

86154
target_os = "ios"
87-
result = gather_benchmark_configs.generate_compatible_configs(
155+
result = self.gather_benchmark_configs.generate_compatible_configs(
88156
model_name, target_os
89157
)
90158
expected = ["xnnpack_q8", "coreml_fp16", "mps"]
@@ -93,22 +161,22 @@ def test_generate_compatible_configs_non_genai_model(self):
93161
def test_generate_compatible_configs_unknown_model(self):
94162
model_name = "unknown_model"
95163
target_os = "ios"
96-
result = gather_benchmark_configs.generate_compatible_configs(
164+
result = self.gather_benchmark_configs.generate_compatible_configs(
97165
model_name, target_os
98166
)
99167
self.assertEqual(result, [])
100168

101169
def test_is_valid_huggingface_model_id_valid(self):
102170
valid_model = "meta-llama/Llama-3.2-1B"
103171
self.assertTrue(
104-
gather_benchmark_configs.is_valid_huggingface_model_id(valid_model)
172+
self.gather_benchmark_configs.is_valid_huggingface_model_id(valid_model)
105173
)
106174

107175
@patch("builtins.open", new_callable=mock_open)
108176
@patch("os.getenv", return_value=None)
109177
def test_set_output_no_github_env(self, mock_getenv, mock_file):
110178
with patch("builtins.print") as mock_print:
111-
gather_benchmark_configs.set_output("test_name", "test_value")
179+
self.gather_benchmark_configs.set_output("test_name", "test_value")
112180
mock_print.assert_called_with("::set-output name=test_name::test_value")
113181

114182
def test_device_pools_contains_all_devices(self):
@@ -120,7 +188,7 @@ def test_device_pools_contains_all_devices(self):
120188
"google_pixel_8_pro",
121189
]
122190
for device in expected_devices:
123-
self.assertIn(device, gather_benchmark_configs.DEVICE_POOLS)
191+
self.assertIn(device, self.gather_benchmark_configs.DEVICE_POOLS)
124192

125193
def test_gather_benchmark_configs_cli(self):
126194
args = {

.github/workflows/_android.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@ on:
77
jobs:
88
build-llm-demo:
99
name: build-llm-demo
10-
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
10+
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
11+
permissions:
12+
id-token: write
13+
contents: read
1114
with:
1215
runner: linux.2xlarge
1316
docker-image: executorch-ubuntu-22.04-clang12-android

.github/workflows/_unittest.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@ on:
1414

1515
jobs:
1616
linux:
17-
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
17+
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
18+
permissions:
19+
id-token: write
20+
contents: read
1821
with:
1922
runner: linux.2xlarge
2023
docker-image: ${{ inputs.docker-image }}

.github/workflows/android-perf.yml

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,10 @@ jobs:
155155

156156
export-models:
157157
name: export-models
158-
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
158+
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
159+
permissions:
160+
id-token: write
161+
contents: read
159162
needs: set-parameters
160163
secrets: inherit
161164
strategy:
@@ -332,7 +335,10 @@ jobs:
332335
333336
build-benchmark-app:
334337
name: build-benchmark-app
335-
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
338+
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
339+
permissions:
340+
id-token: write
341+
contents: read
336342
needs: set-parameters
337343
with:
338344
runner: linux.2xlarge

.github/workflows/android-release-artifacts.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,10 @@ jobs:
3131
build-aar:
3232
name: build-aar
3333
needs: check-if-aar-exists
34-
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
34+
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
35+
permissions:
36+
id-token: write
37+
contents: read
3538
with:
3639
runner: linux.2xlarge
3740
docker-image: executorch-ubuntu-22.04-clang12-android

0 commit comments

Comments
 (0)