Skip to content

Commit 7463ec5

Browse files
jeffkilpatrickGitHub Enterprise
authored andcommitted
[AISW-157068] onnxruntime model tests for GPU (microsoft#400)
1 parent ffa8bc3 commit 7463ec5

File tree

7 files changed

+124
-23
lines changed

7 files changed

+124
-23
lines changed

.github/workflows/qualcomm-internal-build-and-test-single-os.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,4 +189,5 @@ jobs:
189189
qcom/build_and_test.py
190190
--target-py-version ${{inputs.target_py_vsn}}
191191
create_venv
192-
test_ort_${{inputs.target_os}}_${{inputs.target_arch}}_pysmoke --only
192+
test_ort_${{inputs.target_os}}_${{inputs.target_arch}}_pysmoke
193+
test_ort_${{inputs.target_os}}_${{inputs.target_arch}}_pygpu --only

qcom/build_and_test.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
QdcTestsTask,
3939
)
4040
from ep_build.tasks.docker import MANYLINUX_2_34_AARCH64_TAG, DockerBuildTask
41-
from ep_build.tasks.python import CreateOrtVenvTask, OrtWheelSmokeTestTask, RunLinterTask
41+
from ep_build.tasks.python import CreateOrtVenvTask, OrtWheelGpuModelTestTask, OrtWheelSmokeTestTask, RunLinterTask
4242
from ep_build.typing import BuildConfigT, TargetPyVersionT
4343
from ep_build.util import (
4444
DEFAULT_PYTHON,
@@ -743,6 +743,22 @@ def test_ort_windows_arm64(self, plan: Plan) -> str:
743743
)
744744
)
745745

746+
if is_host_windows():
747+
748+
@task
749+
@depends(["build_ort_windows_arm64"])
750+
def test_ort_windows_arm64_pygpu(self, plan: Plan) -> str:
751+
assert self.__target_py_version is not None
752+
return plan.add_step(
753+
OrtWheelGpuModelTestTask(
754+
"Running GPU model tests on ARM64",
755+
self.__venv_path,
756+
"arm64",
757+
self.__config,
758+
self.__target_py_version,
759+
)
760+
)
761+
746762
if is_host_windows():
747763

748764
@task
@@ -776,6 +792,22 @@ def test_ort_windows_arm64ec(self, plan: Plan) -> str:
776792
)
777793
)
778794

795+
if is_host_windows():
796+
797+
@task
798+
@depends(["build_ort_windows_arm64ec"])
799+
def test_ort_windows_arm64ec_pygpu(self, plan: Plan) -> str:
800+
assert self.__target_py_version is not None
801+
return plan.add_step(
802+
OrtWheelGpuModelTestTask(
803+
"Running GPU model tests on ARM64ec",
804+
self.__venv_path,
805+
"arm64ec",
806+
self.__config,
807+
self.__target_py_version,
808+
)
809+
)
810+
779811
if is_host_windows():
780812

781813
@task
@@ -792,6 +824,22 @@ def test_ort_windows_arm64ec_pysmoke(self, plan: Plan) -> str:
792824
)
793825
)
794826

827+
if is_host_windows():
828+
829+
@task
830+
@depends(["build_ort_windows_arm64x"])
831+
def test_ort_windows_arm64x_pygpu(self, plan: Plan) -> str:
832+
assert self.__target_py_version is not None
833+
return plan.add_step(
834+
OrtWheelGpuModelTestTask(
835+
"Running GPU model tests on ARM64x",
836+
self.__venv_path,
837+
"arm64x",
838+
self.__config,
839+
self.__target_py_version,
840+
)
841+
)
842+
795843
if is_host_windows():
796844

797845
@task

qcom/ep_build/tasks/python.py

Lines changed: 58 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import functools
55
import operator
6+
import os
67
from collections.abc import Callable, Iterable, Mapping
78
from pathlib import Path
89
from typing import Literal
@@ -175,16 +176,17 @@ def make_wheel_test(self, tmpdir: Path) -> Task:
175176
)
176177

