Skip to content

Commit d694785

Browse files
authored
Allow 3D optmization in addition to 4D
1 parent d59146c commit d694785

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

tf2onnx/onnx_opset/tensor.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1165,8 +1165,8 @@ def version_1(cls, ctx, node, **kwargs):
11651165
blocklen = len(block_shape)
11661166
xlen = len(ctx.get_shape(node.input[0]))
11671167

1168-
# if 4d tensor & square 2d block_shape , can optimize
1169-
cond1 = xlen == 4
1168+
# if 3d or 4d tensor & square 2d block_shape , can optimize
1169+
cond1 = xlen in [3, 4]
11701170
cond2 = node.inputs[2].is_const()
11711171
cond3 = blocklen == 2 and block_shape[0] == block_shape[1]
11721172
if cond1 and cond2 and cond3:
@@ -1228,7 +1228,7 @@ def mkconst(desc, val, dtype=np.int64):
12281228
const_node = ctx.make_const(utils.make_name(nodename), val.astype(dtype))
12291229
return const_node.output[0]
12301230

1231-
# support non 4D tensors and dynamic crop vals
1231+
# support non 3D/4D tensors and dynamic crop vals
12321232
# dynamic slice starts at opset 10
12331233
utils.make_sure(ctx.opset >= 10, 'non-4D tensor or non-const crops require opset 10+')
12341234

@@ -1311,8 +1311,8 @@ def version_1(cls, ctx, node, **kwargs):
13111311
blocklen = len(block_shape)
13121312
xlen = len(ctx.get_shape(node.input[0]))
13131313

1314-
# if 4d tensor & square 2d block_shape , can optimize
1315-
cond1 = xlen == 4
1314+
# if 3d or 4d tensor & square 2d block_shape , can optimize
1315+
cond1 = xlen in [3, 4]
13161316
cond2 = node.inputs[2].is_const()
13171317
cond3 = blocklen == 2 and block_shape[0] == block_shape[1]
13181318
if cond1 and cond2 and cond3:
@@ -1356,7 +1356,7 @@ def mkconst(desc, val, dtype=np.int64):
13561356
const_node = ctx.make_const(utils.make_name(nodename), val.astype(dtype))
13571357
return const_node.output[0]
13581358

1359-
# support non 4D tensors and dynamic pad vals
1359+
# support non 3D/4D tensors and dynamic pad vals
13601360
# dynamic slice starts at opset 10
13611361
utils.make_sure(ctx.opset >= 10, 'non-4D tensor or non-const pads require opset 10+')
13621362

0 commit comments

Comments
 (0)