44import onnxruntime as ort
55
66class OnnxInferenceModel :
7- """ Base class for all inference models that use onnxruntime """
7+ """ Base class for all inference models that use onnxruntime
8+
9+ Attributes:
10+ model_path (str, optional): Path to the model folder. Defaults to "".
11+ force_cpu (bool, optional): Force the model to run on CPU or GPU. Defaults to GPU.
12+ default_model_name (str, optional): Default model name. Defaults to "model.onnx".
13+ """
814 def __init__ (
915 self ,
1016 model_path : str = "" ,
1117 force_cpu : bool = False ,
1218 default_model_name : str = "model.onnx"
1319 ):
14- """ Initialize the model
1520
16- Args:
17- model_path (str, optional): Path to the model. Defaults to "".
18- force_cpu (bool, optional): Force the model to run on CPU or GPU. Defaults to GPU.
19- default_model_name (str, optional): Default model name. Defaults to "model.onnx".
20- """
2121 self .model_path = model_path
2222 self .force_cpu = force_cpu
2323 self .default_model_name = default_model_name
@@ -28,9 +28,14 @@ def __init__(
2828 if not os .path .exists (self .model_path ):
2929 raise Exception (f"Model path ({ self .model_path } ) does not exist" )
3030
31- providers = ['CUDAExecutionProvider' ] if ort .get_device () == "GPU" and not force_cpu else ['CPUExecutionProvider' ]
31+ providers = ['CUDAExecutionProvider' , 'CPUExecutionProvider' ] if ort .get_device () == "GPU" and not force_cpu else ['CPUExecutionProvider' ]
3232
3333 self .model = ort .InferenceSession (self .model_path , providers = providers )
34+
35+ # Update providers priority to only CPUExecutionProvider
36+ if self .force_cpu :
37+ self .model .set_providers (['CPUExecutionProvider' ])
38+
3439 self .input_shape = self .model .get_inputs ()[0 ].shape [1 :]
3540 self .input_name = self .model ._inputs_meta [0 ].name
3641 self .output_name = self .model ._outputs_meta [0 ].name
0 commit comments