Skip to content

Commit 6ec0a23

Browse files
committed
update smoke test
1 parent 2d47bca commit 6ec0a23

File tree

5 files changed

+71
-129
lines changed

5 files changed

+71
-129
lines changed

.ci/scripts/gather_test_models.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -113,13 +113,10 @@ def model_should_run_on_target_os(model: str, target_os: str) -> bool:
113113
return model not in ["llava"]
114114

115115

116-
def export_models_for_ci() -> dict[str, dict]:
116+
def export_models_for_ci(target_os: str, event: str) -> dict[str, Any]:
117117
"""
118118
This gathers all the example models that we want to test on GitHub OSS CI
119119
"""
120-
args = parse_args()
121-
target_os = args.target_os
122-
event = args.event
123120

124121
# This is the JSON syntax for configuration matrix used by GitHub
125122
# https://docs.github.com/en/actions/using-jobs/using-a-matrix-for-your-jobs
@@ -175,9 +172,16 @@ def export_models_for_ci() -> dict[str, dict]:
175172
record["runner"] = CUSTOM_RUNNERS[target_os][name]
176173

177174
models["include"].append(record)
175+
176+
return models
178177

179-
set_output("models", json.dumps(models))
178+
180179

181180

182181
if __name__ == "__main__":
183-
export_models_for_ci()
182+
args = parse_args()
183+
models = export_models_for_ci(
184+
target_os = args.target_os,
185+
event = args.event
186+
)
187+
set_output("models", json.dumps(models))

.github/workflows/build-wheels-linux.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ jobs:
4141
- repository: pytorch/executorch
4242
pre-script: build/packaging/pre_build_script.sh
4343
post-script: build/packaging/post_build_script.sh
44-
smoke-test-script: build/packaging/smoke_test.py
44+
smoke-test-script: build/util/packaging/wheel_test.py --target-os linux
4545
package-name: executorch
4646
name: ${{ matrix.repository }}
4747
uses: pytorch/test-infra/.github/workflows/build_wheels_linux.yml@main

.github/workflows/build-wheels-macos.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ jobs:
4141
- repository: pytorch/executorch
4242
pre-script: build/packaging/pre_build_script.sh
4343
post-script: build/packaging/post_build_script.sh
44-
smoke-test-script: build/packaging/smoke_test.py
44+
smoke-test-script: build/util/packaging/wheel_test.py --target-os macos
4545
package-name: executorch
4646
name: ${{ matrix.repository }}
4747
uses: pytorch/test-infra/.github/workflows/build_wheels_macos.yml@main

build/packaging/smoke_test.py

Lines changed: 0 additions & 121 deletions
This file was deleted.

build/util/packaging/wheel_test.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import sys
2+
import os
3+
import argparse
4+
import subprocess
5+
# Since the .ci folder starts with a period, it is not possible to import it
6+
# directly, so we must use the importlib module to load it.
7+
_PYTHON_PATH = os.getenv("PYTHONPATH")
8+
if _PYTHON_PATH is None:
9+
print("PYTHONPATH is not set")
10+
exit(1)
11+
12+
sys.path.append(os.path.join(_PYTHON_PATH, ".ci"))
13+
from scripts import gather_test_models
14+
15+
16+
def _create_arg_parser() -> argparse.ArgumentParser:
17+
parser = argparse.ArgumentParser()
18+
19+
parser.add_argument(
20+
"--target-os",
21+
type=str,
22+
required=True,
23+
choices=gather_test_models.DEFAULT_RUNNERS.keys(),
24+
help="the target OS",
25+
)
26+
27+
return parser
28+
29+
def _run_test(model_name: str, build_tool: str, backend: str) -> None:
30+
subprocess.run(
31+
[
32+
os.path.join(_PYTHON_PATH, ".ci/scripts/test_model.sh"),
33+
model_name,
34+
build_tool,
35+
backend,
36+
],
37+
env={**os.environ, "PYTHON_EXECUTABLE": "python"},
38+
check=True,
39+
)
40+
41+
if __name__ == "__main__":
42+
args = _create_arg_parser().parse_args()
43+
models = gather_test_models.export_models_for_ci(
44+
target_os = args.target_os,
45+
# Event refers to the type of models that will be downloaded. "pull_request"
46+
# uses higher priority and fast models.
47+
event="pull_request",
48+
).get("include", [])
49+
50+
if len(models) == 0:
51+
print("No models found")
52+
exit(1)
53+
54+
for model in models:
55+
_run_test(
56+
model_name=model["model"],
57+
build_tool=model["build-tool"],
58+
backend=model["backend"],
59+
)

0 commit comments

Comments
 (0)