33
44import functools
55import operator
6+ import os
67from collections .abc import Callable , Iterable , Mapping
78from pathlib import Path
89from typing import Literal
@@ -175,16 +176,17 @@ def make_wheel_test(self, tmpdir: Path) -> Task:
175176 )
176177
177178
178- class OrtWheelSmokeTestTask (OrtWheelTestTask ):
179+ class OrtWheelModelTestTask (OrtWheelTestTask ):
179180 def __init__ (
180181 self ,
181182 group_name : str | None ,
182183 venv : Path | None ,
183184 wheel_pe_arch : WheelPeArchT ,
184185 config : BuildConfigT ,
185186 py_version : TargetPyVersionT ,
187+ test_files_or_dirs : list [str ],
188+ get_test_env : Callable [[], Mapping [str , str ]],
186189 ) -> None :
187- self .__venv = venv
188190 self .__wheel_pe_arch = wheel_pe_arch
189191 self .__config = config
190192 self .__py_version = py_version
@@ -195,11 +197,8 @@ def __init__(
195197 wheel_pe_arch ,
196198 py_version ,
197199 self .__find_wheel ,
198- [
199- str (REPO_ROOT / "qcom" / "model_test" / "smoke_test.py" ),
200- str (REPO_ROOT / "qcom" / "model_test" / "model_zoo_test.py" ),
201- ],
202- get_test_env = self .__get_test_env ,
200+ test_files_or_dirs ,
201+ get_test_env ,
203202 )
204203
205204 def __find_wheel (self ) -> Path :
@@ -231,12 +230,58 @@ def __find_wheel(self) -> Path:
231230 raise FileNotFoundError ("Could not find onnxruntime wheel." )
232231 return found_wheels [0 ]
233232
234- def __get_test_env (self ) -> Mapping [str , str ]:
235- """Get an environment that tells the tests where to find their models."""
236- return {
237- "ORT_WHEEL_SMOKE_TEST_ROOT" : str (get_onnx_models_root (self .__venv ) / "testdata" / "smoke" ),
238- "ORT_MODEL_ZOO_TEST_ROOTS" : str (get_model_zoo_root (self .__venv ) / "winml-cert" ),
239- }
233+
234+ class OrtWheelSmokeTestTask (OrtWheelModelTestTask ):
235+ def __init__ (
236+ self ,
237+ group_name : str | None ,
238+ venv : Path | None ,
239+ wheel_pe_arch : WheelPeArchT ,
240+ config : BuildConfigT ,
241+ py_version : TargetPyVersionT ,
242+ ) -> None :
243+ super ().__init__ (
244+ group_name ,
245+ venv ,
246+ wheel_pe_arch ,
247+ config ,
248+ py_version ,
249+ [
250+ str (REPO_ROOT / "qcom" / "model_test" / "smoke_test.py" ),
251+ str (REPO_ROOT / "qcom" / "model_test" / "model_zoo_test.py" ),
252+ ],
253+ get_test_env = lambda : {
254+ ** os .environ ,
255+ "ORT_WHEEL_SMOKE_TEST_ROOT" : str (get_onnx_models_root (venv ) / "testdata" / "smoke" ),
256+ "ORT_MODEL_ZOO_TEST_ROOTS" : str (get_model_zoo_root (venv ) / "winml-cert" ),
257+ },
258+ )
259+
260+
261+ class OrtWheelGpuModelTestTask (OrtWheelModelTestTask ):
262+ def __init__ (
263+ self ,
264+ group_name : str | None ,
265+ venv : Path | None ,
266+ wheel_pe_arch : WheelPeArchT ,
267+ config : BuildConfigT ,
268+ py_version : TargetPyVersionT ,
269+ ) -> None :
270+ super ().__init__ (
271+ group_name ,
272+ venv ,
273+ wheel_pe_arch ,
274+ config ,
275+ py_version ,
276+ [
277+ str (REPO_ROOT / "qcom" / "model_test" / "model_zoo_test.py" ),
278+ ],
279+ get_test_env = lambda : {
280+ ** os .environ ,
281+ "ORT_MODEL_ZOO_TEST_ROOTS" : str (get_model_zoo_root (venv ) / "winml-cert-gpu" ),
282+ "ORT_MODEL_ZOO_BACKEND" : "gpu" ,
283+ },
284+ )
240285
241286
242287class RunLinterTask (CompositeTask ):
0 commit comments