Skip to content

Commit f9ec94d

Browse files
committed
replace node.get_attr('xxx').i by node.get_attr_value('xxx')
1 parent 4ecf1cd commit f9ec94d

File tree

2 files changed

+5
-6
lines changed

2 files changed

+5
-6
lines changed

tf2onnx/onnx_opset/math.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -251,11 +251,10 @@ def version_1(cls, ctx, node, **kwargs):
251251
# ONNX: Each input value is divided by (bias+(alpha/size)*sum(xi^2 for every xi in the local region))^beta
252252
# TF: sqr_sum[a, b, c, d] = sum(input[a, b, c, d - depth_radius : d + depth_radius + 1] ** 2)
253253
# output = input / (bias + alpha * sqr_sum) ** beta
254-
depth_radius = node.get_attr("depth_radius")
255-
if depth_radius:
256-
size = depth_radius.i
257-
else:
258-
size = 5
254+
255+
# by default, depth_radius is 5 in tensorflow
256+
size = node.get_attr_value("depth_radius", 5) * 2 + 1
257+
259258
node.set_attr("size", size)
260259
node.set_attr("alpha", size * node.get_attr("alpha").f)
261260

tf2onnx/onnx_opset/tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1257,7 +1257,7 @@ def version_10(cls, ctx, node, **kwargs):
12571257
seq_dim = node.get_attr("seq_dim")
12581258
utils.make_sure(seq_dim is not None, "sequence dim must be given in {}".format(node.name))
12591259
seq_dim = seq_dim.i
1260-
batch_dim = node.get_attr_value("batch_dim")
1260+
batch_dim = node.get_attr_value("batch_dim", 0)
12611261

12621262
ctx.remove_node(node.name)
12631263
node = ctx.make_node(

0 commit comments

Comments
 (0)