Skip to content

Commit 6d5f998

Browse files
authored
solve the onnxruntime inference issue (#13154)
1 parent f8ca01d commit 6d5f998

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

tools/infer/utility.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,10 +197,18 @@ def create_predictor(args, mode, logger):
197197
raise ValueError("not find model file path {}".format(model_file_path))
198198
if args.use_gpu:
199199
sess = ort.InferenceSession(
200-
model_file_path, providers=["CUDAExecutionProvider"]
200+
model_file_path,
201+
providers=[
202+
(
203+
"CUDAExecutionProvider",
204+
{"device_id": args.gpu_id, "cudnn_conv_algo_search": "DEFAULT"},
205+
)
206+
],
201207
)
202208
else:
203-
sess = ort.InferenceSession(model_file_path)
209+
sess = ort.InferenceSession(
210+
model_file_path, providers=["CPUExecutionProvider"]
211+
)
204212
return sess, sess.get_inputs()[0], None, None
205213

206214
else:

0 commit comments

Comments
 (0)