Skip to content

Commit da2d7a6

Browse files
lei-Qiaowayuanho
authored andcommitted
math lrn op set shape after transpose (#610)
* math lrn op set shape afert transpose
1 parent a6f8ccf commit da2d7a6

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

tf2onnx/onnx_opset/math.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,9 +261,14 @@ def version_1(cls, ctx, node, **kwargs):
261261
node.set_attr("size", size)
262262
node.set_attr("alpha", size * node.get_attr("alpha").f)
263263

264+
shapes = node.output_shapes[0]
265+
dtypes = node.output_dtypes[0]
266+
264267
ctx.insert_new_node_on_input(node, "Transpose", node.input[0], perm=constants.NHWC_TO_NCHW)
268+
ctx.update_node_shape_dtype(node, override=True)
265269
op_name = utils.make_name(node.name)
266-
ctx.insert_new_node_on_output("Transpose", node.output[0], perm=constants.NCHW_TO_NHWC, name=op_name)
270+
ctx.insert_new_node_on_output("Transpose", node.output[0], perm=constants.NCHW_TO_NHWC,
271+
name=op_name, shapes=shapes, dtypes=dtypes)
267272

268273

269274
@tf_op(["MatMul", "BatchMatMul"])

0 commit comments

Comments
 (0)