Skip to content

Commit 0b09b7d

Browse files
authored
AC: MXNet isolated import (#1599)
1 parent 204a24b commit 0b09b7d

File tree

2 files changed

+13
-10
lines changed

2 files changed

+13
-10
lines changed

tools/accuracy_checker/accuracy_checker/launcher/mxnet_launcher.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
11
"""
22
Copyright (c) 2018-2020 Intel Corporation
3-
43
Licensed under the Apache License, Version 2.0 (the "License");
54
you may not use this file except in compliance with the License.
65
You may obtain a copy of the License at
7-
86
http://www.apache.org/licenses/LICENSE-2.0
9-
107
Unless required by applicable law or agreed to in writing, software
118
distributed under the License is distributed on an "AS IS" BASIS,
129
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -19,7 +16,6 @@
1916
from pathlib import Path
2017

2118
import numpy as np
22-
import mxnet
2319

2420
from .launcher import Launcher, LauncherConfigValidator, ListInputsField
2521
from ..config import PathField, StringField, NumberField, ConfigError
@@ -58,6 +54,13 @@ def parameters(cls):
5854
return parameters
5955

6056
def __init__(self, config_entry: dict, *args, **kwargs):
57+
try:
58+
import mxnet # pylint: disable=C0415
59+
self.mxnet = mxnet
60+
except ImportError as import_error:
61+
raise ValueError(
62+
"MXNet isn't installed. Please, install it before using. \n{}".format(import_error.msg)
63+
)
6164
super().__init__(config_entry, *args, **kwargs)
6265
self._delayed_model_loading = kwargs.get('delayed_model_loading', False)
6366

@@ -78,9 +81,9 @@ def __init__(self, config_entry: dict, *args, **kwargs):
7881
identifier = match.group('identifier')
7982
if identifier is None:
8083
identifier = 0
81-
device_context = mxnet.gpu(int(identifier))
84+
device_context = self.mxnet.gpu(int(identifier))
8285
else:
83-
device_context = mxnet.cpu()
86+
device_context = self.mxnet.cpu()
8487

8588
# Get batch from config or 1
8689
self._batch = self.config.get('batch', 1)
@@ -109,7 +112,7 @@ def batch(self):
109112

110113
def fit_to_input(self, data, input_layer, layout, precision):
111114
data = np.transpose(data, layout)
112-
return mxnet.nd.array(data.astype(precision) if precision else data)
115+
return self.mxnet.nd.array(data.astype(precision) if precision else data)
113116

114117
@property
115118
def inputs(self):
@@ -125,9 +128,9 @@ def predict(self, inputs, metadata=None, **kwargs):
125128
"""
126129
results = []
127130
for infer_input in inputs:
128-
data_iter = mxnet.io.NDArrayIter(
131+
data_iter = self.mxnet.io.NDArrayIter(
129132
data=infer_input, label=None, batch_size=self.batch)
130-
data_batch = mxnet.io.DataBatch(data=data_iter.data_list)
133+
data_batch = self.mxnet.io.DataBatch(data=data_iter.data_list)
131134

132135
# Infer
133136
self.module.forward(data_batch)

tools/accuracy_checker/tests/test_mxnet_launcher.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"""
1616

1717
import pytest
18-
pytest.importorskip('accuracy_checker.launcher.mxnet_launcher')
18+
pytest.importorskip('mxnet')
1919
import cv2
2020
import numpy as np
2121

0 commit comments

Comments
 (0)