Skip to content

Commit d3ddd45

Browse files
Convert Shapes to lists, add extra dimension to images (#3455)
1 parent 9738d79 commit d3ddd45

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

demos/image_translation_demo/python/image_translation_demo/models.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ def __init__(self, core, model_path, device='CPU'):
2121

2222
inputs = [node.get_any_name() for node in model.inputs]
2323
self.input_semantics, self.reference_image, self.reference_semantics = inputs
24-
self.input_semantic_size = model.input(self.input_semantics).shape
25-
self.input_image_size = model.input(self.reference_image).shape
24+
self.input_semantic_size = list(model.input(self.input_semantics).shape)
25+
self.input_image_size = list(model.input(self.reference_image).shape)
2626

2727
compiled_model = core.compile_model(model, device)
2828
self.output_tensor = compiled_model.outputs[0]
@@ -46,7 +46,7 @@ def __init__(self, core, model_path, device='CPU'):
4646
raise RuntimeError("The SegmentationModel expects 1 output layer")
4747

4848
self.input_tensor_name = model.inputs[0].get_any_name()
49-
self.input_size = model.inputs[0].shape
49+
self.input_size = list(model.inputs[0].shape)
5050

5151
compiled_model = core.compile_model(model, device)
5252
self.output_tensor = compiled_model.outputs[0]

demos/image_translation_demo/python/image_translation_demo/preprocessing.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,16 +34,19 @@ def preprocess_semantics(semantic_mask, input_size):
3434
interpolation=cv2.INTER_NEAREST)
3535
# create one-hot label map
3636
semantic_mask = scatter(semantic_mask, classes=input_size[1])
37+
38+
if len(semantic_mask.shape) == 3:
39+
return np.expand_dims(semantic_mask, axis=0)
3740
return semantic_mask
3841

3942

4043
def preprocess_for_seg_model(image, input_size):
4144
image = cv2.resize(image, dsize=tuple(input_size[2:]), interpolation=cv2.INTER_LINEAR)
4245
image = np.transpose(image, (2, 0, 1))
43-
return image
46+
return np.expand_dims(image, axis=0)
4447

4548

4649
def preprocess_image(image, input_size):
4750
image = cv2.resize(image, dsize=tuple(input_size[2:]), interpolation=cv2.INTER_CUBIC)
4851
image = np.transpose(image, (2, 0, 1))
49-
return image
52+
return np.expand_dims(image, axis=0)

0 commit comments

Comments
 (0)