Skip to content
This repository was archived by the owner on Nov 3, 2022. It is now read-only.

Commit 4be176a

Browse files
authored
Fix tests to deal with OOM error on CNTK (#6)
1 parent d68bbb5 commit 4be176a

File tree

1 file changed

+70
-32
lines changed

1 file changed

+70
-32
lines changed

tests/applications_test.py

Lines changed: 70 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
from keras.preprocessing import image
2929
from keras import backend
3030

31+
from multiprocessing import Process, Queue
32+
3133

3234
MOBILENET_LIST = [(mobilenet.MobileNet, 1024),
3335
(mobilenet_v2.MobileNetV2, 1280)]
@@ -38,43 +40,76 @@
3840
(nasnet.NASNetLarge, 4032)]
3941

4042

41-
@keras_test
42-
def _test_application_basic(app, last_dim=1000, module=None):
43-
if module is not None:
44-
weights = 'imagenet'
45-
else:
46-
weights = None
47-
model = app(weights=weights)
48-
assert model.output_shape == (None, last_dim)
49-
if module is None:
50-
return
51-
52-
img_path = 'tests/data/elephant.jpg'
53-
target_size = tuple(model.input_shape[1: 3])
43+
def _get_elephant(target_size):
5444
# For models that don't include a Flatten step,
5545
# the default is to accept variable-size inputs
5646
# even when loading ImageNet weights (since it is possible).
5747
# In this case, default to 299x299.
5848
if target_size[0] is None:
5949
target_size = (299, 299)
60-
img = image.load_img(img_path, target_size=target_size)
50+
img = image.load_img('tests/data/elephant.jpg',
51+
target_size=tuple(target_size))
6152
x = image.img_to_array(img)
62-
x = np.expand_dims(x, axis=0)
53+
return np.expand_dims(x, axis=0)
54+
55+
56+
def _get_output_shape(model_fn, preprocess_input=None):
57+
if backend.backend() == 'cntk':
58+
# Create model in a subprocess so that
59+
# the memory consumed by InceptionResNetV2 will be
60+
# released back to the system after this test
61+
# (to deal with OOM error on CNTK backend).
62+
# TODO: remove the use of multiprocessing from these tests
63+
# once a memory clearing mechanism
64+
# is implemented in the CNTK backend.
65+
def target(queue):
66+
model = model_fn()
67+
if preprocess_input is None:
68+
queue.put(model.output_shape)
69+
else:
70+
x = _get_elephant(model.input_shape[1:3])
71+
x = preprocess_input(x)
72+
queue.put((model.output_shape, model.predict(x)))
73+
queue = Queue()
74+
p = Process(target=target, args=(queue,))
75+
p.start()
76+
p.join()
77+
# The error in a subprocess won't propagate
78+
# to the main process, so we check if the model
79+
# is successfully created by checking if the output shape
80+
# has been put into the queue
81+
assert not queue.empty(), 'Model creation failed.'
82+
return queue.get_nowait()
83+
else:
84+
model = model_fn()
85+
if preprocess_input is None:
86+
return model.output_shape
87+
else:
88+
x = _get_elephant(model.input_shape[1:3])
89+
x = preprocess_input(x)
90+
return (model.output_shape, model.predict(x))
91+
6392

64-
preprocess_input = getattr(module, 'preprocess_input')
65-
decode_predictions = getattr(module, 'decode_predictions')
66-
x = preprocess_input(x)
93+
@keras_test
94+
def _test_application_basic(app, last_dim=1000, module=None):
95+
if module is None:
96+
output_shape = _get_output_shape(lambda: app(weights=None))
97+
assert output_shape == (None, None, None, last_dim)
98+
else:
99+
output_shape, preds = _get_output_shape(
100+
lambda: app(weights='imagenet'), module.preprocess_input)
101+
assert output_shape == (None, last_dim)
67102

68-
preds = model.predict(x)
69-
names = [p[1] for p in decode_predictions(preds)[0]]
70-
# Test correct label is in top 3 (weak correctness test).
71-
assert 'African_elephant' in names[:3]
103+
names = [p[1] for p in module.decode_predictions(preds)[0]]
104+
# Test correct label is in top 3 (weak correctness test).
105+
assert 'African_elephant' in names[:3]
72106

73107

74108
@keras_test
75109
def _test_application_notop(app, last_dim):
76-
model = app(weights=None, include_top=False)
77-
assert model.output_shape == (None, None, None, last_dim)
110+
output_shape = _get_output_shape(
111+
lambda: app(weights=None, include_top=False))
112+
assert output_shape == (None, None, None, last_dim)
78113

79114

80115
@keras_test
@@ -83,23 +118,26 @@ def _test_application_variable_input_channels(app, last_dim):
83118
input_shape = (1, None, None)
84119
else:
85120
input_shape = (None, None, 1)
86-
model = app(weights=None, include_top=False, input_shape=input_shape)
87-
assert model.output_shape == (None, None, None, last_dim)
121+
output_shape = _get_output_shape(
122+
lambda: app(weights=None, include_top=False, input_shape=input_shape))
123+
assert output_shape == (None, None, None, last_dim)
88124

89125
if backend.image_data_format() == 'channels_first':
90126
input_shape = (4, None, None)
91127
else:
92128
input_shape = (None, None, 4)
93-
model = app(weights=None, include_top=False, input_shape=input_shape)
94-
assert model.output_shape == (None, None, None, last_dim)
129+
output_shape = _get_output_shape(
130+
lambda: app(weights=None, include_top=False, input_shape=input_shape))
131+
assert output_shape == (None, None, None, last_dim)
95132

96133

97134
@keras_test
98135
def _test_app_pooling(app, last_dim):
99-
model = app(weights=None,
100-
include_top=False,
101-
pooling=random.choice(['avg', 'max']))
102-
assert model.output_shape == (None, last_dim)
136+
output_shape = _get_output_shape(
137+
lambda: app(weights=None,
138+
include_top=False,
139+
pooling=random.choice(['avg', 'max'])))
140+
assert output_shape == (None, last_dim)
103141

104142

105143
def test_resnet50():

0 commit comments

Comments
 (0)