177178

178-
class OrtWheelSmokeTestTask(OrtWheelTestTask):
179+
class OrtWheelModelTestTask(OrtWheelTestTask):
179180
def __init__(
180181
self,
181182
group_name: str | None,
182183
venv: Path | None,
183184
wheel_pe_arch: WheelPeArchT,
184185
config: BuildConfigT,
185186
py_version: TargetPyVersionT,
187+
test_files_or_dirs: list[str],
188+
get_test_env: Callable[[], Mapping[str, str]],
186189
) -> None:
187-
self.__venv = venv
188190
self.__wheel_pe_arch = wheel_pe_arch
189191
self.__config = config
190192
self.__py_version = py_version
@@ -195,11 +197,8 @@ def __init__(
195197
wheel_pe_arch,
196198
py_version,
197199
self.__find_wheel,
198-
[
199-
str(REPO_ROOT / "qcom" / "model_test" / "smoke_test.py"),
200-
str(REPO_ROOT / "qcom" / "model_test" / "model_zoo_test.py"),
201-
],
202-
get_test_env=self.__get_test_env,
200+
test_files_or_dirs,
201+
get_test_env,
203202
)
204203

205204
def __find_wheel(self) -> Path:
@@ -231,12 +230,58 @@ def __find_wheel(self) -> Path:
231230
raise FileNotFoundError("Could not find onnxruntime wheel.")
232231
return found_wheels[0]
233232

234-
def __get_test_env(self) -> Mapping[str, str]:
235-
"""Get an environment that tells the tests where to find their models."""
236-
return {
237-
"ORT_WHEEL_SMOKE_TEST_ROOT": str(get_onnx_models_root(self.__venv) / "testdata" / "smoke"),
238-
"ORT_MODEL_ZOO_TEST_ROOTS": str(get_model_zoo_root(self.__venv) / "winml-cert"),
239-
}
233+
234+
class OrtWheelSmokeTestTask(OrtWheelModelTestTask):
235+
def __init__(
236+
self,
237+
group_name: str | None,
238+
venv: Path | None,
239+
wheel_pe_arch: WheelPeArchT,
240+
config: BuildConfigT,
241+
py_version: TargetPyVersionT,
242+
) -> None:
243+
super().__init__(
244+
group_name,
245+
venv,
246+
wheel_pe_arch,
247+
config,
248+
py_version,
249+
[
250+
str(REPO_ROOT / "qcom" / "model_test" / "smoke_test.py"),
251+
str(REPO_ROOT / "qcom" / "model_test" / "model_zoo_test.py"),
252+
],
253+
get_test_env=lambda: {
254+
**os.environ,
255+
"ORT_WHEEL_SMOKE_TEST_ROOT": str(get_onnx_models_root(venv) / "testdata" / "smoke"),
256+
"ORT_MODEL_ZOO_TEST_ROOTS": str(get_model_zoo_root(venv) / "winml-cert"),
257+
},
258+
)
259+
260+
261+
class OrtWheelGpuModelTestTask(OrtWheelModelTestTask):
262+
def __init__(
263+
self,
264+
group_name: str | None,
265+
venv: Path | None,
266+
wheel_pe_arch: WheelPeArchT,
267+
config: BuildConfigT,
268+
py_version: TargetPyVersionT,
269+
) -> None:
270+
super().__init__(
271+
group_name,
272+
venv,
273+
wheel_pe_arch,
274+
config,
275+
py_version,
276+
[
277+
str(REPO_ROOT / "qcom" / "model_test" / "model_zoo_test.py"),
278+
],
279+
get_test_env=lambda: {
280+
**os.environ,
281+
"ORT_MODEL_ZOO_TEST_ROOTS": str(get_model_zoo_root(venv) / "winml-cert-gpu"),
282+
"ORT_MODEL_ZOO_BACKEND": "gpu",
283+
},
284+
)
240285

241286

