|
13 | 13 | import sys
|
14 | 14 |
|
15 | 15 | import numpy as np
|
16 |
| -from onnx import onnx_pb |
| 16 | +from onnx import onnx_pb, helper |
17 | 17 | from onnx.onnx_pb import TensorProto
|
18 | 18 |
|
19 | 19 | from tf2onnx import constants, utils
|
@@ -1485,13 +1485,18 @@ def version_10(cls, ctx, node, **kwargs):
|
1485 | 1485 | outputs=[node.output[1]], dtypes=dtypes[1], shapes=shapes[1])
|
1486 | 1486 | elif "pad_to_max_output_size" in node.attr: # V4
|
1487 | 1487 | shape_op = ctx.make_node("Shape", inputs=[squeeze_op.output[0]])
|
1488 |
| - const_zero = ctx.make_const(utils.make_name("const_zero"), np.array([0], dtype=np.int64)) |
1489 | 1488 | sub_op = ctx.make_node("Sub", inputs=[max_output_size, shape_op.output[0]])
|
1490 |
| - raw_pad = ctx.make_node("Concat", inputs=[const_zero.output[0], sub_op.output[0]], attr={'axis': 0}) |
1491 |
| - raw_pad_float = ctx.make_node("Cast", inputs=[raw_pad.output[0]], attr={"to": onnx_pb.TensorProto.FLOAT}) |
| 1489 | + raw_pad_float = ctx.make_node("Cast", inputs=[sub_op.output[0]], attr={"to": onnx_pb.TensorProto.FLOAT}) |
1492 | 1490 | relu_op = ctx.make_node("Relu", inputs=[raw_pad_float.output[0]])
|
1493 |
| - pad_val = ctx.make_node("Cast", inputs=[relu_op.output[0]], attr={"to": onnx_pb.TensorProto.INT64}) |
1494 |
| - pad_op = ctx.make_node("Pad", inputs=[squeeze_op.output[0], pad_val.output[0]]) |
| 1491 | + pad_amt = ctx.make_node("Cast", inputs=[relu_op.output[0]], attr={"to": onnx_pb.TensorProto.INT64}) |
| 1492 | + if ctx.opset <= 10: # Dynamic padding not supported before opset 11 |
| 1493 | + zero_tensor = helper.make_tensor("value", onnx_pb.TensorProto.INT64, dims=[1], vals=[0]) |
| 1494 | + padding = ctx.make_node("ConstantOfShape", inputs=[pad_amt.output[0]], attr={"value": zero_tensor}) |
| 1495 | + pad_op = ctx.make_node("Concat", inputs=[squeeze_op.output[0], padding.output[0]], attr={'axis': 0}) |
| 1496 | + else: |
| 1497 | + const_zero = ctx.make_const(utils.make_name("const_zero"), np.array([0], dtype=np.int64)) |
| 1498 | + pad_val = ctx.make_node("Concat", inputs=[const_zero.output[0], pad_amt.output[0]], attr={'axis': 0}) |
| 1499 | + pad_op = ctx.make_node("Pad", inputs=[squeeze_op.output[0], pad_val.output[0]]) |
1495 | 1500 | ctx.make_node("Cast", inputs=pad_op.output, name="cast_A", attr={"to": onnx_pb.TensorProto.INT32},
|
1496 | 1501 | outputs=[node.output[0]], dtypes=dtypes[0], shapes=shapes[0], op_name_scope=node.name)
|
1497 | 1502 | reduce_op = ctx.make_node("ReduceSum", inputs=shape_op.output, attr={"axes": [0], "keepdims": 0})
|
|
0 commit comments