@@ -1165,8 +1165,8 @@ def version_1(cls, ctx, node, **kwargs):
1165
1165
blocklen = len (block_shape )
1166
1166
xlen = len (ctx .get_shape (node .input [0 ]))
1167
1167
1168
- # if 4d tensor & square 2d block_shape , can optimize
1169
- cond1 = xlen == 4
1168
+ # if 3d or 4d tensor & square 2d block_shape , can optimize
1169
+ cond1 = xlen in [ 3 , 4 ]
1170
1170
cond2 = node .inputs [2 ].is_const ()
1171
1171
cond3 = blocklen == 2 and block_shape [0 ] == block_shape [1 ]
1172
1172
if cond1 and cond2 and cond3 :
@@ -1228,7 +1228,7 @@ def mkconst(desc, val, dtype=np.int64):
1228
1228
const_node = ctx .make_const (utils .make_name (nodename ), val .astype (dtype ))
1229
1229
return const_node .output [0 ]
1230
1230
1231
- # support non 4D tensors and dynamic crop vals
1231
+ # support non 3D/ 4D tensors and dynamic crop vals
1232
1232
# dynamic slice starts at opset 10
1233
1233
utils .make_sure (ctx .opset >= 10 , 'non-4D tensor or non-const crops require opset 10+' )
1234
1234
@@ -1311,8 +1311,8 @@ def version_1(cls, ctx, node, **kwargs):
1311
1311
blocklen = len (block_shape )
1312
1312
xlen = len (ctx .get_shape (node .input [0 ]))
1313
1313
1314
- # if 4d tensor & square 2d block_shape , can optimize
1315
- cond1 = xlen == 4
1314
+ # if 3d or 4d tensor & square 2d block_shape , can optimize
1315
+ cond1 = xlen in [ 3 , 4 ]
1316
1316
cond2 = node .inputs [2 ].is_const ()
1317
1317
cond3 = blocklen == 2 and block_shape [0 ] == block_shape [1 ]
1318
1318
if cond1 and cond2 and cond3 :
@@ -1356,7 +1356,7 @@ def mkconst(desc, val, dtype=np.int64):
1356
1356
const_node = ctx .make_const (utils .make_name (nodename ), val .astype (dtype ))
1357
1357
return const_node .output [0 ]
1358
1358
1359
- # support non 4D tensors and dynamic pad vals
1359
+ # support non 3D/ 4D tensors and dynamic pad vals
1360
1360
# dynamic slice starts at opset 10
1361
1361
utils .make_sure (ctx .opset >= 10 , 'non-4D tensor or non-const pads require opset 10+' )
1362
1362
0 commit comments