Skip to content

Commit ba9932b

Browse files
committed
Pad output to fixed length if batch size > 1.
Fix #349
1 parent cf91771 commit ba9932b

File tree

2 files changed

+17
-6
lines changed

2 files changed

+17
-6
lines changed

efficientdet/inference.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,10 @@ def det_post_process(params: Dict[Any, Any], cls_outputs: Dict[int, tf.Tensor],
321321
min_score_thresh=min_score_thresh,
322322
max_boxes_to_draw=max_boxes_to_draw,
323323
disable_pyfun=params.get('disable_pyfun'))
324+
if params['batch_size'] > 1:
325+
# pad to fixed length if batch size > 1.
326+
padding_size = max_boxes_to_draw - tf.shape(detections)[0]
327+
detections = tf.pad(detections, [[0, padding_size], [0, 0]])
324328
detections_batch.append(detections)
325329
return tf.stack(detections_batch, name='detections')
326330

efficientdet/model_inspect.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -104,11 +104,11 @@ def __init__(self,
104104
self.batch_size = batch_size or None
105105
self.labels_shape = [batch_size, model_config.num_classes]
106106

107-
width, height = model_config.image_size
107+
height, width = model_config.image_size
108108
if model_config.data_format == 'channels_first':
109-
self.inputs_shape = [batch_size, 3, width, height]
109+
self.inputs_shape = [batch_size, 3, height, width]
110110
else:
111-
self.inputs_shape = [batch_size, width, height, 3]
111+
self.inputs_shape = [batch_size, height, width, 3]
112112

113113
self.model_config = model_config
114114

@@ -161,13 +161,20 @@ def saved_model_inference(self, image_path_pattern, output_dir, **kwargs):
161161
batch_size = self.batch_size or 1
162162
all_files = list(tf.io.gfile.glob(image_path_pattern))
163163
print('all_files=', all_files)
164-
num_batches = len(all_files) // batch_size
164+
num_batches = (len(all_files) + batch_size - 1) // batch_size
165165

166166
for i in range(num_batches):
167167
batch_files = all_files[i * batch_size: (i + 1) * batch_size]
168-
raw_images = [np.array(Image.open(f)) for f in batch_files]
168+
height, width = self.model_config.image_size
169+
raw_images = [np.array(Image.open(f).resize((width, height)))
170+
for f in batch_files]
171+
size_before_pad = len(raw_images)
172+
if size_before_pad < batch_size:
173+
padding_size = batch_size - size_before_pad
174+
raw_images += [np.zeros_like(raw_images[0])] * padding_size
175+
169176
detections_bs = driver.serve_images(raw_images)
170-
for j in range(len(raw_images)):
177+
for j in range(size_before_pad):
171178
img = driver.visualize(raw_images[j], detections_bs[j], **kwargs)
172179
img_id = str(i * batch_size + j)
173180
output_image_path = os.path.join(output_dir, img_id + '.jpg')

0 commit comments

Comments
 (0)