Skip to content

Commit 280b59f

Browse files
authored
AC: update work onnxrt launcher for EP (#2999)
1 parent 973300b commit 280b59f

File tree

1 file changed

+13
-8
lines changed

1 file changed

+13
-8
lines changed

tools/accuracy_checker/openvino/tools/accuracy_checker/launcher/onnx_launcher.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from ..logging import warning
2323
from ..config import PathField, StringField, ListField, ConfigError
2424
from .launcher import Launcher
25-
from ..utils import contains_all
2625
from ..logging import print_info
2726

2827

@@ -54,7 +53,7 @@ def parameters(cls):
5453
parameters = super().parameters()
5554
parameters.update({
5655
'model': PathField(description="Path to model.", file_or_directory=True),
57-
'device': StringField(regex=DEVICE_REGEX, description="Device name.", optional=True, default='CPU'),
56+
'device': StringField(description="Device name.", optional=True, default=''),
5857
'execution_providers': ListField(
5958
value_type=StringField(description="Execution provider name.", ),
6059
default=['CPUExecutionProvider'], optional=True
@@ -106,17 +105,23 @@ def create_inference_session(self, model):
106105
return self._create_session_via_backend_api(model)
107106

108107
def _create_session_via_execution_providers_api(self, model):
109-
session_options = onnx_rt.SessionOptions()
110-
session = onnx_rt.InferenceSession(model, sess_options=session_options)
111108
self.execution_providers = self.get_value_from_config('execution_providers')
112-
available_providers = session.get_providers()
113-
contains_all(available_providers, self.execution_providers)
114-
session.set_providers(self.execution_providers)
109+
device = self.get_value_from_config('device')
110+
self.device = device or 'CPU'
111+
kwargs = {}
112+
if device:
113+
kwargs['provider_options'] = {[{'device_type': self.device}]}
114+
session = onnx_rt.InferenceSession(
115+
model, providers=self.execution_providers, **kwargs)
115116

116117
return session
117118

118119
def _create_session_via_backend_api(self, model):
119-
self.device = re.match(DEVICE_REGEX, self.get_value_from_config('device').lower()).group('device')
120+
device = self.get_value_from_config('device') or 'cpu'
121+
device_match = re.match(DEVICE_REGEX, device.lower())
122+
if not device_match:
123+
raise ConfigError('unknown device: {}'.format(device))
124+
self.device = device_match.group('device')
120125
beckend_rep = backend.prepare(model=str(model), device=self.device.upper())
121126
return beckend_rep._session # pylint: disable=W0212
122127

0 commit comments

Comments
 (0)