Skip to content

Commit f97b4d8

Browse files
committed
refactor
1 parent 0e0c37a commit f97b4d8

File tree

2 files changed

+6
-8
lines changed

2 files changed

+6
-8
lines changed

tf2onnx/onnx_opset/nn.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -464,14 +464,14 @@ def version_7(cls, ctx, node, **kwargs):
464464

465465
@classmethod
466466
def version_9(cls, ctx, node, **kwargs):
467-
cls._convert_since_9(ctx, node, node_type="Upsample")
467+
cls._convert_since_9(ctx, node, op_type="Upsample")
468468

469469
@classmethod
470470
def version_10(cls, ctx, node, **kwargs):
471-
cls._convert_since_9(ctx, node, node_type="Resize")
471+
cls._convert_since_9(ctx, node, op_type="Resize")
472472

473473
@classmethod
474-
def _convert_since_9(cls, ctx, node, node_type):
474+
def _convert_since_9(cls, ctx, node, op_type):
475475

476476
# float32 out = ResizeBilinear/ResizeNearestNeighbor(T images, int size)
477477
# https://www.tensorflow.org/api_docs/python/tf/image/resize_nearest_neighbor
@@ -506,7 +506,7 @@ def _convert_since_9(cls, ctx, node, node_type):
506506
scales = ctx.make_node("Concat", [const_one_array.output[0], scales_hw.output[0]], {"axis": 0})
507507
# because onnxruntime only supports to scale the last two dims so transpose is inserted
508508
input_nchw = ctx.make_node("Transpose", [node.input[0]], {"perm": [0, 3, 1, 2]})
509-
upsample = ctx.make_node(node_type, [input_nchw.output[0], scales.output[0]], attr={"mode": mode})
509+
upsample = ctx.make_node(op_type, [input_nchw.output[0], scales.output[0]], attr={"mode": mode})
510510

511511
shapes = node.output_shapes
512512
dtypes = node.output_dtypes

tf2onnx/onnx_opset/tensor.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -953,10 +953,8 @@ def version_10(cls, ctx, node, **kwargs):
953953
shapes = [ctx.get_shape(node.output[0])]
954954
ctx.remove_node(node.name)
955955
new_nonmaxsurppress = ctx.make_node(node.type, node.input).output[0]
956-
attr = {"axes": [1], "ends": [3], "starts": [2]}
957-
inputs_map = {"data": new_nonmaxsurppress, **attr}
958-
slice_op = GraphBuilder(ctx).make_slice(inputs_map)
956+
slice_op = GraphBuilder(ctx).make_slice({"data": new_nonmaxsurppress,
957+
"axes": [1], "ends": [3], "starts": [2]})
959958
squeeze_op = ctx.make_node("Squeeze", [slice_op], attr={"axes": [1]})
960959
ctx.make_node("Cast", inputs=squeeze_op.output, attr={"to": onnx_pb.TensorProto.INT32},
961960
name=node.name, outputs=node.output, dtypes=dtypes, shapes=shapes)
962-

0 commit comments

Comments
 (0)