Skip to content

Commit 3d39433

Browse files
committed
fix math lrn op numerical problem
1 parent 701dbfe commit 3d39433

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

tf2onnx/onnx_opset/math.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,14 +255,15 @@ def version_1(cls, ctx, node, **kwargs):
255255
if depth_radius:
256256
size = depth_radius.i * 2 + 1
257257
else:
258+
# by default, depth_radius is 5 in tensorflow
258259
size = 5 * 2 + 1
259260

260261
node.set_attr("size", size)
261262
node.set_attr("alpha", size * node.get_attr("alpha").f)
262263

263-
ctx.insert_new_node_on_input(node, "Transpose", node.input[0], perm=[0, 3, 1, 2])
264+
ctx.insert_new_node_on_input(node, "Transpose", node.input[0], perm=constants.NHWC_TO_NCHW)
264265
op_name = utils.make_name(node.name)
265-
ctx.insert_new_node_on_output("Transpose", node.output[0], perm=[0, 2, 3, 1], name=op_name)
266+
ctx.insert_new_node_on_output("Transpose", node.output[0], perm=constants.NCHW_TO_NHWC, name=op_name)
266267

267268

268269
@tf_op(["MatMul", "BatchMatMul"])

0 commit comments

Comments
 (0)