Skip to content

Commit a09d1ca

Browse files
authored
Merge pull request #776 from Deepomatic/dev-resize-by-size
Use the 'sizes' parameters of Resize op instead of 'scales'
2 parents 911d78a + 1f3b9f2 commit a09d1ca

File tree

1 file changed

+62
-35
lines changed

1 file changed

+62
-35
lines changed

tf2onnx/onnx_opset/nn.py

Lines changed: 62 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,28 @@ def conv_kernel_shape(ctx, node, input_idx, spatial=2):
192192
return kernel_shape
193193

194194

195+
def build_dynamic_target_size(ctx, transposed_intput, target_hw):
196+
"""
197+
Build the target tensor shape for the Resize op.
198+
199+
Args:
200+
- ctx: the graph context
201+
- transposed_intput: A tensor of rank 4 of shape [n c h w]
202+
- target_hw: tensor of rank 2 containing the target size for a resize: [nh nw]
203+
204+
Returns:
205+
A tensor of rank 2 containing [n c nh nw]
206+
"""
207+
# We get the first half [n c] of the target shape
208+
shape_of_transposed_input = ctx.make_node("Shape", [transposed_intput.output[0]])
209+
first_half_of_shape = GraphBuilder(ctx).make_slice(
210+
{"data": shape_of_transposed_input.output[0], "ends": [2], "starts": [0]})
211+
target_size_int64 = ctx.make_node("Cast", [target_hw.output[0]], attr={'to': TensorProto.INT64})
212+
# We build a tensor containing [n c nh nw]
213+
final_target_size = ctx.make_node("Concat", [first_half_of_shape, target_size_int64.output[0]], {'axis': 0})
214+
return final_target_size
215+
216+
195217
@tf_op(["Conv1D", "Conv2D", "Conv3D"])
196218
class ConvOp:
197219
@classmethod
@@ -594,15 +616,12 @@ def version_11(cls, ctx, node, **kwargs):
594616
target_x = g.make_node("Slice", [input_x.output[0], box_index_from.output[0], box_index_to.output[0],
595617
const_zero.output[0]], name="Slice_b")
596618
transposed_x = g.make_node("Transpose", [target_x.output[0]], attr={'perm': constants.NHWC_TO_NCHW})
597-
shape_of_transposed_x = g.make_node("Shape", [transposed_x.output[0]])
598619
const_zero_zero = g.make_const(utils.make_name(node.name + "_const_zero_zero"),
599620
np.array([0, 0], dtype=np.float32))
600621
const_one_one = g.make_const(utils.make_name(node.name + "_const_one_one"),
601622
np.array([1, 1], dtype=np.float32))
602623
const_four = g.make_const(utils.make_name(node.name + "_const_four"), np.array([4], dtype=np.int64))
603624
const_empty_float = g.make_const(utils.make_name("const_empty_float"), np.array([], dtype=np.float32))
604-
first_half_of_shape = GraphBuilder(g).make_slice(
605-
{"data": shape_of_transposed_x.output[0], "ends": [2], "starts": [0]})
606625
box = g.make_node("Slice", [boxes.output[0], trip_name, index_end.output[0], const_zero_long.output[0]],
607626
name="Slice_c")
608627
roi_raw = g.make_node("Reshape", [box.output[0], const_four.output[0]])
@@ -611,8 +630,7 @@ def version_11(cls, ctx, node, **kwargs):
611630
roi_concat_1 = g.make_node("Concat", [const_zero_zero.output[0], roi_raw_first_half], attr={'axis': 0})
612631
roi_concat_2 = g.make_node("Concat", [const_one_one.output[0], roi_raw_second_half], attr={'axis': 0})
613632
final_roi = g.make_node("Concat", [roi_concat_1.output[0], roi_concat_2.output[0]], attr={'axis': 0})
614-
crop_size_int64 = g.make_node("Cast", [crop_size.output[0]], attr={'to': TensorProto.INT64})
615-
final_crop_size = g.make_node("Concat", [first_half_of_shape, crop_size_int64.output[0]], {'axis': 0})
633+
final_crop_size = build_dynamic_target_size(g, transposed_x, crop_size)
616634
resized_x = g.make_node("Resize", [transposed_x.output[0], final_roi.output[0], const_empty_float.output[0],
617635
final_crop_size.output[0]],
618636
attr={"mode": mode, "extrapolation_value": extrapolation_value,
@@ -661,50 +679,59 @@ def version_10(cls, ctx, node, **kwargs):
661679

662680
@classmethod
663681
def version_11(cls, ctx, node, **kwargs):
664-
cls._convert_since_9(ctx, node, op_type="Resize", roi_required=True)
682+
cls._convert_since_9(ctx, node, op_type="Resize", use_target_size=True)
665683

666684
@classmethod
667-
def _convert_since_9(cls, ctx, node, op_type, roi_required=False):
685+
def _convert_since_9(cls, ctx, node, op_type, use_target_size=False):
668686

669687
# float32 out = ResizeBilinear/ResizeNearestNeighbor(T images, int size)
670688
# https://www.tensorflow.org/api_docs/python/tf/image/resize_nearest_neighbor
671689
# wants the input to be NHWC - adjust target_shape to this.
672690
mode = "linear" if node.type == "ResizeBilinear" else "nearest"
673691

674-
# first create "scales" info for onnx upsample
675-
# if shape of input and output known then "scale" is calculated statically and set as a const node
676-
shape = ctx.get_shape(node.input[0])
677-
if shape and shape[2] != -1 and shape[1] != -1 and node.inputs[1].is_const():
678-
target_shape = node.inputs[1].get_tensor_value()
679-
n, h, w, c = shape
680-
nh, nw = target_shape
681-
# scales is nchw
682-
# the reason not storing data at raw field is because of the bug: https://github.com/onnx/onnx/issues/1852
683-
scale_val = np.array([1.0, 1.0, float(nh) / h, float(nw) / w]).astype(np.float32)
684-
scales = ctx.make_const(utils.make_name("scales"), scale_val, raw=False)
685-
else:
686-
ori_shape = ctx.make_node("Shape", [node.input[0]])
687-
attr = {"axes": [0], "starts": [1], "ends": [3]}
688-
inputs_map = {"data": ori_shape.output[0], **attr}
689-
ori_shape_hw = GraphBuilder(ctx).make_slice(inputs_map)
690-
ori_shape_hw_float = ctx.make_node("Cast", [ori_shape_hw], attr={"to": onnx_pb.TensorProto.FLOAT})
691-
692-
target_hw = node.inputs[1]
693-
target_hw_float = ctx.make_node("Cast", target_hw.output, attr={"to": onnx_pb.TensorProto.FLOAT})
694-
695-
scales_hw = ctx.make_node("Div", [target_hw_float.output[0], ori_shape_hw_float.output[0]])
696-
697-
const_one_array = ctx.make_const(utils.make_name("one"), np.array([1.0, 1.0]).astype(np.float32))
698-
# scales is nchw
699-
scales = ctx.make_node("Concat", [const_one_array.output[0], scales_hw.output[0]], {"axis": 0})
700692
# because onnxruntime only supports to scale the last two dims so transpose is inserted
701693
input_nchw = ctx.make_node("Transpose", [node.input[0]], {"perm": constants.NHWC_TO_NCHW})
702-
if roi_required:
694+
if use_target_size:
695+
final_target_size = build_dynamic_target_size(ctx, input_nchw, node.inputs[1])
703696
roi = ctx.make_const(utils.make_name("roi"), np.array([]).astype(np.float32))
704-
upsample = ctx.make_node("Resize", [input_nchw.output[0], roi.output[0], scales.output[0]],
697+
const_empty_float = ctx.make_const(utils.make_name("const_empty_float"), np.array([], dtype=np.float32))
698+
resize_inputs = [
699+
input_nchw.output[0],
700+
roi.output[0],
701+
const_empty_float.output[0],
702+
final_target_size.output[0]
703+
]
704+
upsample = ctx.make_node("Resize", resize_inputs,
705705
attr={"mode": mode, "nearest_mode": "floor",
706706
"coordinate_transformation_mode": "asymmetric"})
707707
else:
708+
# first create "scales" info for onnx upsample
709+
# if shape of input and output known then "scale" is calculated statically and set as a const node
710+
shape = ctx.get_shape(node.input[0])
711+
if shape and shape[2] != -1 and shape[1] != -1 and node.inputs[1].is_const():
712+
target_shape = node.inputs[1].get_tensor_value()
713+
n, h, w, c = shape
714+
nh, nw = target_shape
715+
# scales is nchw
716+
# the reason not storing data at raw field is because of the bug:
717+
# https://github.com/onnx/onnx/issues/1852
718+
scale_val = np.array([1.0, 1.0, float(nh) / h, float(nw) / w]).astype(np.float32)
719+
scales = ctx.make_const(utils.make_name("scales"), scale_val, raw=False)
720+
else:
721+
ori_shape = ctx.make_node("Shape", [node.input[0]])
722+
attr = {"axes": [0], "starts": [1], "ends": [3]}
723+
inputs_map = {"data": ori_shape.output[0], **attr}
724+
ori_shape_hw = GraphBuilder(ctx).make_slice(inputs_map)
725+
ori_shape_hw_float = ctx.make_node("Cast", [ori_shape_hw], attr={"to": onnx_pb.TensorProto.FLOAT})
726+
727+
target_hw = node.inputs[1]
728+
target_hw_float = ctx.make_node("Cast", target_hw.output, attr={"to": onnx_pb.TensorProto.FLOAT})
729+
730+
scales_hw = ctx.make_node("Div", [target_hw_float.output[0], ori_shape_hw_float.output[0]])
731+
732+
const_one_array = ctx.make_const(utils.make_name("one"), np.array([1.0, 1.0]).astype(np.float32))
733+
# scales is nchw
734+
scales = ctx.make_node("Concat", [const_one_array.output[0], scales_hw.output[0]], {"axis": 0})
708735
upsample = ctx.make_node(op_type, [input_nchw.output[0], scales.output[0]], attr={"mode": mode})
709736

710737
shapes = node.output_shapes

0 commit comments

Comments
 (0)