Skip to content

Commit fdbbd25

Browse files
committed
replace node.get_attr('xxx').i by node.get_attr_value('xxx')
1 parent 1dda7df commit fdbbd25

File tree

2 files changed

+13
-8
lines changed

2 files changed

+13
-8
lines changed

tf2onnx/onnx_opset/math.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -251,12 +251,19 @@ def version_1(cls, ctx, node, **kwargs):
251251
# ONNX: Each input value is divided by (bias+(alpha/size)*sum(xi^2 for every xi in the local region))^beta
252252
# TF: sqr_sum[a, b, c, d] = sum(input[a, b, c, d - depth_radius : d + depth_radius + 1] ** 2)
253253
# output = input / (bias + alpha * sqr_sum) ** beta
254-
depth_radius = node.get_attr("depth_radius")
255-
if depth_radius:
256-
size = depth_radius.i
257-
else:
258-
size = 5
254+
size = node.get_attr_value("depth_radius") * 2 + 1
255+
259256
node.set_attr("size", size)
257+
node.set_attr("alpha", size * node.get_attr("alpha").f)
258+
259+
shapes = node.output_shapes[0]
260+
dtypes = node.output_dtypes[0]
261+
262+
ctx.insert_new_node_on_input(node, "Transpose", node.input[0], perm=constants.NHWC_TO_NCHW)
263+
ctx.update_node_shape_dtype(node, override=True)
264+
op_name = utils.make_name(node.name)
265+
ctx.insert_new_node_on_output("Transpose", node.output[0], perm=constants.NCHW_TO_NHWC,
266+
name=op_name, shapes=shapes, dtypes=dtypes)
260267

261268

262269
@tf_op(["MatMul", "BatchMatMul"])

tf2onnx/onnx_opset/tensor.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1257,9 +1257,7 @@ def version_10(cls, ctx, node, **kwargs):
12571257
seq_dim = node.get_attr("seq_dim")
12581258
utils.make_sure(seq_dim is not None, "sequence dim must be given in {}".format(node.name))
12591259
seq_dim = seq_dim.i
1260-
batch_dim = node.get_attr("batch_dim")
1261-
#batch_dim is set by default 0 in tf
1262-
batch_dim = batch_dim.i
1260+
batch_dim = node.get_attr_value("batch_dim")
12631261

12641262
ctx.remove_node(node.name)
12651263
node = ctx.make_node(

0 commit comments

Comments
 (0)