2828from keras .preprocessing import image
2929from keras import backend
3030
31+ from multiprocessing import Process , Queue
32+
3133
3234MOBILENET_LIST = [(mobilenet .MobileNet , 1024 ),
3335 (mobilenet_v2 .MobileNetV2 , 1280 )]
3840 (nasnet .NASNetLarge , 4032 )]
3941
4042
41- @keras_test
42- def _test_application_basic (app , last_dim = 1000 , module = None ):
43- if module is not None :
44- weights = 'imagenet'
45- else :
46- weights = None
47- model = app (weights = weights )
48- assert model .output_shape == (None , last_dim )
49- if module is None :
50- return
51-
52- img_path = 'tests/data/elephant.jpg'
53- target_size = tuple (model .input_shape [1 : 3 ])
43+ def _get_elephant (target_size ):
5444 # For models that don't include a Flatten step,
5545 # the default is to accept variable-size inputs
5646 # even when loading ImageNet weights (since it is possible).
5747 # In this case, default to 299x299.
5848 if target_size [0 ] is None :
5949 target_size = (299 , 299 )
60- img = image .load_img (img_path , target_size = target_size )
50+ img = image .load_img ('tests/data/elephant.jpg' ,
51+ target_size = tuple (target_size ))
6152 x = image .img_to_array (img )
62- x = np .expand_dims (x , axis = 0 )
53+ return np .expand_dims (x , axis = 0 )
54+
55+
56+ def _get_output_shape (model_fn , preprocess_input = None ):
57+ if backend .backend () == 'cntk' :
58+ # Create model in a subprocess so that
59+ # the memory consumed by InceptionResNetV2 will be
60+ # released back to the system after this test
61+ # (to deal with OOM error on CNTK backend).
62+ # TODO: remove the use of multiprocessing from these tests
63+ # once a memory clearing mechanism
64+ # is implemented in the CNTK backend.
65+ def target (queue ):
66+ model = model_fn ()
67+ if preprocess_input is None :
68+ queue .put (model .output_shape )
69+ else :
70+ x = _get_elephant (model .input_shape [1 :3 ])
71+ x = preprocess_input (x )
72+ queue .put ((model .output_shape , model .predict (x )))
73+ queue = Queue ()
74+ p = Process (target = target , args = (queue ,))
75+ p .start ()
76+ p .join ()
77+ # The error in a subprocess won't propagate
78+ # to the main process, so we check if the model
79+ # is successfully created by checking if the output shape
80+ # has been put into the queue
81+ assert not queue .empty (), 'Model creation failed.'
82+ return queue .get_nowait ()
83+ else :
84+ model = model_fn ()
85+ if preprocess_input is None :
86+ return model .output_shape
87+ else :
88+ x = _get_elephant (model .input_shape [1 :3 ])
89+ x = preprocess_input (x )
90+ return (model .output_shape , model .predict (x ))
91+
6392
64- preprocess_input = getattr (module , 'preprocess_input' )
65- decode_predictions = getattr (module , 'decode_predictions' )
66- x = preprocess_input (x )
93+ @keras_test
94+ def _test_application_basic (app , last_dim = 1000 , module = None ):
95+ if module is None :
96+ output_shape = _get_output_shape (lambda : app (weights = None ))
97+ assert output_shape == (None , None , None , last_dim )
98+ else :
99+ output_shape , preds = _get_output_shape (
100+ lambda : app (weights = 'imagenet' ), module .preprocess_input )
101+ assert output_shape == (None , last_dim )
67102
68- preds = model .predict (x )
69- names = [p [1 ] for p in decode_predictions (preds )[0 ]]
70- # Test correct label is in top 3 (weak correctness test).
71- assert 'African_elephant' in names [:3 ]
103+ names = [p [1 ] for p in module .decode_predictions (preds )[0 ]]
104+ # Test correct label is in top 3 (weak correctness test).
105+ assert 'African_elephant' in names [:3 ]
72106
73107
74108@keras_test
75109def _test_application_notop (app , last_dim ):
76- model = app (weights = None , include_top = False )
77- assert model .output_shape == (None , None , None , last_dim )
110+ output_shape = _get_output_shape (
111+ lambda : app (weights = None , include_top = False ))
112+ assert output_shape == (None , None , None , last_dim )
78113
79114
80115@keras_test
@@ -83,23 +118,26 @@ def _test_application_variable_input_channels(app, last_dim):
83118 input_shape = (1 , None , None )
84119 else :
85120 input_shape = (None , None , 1 )
86- model = app (weights = None , include_top = False , input_shape = input_shape )
87- assert model .output_shape == (None , None , None , last_dim )
121+ output_shape = _get_output_shape (
122+ lambda : app (weights = None , include_top = False , input_shape = input_shape ))
123+ assert output_shape == (None , None , None , last_dim )
88124
89125 if backend .image_data_format () == 'channels_first' :
90126 input_shape = (4 , None , None )
91127 else :
92128 input_shape = (None , None , 4 )
93- model = app (weights = None , include_top = False , input_shape = input_shape )
94- assert model .output_shape == (None , None , None , last_dim )
129+ output_shape = _get_output_shape (
130+ lambda : app (weights = None , include_top = False , input_shape = input_shape ))
131+ assert output_shape == (None , None , None , last_dim )
95132
96133
97134@keras_test
98135def _test_app_pooling (app , last_dim ):
99- model = app (weights = None ,
100- include_top = False ,
101- pooling = random .choice (['avg' , 'max' ]))
102- assert model .output_shape == (None , last_dim )
136+ output_shape = _get_output_shape (
137+ lambda : app (weights = None ,
138+ include_top = False ,
139+ pooling = random .choice (['avg' , 'max' ])))
140+ assert output_shape == (None , last_dim )
103141
104142
105143def test_resnet50 ():
0 commit comments