Skip to content

Commit da76eea

Browse files
authored
add provider hook to run_pretrained_models (#1642)
Signed-off-by: Guenther Schmuelling <[email protected]>
1 parent 4061ca9 commit da76eea

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

tests/run_pretrained_models.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -335,14 +335,19 @@ def run_onnxruntime(self, name, model_proto, inputs, outputs, external_tensor_st
335335
as_text=utils.is_debug_mode(),
336336
external_tensor_storage=external_tensor_storage)
337337
logger.info("Model saved to %s", model_path)
338+
providers = ['CPUExecutionProvider']
339+
if rt.get_device() == "GPU":
340+
gpus = os.environ.get("CUDA_VISIBLE_DEVICES")
341+
if gpus is None or len(gpus) > 1:
342+
providers = ['CUDAExecutionProvider']
343+
338344
opt = rt.SessionOptions()
339345
if self.use_custom_ops:
340346
from onnxruntime_extensions import get_library_path
341347
opt.register_custom_ops_library(get_library_path())
342-
m = rt.InferenceSession(model_path, opt)
343348
if self.ort_profile is not None:
344349
opt.enable_profiling = True
345-
m = rt.InferenceSession(model_path, opt)
350+
m = rt.InferenceSession(model_path, sess_options=opt, providers=providers)
346351
results = m.run(outputs, inputs)
347352
if self.perf:
348353
n = 0

0 commit comments

Comments
 (0)