Skip to content

Commit 1db1270

Browse files
Enhance BatchToSpaceND and SpaceToBatchND handlers (#1386)
* Enhance BatchToSpaceND and SpaceToBatchND handlers - add explicit NHWC_TO_NCHW and NCHW_TO_NHWC transposes so they can be eliminated with its counterpart in TransposeOptimizer - handle nonconst pads and crops Signed-off-by: Mateusz Tabaka <[email protected]> * test [2, 2] block_shape in test_batch_to_spacend_non_const_7d Signed-off-by: Mateusz Tabaka <[email protected]> Co-authored-by: TomWildenhain-Microsoft <[email protected]>
1 parent b540e3c commit 1db1270

File tree

2 files changed

+117
-49
lines changed

2 files changed

+117
-49
lines changed

tests/test_backend.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3321,11 +3321,11 @@ def func(input_x, block_size, pad):
33213321

33223322
@check_opset_min_version(11, "BatchToSpaceND")
33233323
def test_batch_to_spacend_non_const_7d(self):
3324-
x_type, y_type, z_type = np.int64, np.int64, np.int64
3324+
x_type, y_type, z_type = np.float32, np.int64, np.int64
33253325
# test 3D upto 7D input tensors
33263326
for x_shape in [[12, 4, 4], [12, 4, 8, 3], [12, 4, 8, 3, 2], [12, 4, 8, 3, 2, 3], [12, 4, 8, 3, 2, 1, 3]]:
33273327
# test 1D upto 2D block shapes
3328-
for block_shape in [[2, 3], [2]]:
3328+
for block_shape in [[2, 3], [2, 2], [2]]:
33293329
# crop 1 layer at end of each dim
33303330
# x and z can be dynamic.
33313331
# y = block_shape cannot be dynamic without change to Transpose op spec
@@ -3340,7 +3340,7 @@ def func(x, z):
33403340

33413341
@check_opset_min_version(11, "SpaceToBatchND")
33423342
def test_space_to_batchnd_non_const_7d(self):
3343-
x_type, y_type, z_type = np.int64, np.int64, np.int64
3343+
x_type, y_type, z_type = np.float32, np.int64, np.int64
33443344
# test 3D upto 7D input tensors
33453345
for x_shape in [[2, 4, 4], [1, 4, 8, 3], [1, 4, 8, 3, 2], [1, 4, 8, 3, 2, 3], [1, 4, 8, 3, 2, 1, 3]]:
33463346
# test 1D upto 2D block shapes

tf2onnx/onnx_opset/tensor.py

Lines changed: 114 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from tf2onnx.graph_builder import GraphBuilder
2121
from tf2onnx.handler import tf_op
2222
from tf2onnx.onnx_opset import nn, math
23+
from tf2onnx.constants import NCHW_TO_NHWC, NHWC_TO_NCHW
2324

2425
logger = logging.getLogger(__name__)
2526

@@ -1392,9 +1393,8 @@ def any_version(cls, opset, ctx, node, **kwargs):
13921393

13931394
# if 3d or 4d tensor & square 2d block_shape , can optimize
13941395
cond1 = xlen in [3, 4]
1395-
cond2 = node.inputs[2].is_const()
1396-
cond3 = blocklen == 2 and block_shape[0] == block_shape[1]
1397-
if cond1 and cond2 and cond3:
1396+
cond2 = blocklen == 2 and block_shape[0] == block_shape[1]
1397+
if cond1 and cond2:
13981398
# https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-to-space-n-d.html
13991399
# the above link says the data format of input tensor should be (batch, spatial_shape, remaining_shape)
14001400
# and we only support 3D and 4D here, and the data format is NHC and NHWC
@@ -1403,47 +1403,81 @@ def any_version(cls, opset, ctx, node, **kwargs):
14031403
# T out = BatchToSpaceND(T input, int32 block_shape, int32 crops)
14041404
input_tensor = node.inputs[0]
14051405
input_shape = ctx.get_shape(input_tensor.output[0])
1406-
crops = node.inputs[2].get_tensor_value()
14071406

1408-
# NHWC TO CNHW, so onnx op will work on "N" which is the same as tensorflow
14091407
if len(input_shape) == 3:
14101408
# insert automatically an Unsqueeze op if the input is 3d
14111409
unsqz1 = GraphBuilder(ctx).make_unsqueeze(
14121410
{"axes": [3], "data": input_tensor.output[0]}, return_node=True)
1411+
# NHWC TO CNHW, so onnx op will work on "N" which is the same as tensorflow
14131412
trans1 = ctx.make_node("Transpose", unsqz1.output, {"perm": [3, 0, 1, 2]})
14141413
else:
1415-
trans1 = ctx.make_node("Transpose", input_tensor.output, {"perm": [3, 0, 1, 2]})
1414+
# Add explicit NHWC_TO_NCHW transpose before and NCHW_TO_NHWC transpose after subgraph.
1415+
# That enables more optimizations in TransposeOptimizer.
1416+
trans_nchw = ctx.make_node("Transpose", input_tensor.output, {"perm": NHWC_TO_NCHW})
1417+
# NCHW TO CNHW
1418+
trans1 = ctx.make_node("Transpose", trans_nchw.output, {"perm": [1, 0, 2, 3]})
14161419
reorganize_node = ctx.make_node(node.type, trans1.output, attr={"blocksize": block_shape[0]})
1417-
trans2 = ctx.make_node("Transpose", reorganize_node.output, {"perm": [1, 2, 3, 0]})
1418-
1419-
# implement crop logic, the data format is NHWC
1420-
slice_axis = [1, 2]
1421-
top, bottom = crops[0]
1422-
left, right = crops[1]
1423-
starts = [top, left]
1424-
ends = []
1425-
for end in [bottom, right]:
1426-
if end != 0:
1427-
ends.append(-end)
1428-
else:
1429-
ends.append(np.iinfo(np.int32).max)
14301420

1431-
attr = {"axes": slice_axis, "ends": ends, "starts": starts}
1432-
inputs_map = {"data": trans2.output[0], **attr}
1421+
# implement crop logic, the data format is NCHW
1422+
slice_axis = [2, 3]
1423+
if node.inputs[2].is_const():
1424+
crops = node.inputs[2].get_tensor_value()
1425+
top, bottom = crops[0]
1426+
left, right = crops[1]
1427+
starts = [top, left]
1428+
ends = []
1429+
for end in [bottom, right]:
1430+
if end != 0:
1431+
ends.append(-end)
1432+
else:
1433+
ends.append(np.iinfo(np.int32).max)
1434+
attr = {"axes": slice_axis, "ends": ends, "starts": starts}
1435+
else:
1436+
shape = ctx.make_const(name=utils.make_name("shape"), np_val=np.array([-1], dtype=np.int64))
1437+
reshape = ctx.make_node("Cast",
1438+
ctx.make_node("Reshape", inputs=[node.input[2], shape.output[0]]).output,
1439+
attr={"to": utils.map_numpy_to_onnx_dtype(np.int64)})
1440+
crops = ctx.make_node("Split", inputs=reshape.output, attr={}, output_count=4).output
1441+
zero = ctx.make_const(name=utils.make_name("zero"), np_val=np.array([0], dtype=np.int64)).output[0]
1442+
int32_max = ctx.make_const(name=utils.make_name("int32_max"),
1443+
np_val=np.array([np.iinfo(np.int32).max], dtype=np.int64)).output[0]
1444+
def crop_to_end(crop):
1445+
eq = ctx.make_node("Equal", [crop, zero])
1446+
not_eq = ctx.make_node("Not", eq.output)
1447+
cast_eq = ctx.make_node("Cast", eq.output, attr={"to": utils.map_numpy_to_onnx_dtype(np.int64)})
1448+
cast_not_eq = ctx.make_node("Cast", not_eq.output,
1449+
attr={"to": utils.map_numpy_to_onnx_dtype(np.int64)})
1450+
neg = ctx.make_node("Neg", cast_not_eq.output)
1451+
add = ctx.make_node("Add",
1452+
[
1453+
ctx.make_node("Mul", [crop, neg.output[0]]).output[0],
1454+
ctx.make_node("Mul", [int32_max, cast_eq.output[0]]).output[0],
1455+
])
1456+
return add.output[0]
1457+
1458+
starts = ctx.make_node("Concat", [crops[0], crops[2]], {'axis': 0})
1459+
ends = ctx.make_node("Concat", [crop_to_end(crops[1]), crop_to_end(crops[3])], {'axis': 0})
1460+
axes = ctx.make_const(name=utils.make_name("axes"), np_val=np.array(slice_axis, dtype=np.int64))
1461+
attr = {"axes": axes.output[0], "ends": ends.output[0], "starts": starts.output[0]}
1462+
inputs_map = {"data": reorganize_node.output[0], **attr}
14331463
dtypes = node.output_dtypes
14341464
shapes = node.output_shapes
14351465

1466+
ctx.remove_node(node.name)
14361467
if len(input_shape) == 3:
14371468
# add a squeeze op to convert output into 3d
14381469
kwargs = {**inputs_map}
1439-
ctx.remove_node(node.name)
1440-
slice1 = GraphBuilder(ctx).make_slice(kwargs)
1441-
GraphBuilder(ctx).make_squeeze(
1442-
{"axes": [3], "data": slice1, "outputs": node.output}, name=node.name, dtypes=dtypes, shapes=shapes)
1470+
node_slice = GraphBuilder(ctx).make_slice(kwargs)
1471+
# CNHW TO NHWC
1472+
trans2 = ctx.make_node("Transpose", [node_slice], {"perm": [1, 2, 3, 0]})
1473+
GraphBuilder(ctx).make_squeeze({"axes": [3], "data": trans2.output[0], "outputs": node.output},
1474+
name=node.name, shapes=shapes, dtypes=dtypes)
14431475
else:
1444-
kwargs = {**inputs_map, "outputs": node.output}
1445-
ctx.remove_node(node.name)
1446-
GraphBuilder(ctx).make_slice(kwargs, name=node.name, dtypes=dtypes, shapes=shapes)
1476+
node_slice = GraphBuilder(ctx).make_slice(inputs_map)
1477+
# CNHW TO NCHW
1478+
trans2 = ctx.make_node("Transpose", [node_slice], {"perm": [1, 0, 2, 3]})
1479+
ctx.make_node("Transpose", trans2.output, {"perm": NCHW_TO_NHWC},
1480+
name=node.name, outputs=node.output, shapes=shapes, dtypes=dtypes)
14471481
else:
14481482
def mknode(optype, inputs, attrs=None):
14491483
nodename = utils.make_name(node.name + '_' + optype.lower())
@@ -1545,7 +1579,10 @@ def version_1(cls, ctx, node, **kwargs):
15451579

15461580
# if 3d or 4d tensor & square 2d block_shape , can optimize
15471581
cond1 = xlen in [3, 4]
1548-
cond2 = node.inputs[2].is_const()
1582+
# with opset 11 (or above), we can deal with non-const pads
1583+
# by creating a subgraph with Split and Concat and pass its output
1584+
# to Pad's second input
1585+
cond2 = node.inputs[2].is_const() or ctx.opset >= 11
15491586
cond3 = blocklen == 2 and block_shape[0] == block_shape[1]
15501587
if cond1 and cond2 and cond3:
15511588
# https://www.tensorflow.org/api_docs/python/tf/space_to_batch_nd
@@ -1555,29 +1592,60 @@ def version_1(cls, ctx, node, **kwargs):
15551592
# and it only supports NCHW
15561593
# T out = SpaceToBatchND(T input, int32 block_shape, int32 crops)
15571594
input_tensor = node.inputs[0]
1595+
input_shape = ctx.get_shape(input_tensor.output[0])
15581596
shapes = [ctx.get_shape(node.output[0])]
15591597
dtypes = [ctx.get_dtype(node.output[0])]
15601598

1561-
# implement pads logic, the data format is NHWC
1562-
paddings = node.inputs[2].get_tensor_value()
1563-
top, bottom = paddings[0]
1564-
left, right = paddings[1]
1565-
pads = [0, top, left, 0,
1566-
0, bottom, right, 0]
1567-
ctx.remove_node(node.name)
1568-
if ctx.opset <= 10:
1569-
pad_op = ctx.make_node("Pad", input_tensor.output, attr={"pads": pads})
1599+
if len(input_shape) == 3:
1600+
# insert automatically an Unsqueeze op if the input is 3d
1601+
unsqz1 = GraphBuilder(ctx).make_unsqueeze(
1602+
{"axes": [3], "data": input_tensor.output[0]}, return_node=True)
1603+
# NHWC TO CNHW
1604+
trans1 = ctx.make_node("Transpose", unsqz1.output, {"perm": [3, 0, 1, 2]})
1605+
else:
1606+
# Add explicit NHWC_TO_NCHW transpose before and NCHW_TO_NHWC transpose after subgraph.
1607+
# That enables more optimizations in TransposeOptimizer.
1608+
trans_nchw = ctx.make_node("Transpose", input_tensor.output, {"perm": NHWC_TO_NCHW})
1609+
# NCHW TO CNHW
1610+
trans1 = ctx.make_node("Transpose", trans_nchw.output, {"perm": [1, 0, 2, 3]})
1611+
# implement pads logic, the data format is NCHW
1612+
if ctx.opset <= 10 or node.inputs[2].is_const():
1613+
paddings = node.inputs[2].get_tensor_value()
1614+
top, bottom = paddings[0]
1615+
left, right = paddings[1]
1616+
pads = [0, 0, top, left,
1617+
0, 0, bottom, right]
1618+
if ctx.opset <= 10:
1619+
pad_op = ctx.make_node("Pad", trans1.output, attr={"pads": pads})
1620+
else:
1621+
new_pads = ctx.make_const(name=utils.make_name("pads"), np_val=np.array(pads, dtype=np.int64))
1622+
pad_op = ctx.make_node("Pad", [trans1.output[0], new_pads.output[0]])
15701623
else:
15711624
# TODO: we should be able to support dynamic input here.
1572-
pads_name = utils.make_name(node.name)
1573-
ctx.make_const(name=pads_name, np_val=np.array(pads, dtype=np.int64))
1574-
pad_op = ctx.make_node("Pad", [input_tensor.output[0], pads_name])
1625+
shape = ctx.make_const(name=utils.make_name("shape"), np_val=np.array([-1], dtype=np.int64))
1626+
reshape = ctx.make_node("Reshape", inputs=[node.input[2], shape.output[0]])
1627+
cast = ctx.make_node("Cast", reshape.output, attr={'to': utils.map_numpy_to_onnx_dtype(np.int64)})
1628+
split = ctx.make_node("Split", inputs=cast.output, attr={}, output_count=4)
1629+
pads = split.output
1630+
zero = ctx.make_const(name=utils.make_name("zero"), np_val=np.array([0], dtype=np.int64)).output[0]
1631+
new_pads = ctx.make_node("Concat", [zero, zero, pads[0], pads[2], zero, zero, pads[1], pads[3]],
1632+
{'axis': 0})
1633+
pad_op = ctx.make_node("Pad", [trans1.output[0], new_pads.output[0]])
1634+
1635+
reorganize_node = ctx.make_node(node.type, pad_op.output, attr={"blocksize": block_shape[0]})
15751636

1576-
# NHWC TO CNHW, so onnx op will work on "N" which is the same as tensorflow
1577-
trans1 = ctx.make_node("Transpose", pad_op.output, {"perm": [3, 0, 1, 2]})
1578-
reorganize_node = ctx.make_node(node.type, trans1.output, attr={"blocksize": block_shape[0]})
1579-
ctx.make_node("Transpose", reorganize_node.output, {"perm": [1, 2, 3, 0]},
1580-
name=node.name, outputs=node.output, shapes=shapes, dtypes=dtypes)
1637+
ctx.remove_node(node.name)
1638+
if len(input_shape) == 3:
1639+
# CNHW TO NHWC
1640+
trans2 = ctx.make_node("Transpose", reorganize_node.output, {"perm": [1, 2, 3, 0]})
1641+
# add a squeeze op to convert output into 3d
1642+
GraphBuilder(ctx).make_squeeze({"axes": [3], "data": trans2.output[0], "outputs": node.output},
1643+
name=node.name, shapes=shapes, dtypes=dtypes)
1644+
else:
1645+
# CNHW TO NCHW
1646+
trans2 = ctx.make_node("Transpose", reorganize_node.output, {"perm": [1, 0, 2, 3]})
1647+
ctx.make_node("Transpose", trans2.output, {"perm": NCHW_TO_NHWC},
1648+
name=node.name, outputs=node.output, shapes=shapes, dtypes=dtypes)
15811649
else:
15821650
def mknode(optype, inputs, attrs=None):
15831651
nodename = utils.make_name(node.name + '_' + optype.lower())

0 commit comments

Comments
 (0)