20
20
from tf2onnx .graph_builder import GraphBuilder
21
21
from tf2onnx .handler import tf_op
22
22
from tf2onnx .onnx_opset import nn , math
23
+ from tf2onnx .constants import NCHW_TO_NHWC , NHWC_TO_NCHW
23
24
24
25
logger = logging .getLogger (__name__ )
25
26
@@ -1392,9 +1393,8 @@ def any_version(cls, opset, ctx, node, **kwargs):
1392
1393
1393
1394
# if 3d or 4d tensor & square 2d block_shape , can optimize
1394
1395
cond1 = xlen in [3 , 4 ]
1395
- cond2 = node .inputs [2 ].is_const ()
1396
- cond3 = blocklen == 2 and block_shape [0 ] == block_shape [1 ]
1397
- if cond1 and cond2 and cond3 :
1396
+ cond2 = blocklen == 2 and block_shape [0 ] == block_shape [1 ]
1397
+ if cond1 and cond2 :
1398
1398
# https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-to-space-n-d.html
1399
1399
# the above link says the data format of input tensor should be (batch, spatial_shape, remaining_shape)
1400
1400
# and we only support 3D and 4D here, and the data format is NHC and NHWC
@@ -1403,47 +1403,81 @@ def any_version(cls, opset, ctx, node, **kwargs):
1403
1403
# T out = BatchToSpaceND(T input, int32 block_shape, int32 crops)
1404
1404
input_tensor = node .inputs [0 ]
1405
1405
input_shape = ctx .get_shape (input_tensor .output [0 ])
1406
- crops = node .inputs [2 ].get_tensor_value ()
1407
1406
1408
- # NHWC TO CNHW, so onnx op will work on "N" which is the same as tensorflow
1409
1407
if len (input_shape ) == 3 :
1410
1408
# insert automatically an Unsqueeze op if the input is 3d
1411
1409
unsqz1 = GraphBuilder (ctx ).make_unsqueeze (
1412
1410
{"axes" : [3 ], "data" : input_tensor .output [0 ]}, return_node = True )
1411
+ # NHWC TO CNHW, so onnx op will work on "N" which is the same as tensorflow
1413
1412
trans1 = ctx .make_node ("Transpose" , unsqz1 .output , {"perm" : [3 , 0 , 1 , 2 ]})
1414
1413
else :
1415
- trans1 = ctx .make_node ("Transpose" , input_tensor .output , {"perm" : [3 , 0 , 1 , 2 ]})
1414
+ # Add explicit NHWC_TO_NCHW transpose before and NCHW_TO_NHWC transpose after subgraph.
1415
+ # That enables more optimizations in TransposeOptimizer.
1416
+ trans_nchw = ctx .make_node ("Transpose" , input_tensor .output , {"perm" : NHWC_TO_NCHW })
1417
+ # NCHW TO CNHW
1418
+ trans1 = ctx .make_node ("Transpose" , trans_nchw .output , {"perm" : [1 , 0 , 2 , 3 ]})
1416
1419
reorganize_node = ctx .make_node (node .type , trans1 .output , attr = {"blocksize" : block_shape [0 ]})
1417
- trans2 = ctx .make_node ("Transpose" , reorganize_node .output , {"perm" : [1 , 2 , 3 , 0 ]})
1418
-
1419
- # implement crop logic, the data format is NHWC
1420
- slice_axis = [1 , 2 ]
1421
- top , bottom = crops [0 ]
1422
- left , right = crops [1 ]
1423
- starts = [top , left ]
1424
- ends = []
1425
- for end in [bottom , right ]:
1426
- if end != 0 :
1427
- ends .append (- end )
1428
- else :
1429
- ends .append (np .iinfo (np .int32 ).max )
1430
1420
1431
- attr = {"axes" : slice_axis , "ends" : ends , "starts" : starts }
1432
- inputs_map = {"data" : trans2 .output [0 ], ** attr }
1421
+ # implement crop logic, the data format is NCHW
1422
+ slice_axis = [2 , 3 ]
1423
+ if node .inputs [2 ].is_const ():
1424
+ crops = node .inputs [2 ].get_tensor_value ()
1425
+ top , bottom = crops [0 ]
1426
+ left , right = crops [1 ]
1427
+ starts = [top , left ]
1428
+ ends = []
1429
+ for end in [bottom , right ]:
1430
+ if end != 0 :
1431
+ ends .append (- end )
1432
+ else :
1433
+ ends .append (np .iinfo (np .int32 ).max )
1434
+ attr = {"axes" : slice_axis , "ends" : ends , "starts" : starts }
1435
+ else :
1436
+ shape = ctx .make_const (name = utils .make_name ("shape" ), np_val = np .array ([- 1 ], dtype = np .int64 ))
1437
+ reshape = ctx .make_node ("Cast" ,
1438
+ ctx .make_node ("Reshape" , inputs = [node .input [2 ], shape .output [0 ]]).output ,
1439
+ attr = {"to" : utils .map_numpy_to_onnx_dtype (np .int64 )})
1440
+ crops = ctx .make_node ("Split" , inputs = reshape .output , attr = {}, output_count = 4 ).output
1441
+ zero = ctx .make_const (name = utils .make_name ("zero" ), np_val = np .array ([0 ], dtype = np .int64 )).output [0 ]
1442
+ int32_max = ctx .make_const (name = utils .make_name ("int32_max" ),
1443
+ np_val = np .array ([np .iinfo (np .int32 ).max ], dtype = np .int64 )).output [0 ]
1444
+ def crop_to_end (crop ):
1445
+ eq = ctx .make_node ("Equal" , [crop , zero ])
1446
+ not_eq = ctx .make_node ("Not" , eq .output )
1447
+ cast_eq = ctx .make_node ("Cast" , eq .output , attr = {"to" : utils .map_numpy_to_onnx_dtype (np .int64 )})
1448
+ cast_not_eq = ctx .make_node ("Cast" , not_eq .output ,
1449
+ attr = {"to" : utils .map_numpy_to_onnx_dtype (np .int64 )})
1450
+ neg = ctx .make_node ("Neg" , cast_not_eq .output )
1451
+ add = ctx .make_node ("Add" ,
1452
+ [
1453
+ ctx .make_node ("Mul" , [crop , neg .output [0 ]]).output [0 ],
1454
+ ctx .make_node ("Mul" , [int32_max , cast_eq .output [0 ]]).output [0 ],
1455
+ ])
1456
+ return add .output [0 ]
1457
+
1458
+ starts = ctx .make_node ("Concat" , [crops [0 ], crops [2 ]], {'axis' : 0 })
1459
+ ends = ctx .make_node ("Concat" , [crop_to_end (crops [1 ]), crop_to_end (crops [3 ])], {'axis' : 0 })
1460
+ axes = ctx .make_const (name = utils .make_name ("axes" ), np_val = np .array (slice_axis , dtype = np .int64 ))
1461
+ attr = {"axes" : axes .output [0 ], "ends" : ends .output [0 ], "starts" : starts .output [0 ]}
1462
+ inputs_map = {"data" : reorganize_node .output [0 ], ** attr }
1433
1463
dtypes = node .output_dtypes
1434
1464
shapes = node .output_shapes
1435
1465
1466
+ ctx .remove_node (node .name )
1436
1467
if len (input_shape ) == 3 :
1437
1468
# add a squeeze op to convert output into 3d
1438
1469
kwargs = {** inputs_map }
1439
- ctx .remove_node (node .name )
1440
- slice1 = GraphBuilder (ctx ).make_slice (kwargs )
1441
- GraphBuilder (ctx ).make_squeeze (
1442
- {"axes" : [3 ], "data" : slice1 , "outputs" : node .output }, name = node .name , dtypes = dtypes , shapes = shapes )
1470
+ node_slice = GraphBuilder (ctx ).make_slice (kwargs )
1471
+ # CNHW TO NHWC
1472
+ trans2 = ctx .make_node ("Transpose" , [node_slice ], {"perm" : [1 , 2 , 3 , 0 ]})
1473
+ GraphBuilder (ctx ).make_squeeze ({"axes" : [3 ], "data" : trans2 .output [0 ], "outputs" : node .output },
1474
+ name = node .name , shapes = shapes , dtypes = dtypes )
1443
1475
else :
1444
- kwargs = {** inputs_map , "outputs" : node .output }
1445
- ctx .remove_node (node .name )
1446
- GraphBuilder (ctx ).make_slice (kwargs , name = node .name , dtypes = dtypes , shapes = shapes )
1476
+ node_slice = GraphBuilder (ctx ).make_slice (inputs_map )
1477
+ # CNHW TO NCHW
1478
+ trans2 = ctx .make_node ("Transpose" , [node_slice ], {"perm" : [1 , 0 , 2 , 3 ]})
1479
+ ctx .make_node ("Transpose" , trans2 .output , {"perm" : NCHW_TO_NHWC },
1480
+ name = node .name , outputs = node .output , shapes = shapes , dtypes = dtypes )
1447
1481
else :
1448
1482
def mknode (optype , inputs , attrs = None ):
1449
1483
nodename = utils .make_name (node .name + '_' + optype .lower ())
@@ -1545,7 +1579,10 @@ def version_1(cls, ctx, node, **kwargs):
1545
1579
1546
1580
# if 3d or 4d tensor & square 2d block_shape , can optimize
1547
1581
cond1 = xlen in [3 , 4 ]
1548
- cond2 = node .inputs [2 ].is_const ()
1582
+ # with opset 11 (or above), we can deal with non-const pads
1583
+ # by creating a subgraph with Split and Concat and pass its output
1584
+ # to Pad's second input
1585
+ cond2 = node .inputs [2 ].is_const () or ctx .opset >= 11
1549
1586
cond3 = blocklen == 2 and block_shape [0 ] == block_shape [1 ]
1550
1587
if cond1 and cond2 and cond3 :
1551
1588
# https://www.tensorflow.org/api_docs/python/tf/space_to_batch_nd
@@ -1555,29 +1592,60 @@ def version_1(cls, ctx, node, **kwargs):
1555
1592
# and it only supports NCHW
1556
1593
# T out = SpaceToBatchND(T input, int32 block_shape, int32 crops)
1557
1594
input_tensor = node .inputs [0 ]
1595
+ input_shape = ctx .get_shape (input_tensor .output [0 ])
1558
1596
shapes = [ctx .get_shape (node .output [0 ])]
1559
1597
dtypes = [ctx .get_dtype (node .output [0 ])]
1560
1598
1561
- # implement pads logic, the data format is NHWC
1562
- paddings = node .inputs [2 ].get_tensor_value ()
1563
- top , bottom = paddings [0 ]
1564
- left , right = paddings [1 ]
1565
- pads = [0 , top , left , 0 ,
1566
- 0 , bottom , right , 0 ]
1567
- ctx .remove_node (node .name )
1568
- if ctx .opset <= 10 :
1569
- pad_op = ctx .make_node ("Pad" , input_tensor .output , attr = {"pads" : pads })
1599
+ if len (input_shape ) == 3 :
1600
+ # insert automatically an Unsqueeze op if the input is 3d
1601
+ unsqz1 = GraphBuilder (ctx ).make_unsqueeze (
1602
+ {"axes" : [3 ], "data" : input_tensor .output [0 ]}, return_node = True )
1603
+ # NHWC TO CNHW
1604
+ trans1 = ctx .make_node ("Transpose" , unsqz1 .output , {"perm" : [3 , 0 , 1 , 2 ]})
1605
+ else :
1606
+ # Add explicit NHWC_TO_NCHW transpose before and NCHW_TO_NHWC transpose after subgraph.
1607
+ # That enables more optimizations in TransposeOptimizer.
1608
+ trans_nchw = ctx .make_node ("Transpose" , input_tensor .output , {"perm" : NHWC_TO_NCHW })
1609
+ # NCHW TO CNHW
1610
+ trans1 = ctx .make_node ("Transpose" , trans_nchw .output , {"perm" : [1 , 0 , 2 , 3 ]})
1611
+ # implement pads logic, the data format is NCHW
1612
+ if ctx .opset <= 10 or node .inputs [2 ].is_const ():
1613
+ paddings = node .inputs [2 ].get_tensor_value ()
1614
+ top , bottom = paddings [0 ]
1615
+ left , right = paddings [1 ]
1616
+ pads = [0 , 0 , top , left ,
1617
+ 0 , 0 , bottom , right ]
1618
+ if ctx .opset <= 10 :
1619
+ pad_op = ctx .make_node ("Pad" , trans1 .output , attr = {"pads" : pads })
1620
+ else :
1621
+ new_pads = ctx .make_const (name = utils .make_name ("pads" ), np_val = np .array (pads , dtype = np .int64 ))
1622
+ pad_op = ctx .make_node ("Pad" , [trans1 .output [0 ], new_pads .output [0 ]])
1570
1623
else :
1571
1624
# TODO: we should be able to support dynamic input here.
1572
- pads_name = utils .make_name (node .name )
1573
- ctx .make_const (name = pads_name , np_val = np .array (pads , dtype = np .int64 ))
1574
- pad_op = ctx .make_node ("Pad" , [input_tensor .output [0 ], pads_name ])
1625
+ shape = ctx .make_const (name = utils .make_name ("shape" ), np_val = np .array ([- 1 ], dtype = np .int64 ))
1626
+ reshape = ctx .make_node ("Reshape" , inputs = [node .input [2 ], shape .output [0 ]])
1627
+ cast = ctx .make_node ("Cast" , reshape .output , attr = {'to' : utils .map_numpy_to_onnx_dtype (np .int64 )})
1628
+ split = ctx .make_node ("Split" , inputs = cast .output , attr = {}, output_count = 4 )
1629
+ pads = split .output
1630
+ zero = ctx .make_const (name = utils .make_name ("zero" ), np_val = np .array ([0 ], dtype = np .int64 )).output [0 ]
1631
+ new_pads = ctx .make_node ("Concat" , [zero , zero , pads [0 ], pads [2 ], zero , zero , pads [1 ], pads [3 ]],
1632
+ {'axis' : 0 })
1633
+ pad_op = ctx .make_node ("Pad" , [trans1 .output [0 ], new_pads .output [0 ]])
1634
+
1635
+ reorganize_node = ctx .make_node (node .type , pad_op .output , attr = {"blocksize" : block_shape [0 ]})
1575
1636
1576
- # NHWC TO CNHW, so onnx op will work on "N" which is the same as tensorflow
1577
- trans1 = ctx .make_node ("Transpose" , pad_op .output , {"perm" : [3 , 0 , 1 , 2 ]})
1578
- reorganize_node = ctx .make_node (node .type , trans1 .output , attr = {"blocksize" : block_shape [0 ]})
1579
- ctx .make_node ("Transpose" , reorganize_node .output , {"perm" : [1 , 2 , 3 , 0 ]},
1580
- name = node .name , outputs = node .output , shapes = shapes , dtypes = dtypes )
1637
+ ctx .remove_node (node .name )
1638
+ if len (input_shape ) == 3 :
1639
+ # CNHW TO NHWC
1640
+ trans2 = ctx .make_node ("Transpose" , reorganize_node .output , {"perm" : [1 , 2 , 3 , 0 ]})
1641
+ # add a squeeze op to convert output into 3d
1642
+ GraphBuilder (ctx ).make_squeeze ({"axes" : [3 ], "data" : trans2 .output [0 ], "outputs" : node .output },
1643
+ name = node .name , shapes = shapes , dtypes = dtypes )
1644
+ else :
1645
+ # CNHW TO NCHW
1646
+ trans2 = ctx .make_node ("Transpose" , reorganize_node .output , {"perm" : [1 , 0 , 2 , 3 ]})
1647
+ ctx .make_node ("Transpose" , trans2 .output , {"perm" : NCHW_TO_NHWC },
1648
+ name = node .name , outputs = node .output , shapes = shapes , dtypes = dtypes )
1581
1649
else :
1582
1650
def mknode (optype , inputs , attrs = None ):
1583
1651
nodename = utils .make_name (node .name + '_' + optype .lower ())
0 commit comments