|
22 | 22 | from ..logging import warning
|
23 | 23 | from ..config import PathField, StringField, ListField, ConfigError
|
24 | 24 | from .launcher import Launcher
|
25 |
| -from ..utils import contains_all |
26 | 25 | from ..logging import print_info
|
27 | 26 |
|
28 | 27 |
|
@@ -54,7 +53,7 @@ def parameters(cls):
|
54 | 53 | parameters = super().parameters()
|
55 | 54 | parameters.update({
|
56 | 55 | 'model': PathField(description="Path to model.", file_or_directory=True),
|
57 |
| - 'device': StringField(regex=DEVICE_REGEX, description="Device name.", optional=True, default='CPU'), |
| 56 | + 'device': StringField(description="Device name.", optional=True, default=''), |
58 | 57 | 'execution_providers': ListField(
|
59 | 58 | value_type=StringField(description="Execution provider name.", ),
|
60 | 59 | default=['CPUExecutionProvider'], optional=True
|
@@ -106,17 +105,23 @@ def create_inference_session(self, model):
|
106 | 105 | return self._create_session_via_backend_api(model)
|
107 | 106 |
|
108 | 107 | def _create_session_via_execution_providers_api(self, model):
|
109 |
| - session_options = onnx_rt.SessionOptions() |
110 |
| - session = onnx_rt.InferenceSession(model, sess_options=session_options) |
111 | 108 | self.execution_providers = self.get_value_from_config('execution_providers')
|
112 |
| - available_providers = session.get_providers() |
113 |
| - contains_all(available_providers, self.execution_providers) |
114 |
| - session.set_providers(self.execution_providers) |
| 109 | + device = self.get_value_from_config('device') |
| 110 | + self.device = device or 'CPU' |
| 111 | + kwargs = {} |
| 112 | + if device: |
| 113 | + kwargs['provider_options'] = {[{'device_type': self.device}]} |
| 114 | + session = onnx_rt.InferenceSession( |
| 115 | + model, providers=self.execution_providers, **kwargs) |
115 | 116 |
|
116 | 117 | return session
|
117 | 118 |
|
118 | 119 | def _create_session_via_backend_api(self, model):
|
119 |
| - self.device = re.match(DEVICE_REGEX, self.get_value_from_config('device').lower()).group('device') |
| 120 | + device = self.get_value_from_config('device') or 'cpu' |
| 121 | + device_match = re.match(DEVICE_REGEX, device.lower()) |
| 122 | + if not device_match: |
| 123 | + raise ConfigError('unknown device: {}'.format(device)) |
| 124 | + self.device = device_match.group('device') |
120 | 125 | beckend_rep = backend.prepare(model=str(model), device=self.device.upper())
|
121 | 126 | return beckend_rep._session # pylint: disable=W0212
|
122 | 127 |
|
|
0 commit comments