Skip to content

Commit d91315c

Browse files
Merge pull request #874 from RandySheriffH/rashuai/RefactorResize
Rashuai/refactor resize
2 parents aef2a5e + 63cd1a9 commit d91315c

File tree

1 file changed

+24
-1
lines changed

1 file changed

+24
-1
lines changed

tf2onnx/onnx_opset/nn.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -721,7 +721,30 @@ def version_10(cls, ctx, node, **kwargs):
721721

722722
@classmethod
723723
def version_11(cls, ctx, node, **kwargs):
724-
cls._convert_since_9(ctx, node, op_type="Resize", use_target_size=True)
724+
mode = "linear" if node.type == "ResizeBilinear" else "nearest"
725+
roi = ctx.make_const(utils.make_name("roi"), np.array([]).astype(np.float32))
726+
const_zero = ctx.make_const(utils.make_name("const_zero"), np.array([0]).astype(np.int64))
727+
const_two = ctx.make_const(utils.make_name("const_two"), np.array([2]).astype(np.int64))
728+
const_empty_float = ctx.make_const(utils.make_name("const_empty_float"), np.array([]).astype(np.float32))
729+
input_nchw = ctx.make_node("Transpose", [node.input[0]], {"perm": constants.NHWC_TO_NCHW})
730+
shape_input = ctx.make_node("Shape", [input_nchw.output[0]])
731+
sliced_shape = ctx.make_node("Slice", [shape_input.output[0], const_zero.output[0], const_two.output[0]])
732+
size_int64 = ctx.make_node("Cast", [node.input[1]], attr={"to": onnx_pb.TensorProto.INT64})
733+
concat_shape = ctx.make_node("Concat", [sliced_shape.output[0], size_int64.output[0]], {'axis': 0})
734+
resize_inputs = [
735+
input_nchw.output[0],
736+
roi.output[0],
737+
const_empty_float.output[0],
738+
concat_shape.output[0]
739+
]
740+
resize = ctx.make_node("Resize", resize_inputs,
741+
attr={"mode": mode, "nearest_mode": "floor",
742+
"coordinate_transformation_mode": "asymmetric"})
743+
shapes = node.output_shapes
744+
dtypes = node.output_dtypes
745+
ctx.remove_node(node.name)
746+
ctx.make_node("Transpose", resize.output, {"perm": constants.NCHW_TO_NHWC},
747+
name=node.name, outputs=node.output, shapes=shapes, dtypes=dtypes)
725748

726749
@classmethod
727750
def _convert_since_9(cls, ctx, node, op_type, use_target_size=False):

0 commit comments

Comments
 (0)