242287
class RunLinterTask(CompositeTask):

qcom/model_test/model_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ def run(self) -> None:
118118
assert len(inputs) == len(expected)
119119

120120
for ds_idx in range(len(inputs)):
121+
logging.debug(f"Inputs: { {n: t.shape for n, t in inputs[ds_idx].items()} }")
121122
logging.debug(f"Expected outputs: { {n: t.shape for n, t in expected[ds_idx].items()} }")
122123
actual = dict(
123124
zip(self.output_names, cast(Sequence[np.ndarray], self.__session.run([], inputs[ds_idx])), strict=False)

qcom/model_test/model_zoo_test.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,26 @@
33

44
import os
55
from pathlib import Path
6+
from typing import cast, get_args
67

78
import pytest
8-
from model_test import ModelTestCase, ModelTestDef, ModelTestSuite
9+
from model_test import BackendT, ModelTestCase, ModelTestDef, ModelTestSuite
910

1011
MODEL_ZOO_ROOTS = [Path(p) for p in os.getenv("ORT_MODEL_ZOO_TEST_ROOTS", "").split(os.pathsep) if len(p) > 0]
11-
12+
MODEL_ZOO_BACKEND = cast(BackendT, os.getenv("ORT_MODEL_ZOO_BACKEND", "htp"))
13+
assert MODEL_ZOO_BACKEND in get_args(BackendT)
14+
MODEL_ZOO_ENABLE_CONTEXT = os.getenv("ORT_MODEL_ZOO_ENABLE_CONTEXT", "1") == "1"
15+
MODEL_ZOO_ENABLE_CPU_FALLBACK = os.getenv("ORT_MODEL_ZOO_ENABLE_CPU_FALLBACK", "0") == "1"
1216

1317
for model_zoo_root in MODEL_ZOO_ROOTS:
1418
TEST_DEFS = list(
1519
ModelTestSuite(
1620
model_zoo_root,
17-
backend_type="htp",
21+
backend_type=MODEL_ZOO_BACKEND,
1822
rtol=None,
1923
atol=None,
20-
enable_context=True,
21-
enable_cpu_fallback=False,
24+
enable_context=MODEL_ZOO_ENABLE_CONTEXT,
25+
enable_cpu_fallback=MODEL_ZOO_ENABLE_CPU_FALLBACK,
2226
).tests
2327
)
2428

qcom/packages.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,9 @@ java_windows_x86_64:
9393
content_root: jdk-{version}
9494
bindir: bin
9595
model_zoo:
96-
version: 2025-10-08
96+
version: 2025-11-07
9797
url: http://ort-ep-win-01.na.qualcomm.com:8000/model-zoo/model-zoo-{version}.zip
98-
sha256: 1e49e5c570ee16c7bf62973dc8f52201f4caaa45dda0e74fdc190b55e980ffa7
98+
sha256: a0502c3f52202c602e1cae3d22658c6a9a052d1d3be25ad077c1280b03990777
9999
content_root: model-zoo
100100
ninja_linux_aarch64:
101101
version: 1.12.1

qcom/scripts/all/package_manager.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@
3535

3636
AUTOPRUNE = os.environ.get("ORT_BUILD_PRUNE_PACKAGES", "1") == "1"
3737

38-
DEFAULT_MAX_CACHE_SIZE_BYTES = int(os.environ.get("ORT_BUILD_PACKAGE_CACHE_SIZE", f"{7 * 1024 * 1024 * 1024}")) # 7 GiB
38+
DEFAULT_MAX_CACHE_SIZE_BYTES = int(
39+
os.environ.get("ORT_BUILD_PACKAGE_CACHE_SIZE", f"{10 * 1024 * 1024 * 1024}")
40+
) # 10 GiB
3941

4042
DEFAULT_TOOLS_DIR = Path(os.environ.get("ORT_BUILD_TOOLS_PATH", REPO_ROOT / "build" / "tools"))
4143

0 commit comments

Comments
 (0)