Skip to content

Commit 5d58e9c

Browse files
committed
fix math lrn op numerical problem
1 parent ffe6792 commit 5d58e9c

File tree

2 files changed

+20
-7
lines changed

2 files changed

+20
-7
lines changed

tests/test_backend.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -489,13 +489,20 @@ def test_conv2d_with_input_transpose(self):
489489
process_args={"inputs_as_nchw": [_INPUT]},
490490
onnx_feed_dict={_INPUT: x_val_for_onnx})
491491

492-
@unittest.skip("")
493-
def test_lrn(self):
494-
# FIXME: numerical results are not correct
492+
def test_lrn_default(self):
495493
x_shape = [1, 3, 4, 3]
496494
x_val = np.arange(1, 1 + np.prod(x_shape)).astype("float32").reshape(x_shape)
497-
_ = tf.placeholder(tf.float32, shape=x_val.shape, name=_TFINPUT)
498-
op = tf.nn.local_response_normalization(x_val)
495+
x_ = tf.placeholder(tf.float32, shape=x_shape, name=_TFINPUT)
496+
op = tf.nn.local_response_normalization(x_)
497+
_ = tf.identity(op, name=_TFOUTPUT)
498+
self._run_test_case([_OUTPUT], {_INPUT: x_val}, rtol=1e-05)
499+
500+
def test_lrn(self):
501+
# can't set bias = 0
502+
x_shape = [1, 2, 2, 8]
503+
x_val = np.arange(1, 1 + np.prod(x_shape)).astype("float32").reshape(x_shape)
504+
x_ = tf.placeholder(tf.float32, shape=x_shape, name=_TFINPUT)
505+
op = tf.nn.local_response_normalization(x_, depth_radius=4, bias=2, alpha=2, beta=1)
499506
_ = tf.identity(op, name=_TFOUTPUT)
500507
self._run_test_case([_OUTPUT], {_INPUT: x_val}, rtol=1e-05)
501508

tf2onnx/onnx_opset/math.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,10 +253,16 @@ def version_1(cls, ctx, node, **kwargs):
253253
# output = input / (bias + alpha * sqr_sum) ** beta
254254
depth_radius = node.get_attr("depth_radius")
255255
if depth_radius:
256-
size = depth_radius.i
256+
size = depth_radius.i * 2 + 1
257257
else:
258-
size = 5
258+
size = 5 * 2 + 1
259+
259260
node.set_attr("size", size)
261+
node.set_attr("alpha", size * node.get_attr("alpha").f)
262+
263+
ctx.insert_new_node_on_input(node, "Transpose", node.input[0], perm=[0, 3, 1, 2])
264+
op_name = utils.make_name(node.name)
265+
ctx.insert_new_node_on_output("Transpose", node.output[0], perm=[0, 2, 3, 1],name=op_name)
260266

261267

262268
@tf_op(["MatMul", "BatchMatMul"])

0 commit comments

Comments
 (0)