@@ -441,6 +441,80 @@ def version_7(cls, ctx, node, **kwargs):
441
441
conv_convert_inputs (ctx , node , with_kernel = False )
442
442
443
443
444
+ @tf_op ("BatchToSpaceND" , onnx_op = "DepthToSpace" )
445
+ class BatchToSpace :
446
+ @classmethod
447
+ def version_4 (cls , ctx , node , ** kwargs ):
448
+ # https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-to-space-n-d.html
449
+ # the above link says the data format of input tensor should be (batch, spatial_shape, remaining_shape)
450
+ # and we only support 4D here, so the data format is NHWC
451
+ # onnx op "DepthToSpace" does the same work on input tensor except that it works on "C",
452
+ # and it only supports NCHW
453
+ # T out = BatchToSpaceND(T input, int32 block_shape, int32 crops)
454
+ input_tensor = node .inputs [0 ]
455
+ blocksize = node .inputs [1 ].get_tensor_value ()
456
+ crops = node .inputs [2 ].get_tensor_value ()
457
+
458
+ utils .make_sure (len (ctx .get_shape (input_tensor .output [0 ])) == 4 , "only supports 4D for now" )
459
+ utils .make_sure (len (blocksize ) == 2 and blocksize [0 ] == blocksize [1 ],
460
+ "only support same blocksize at different dims" )
461
+
462
+ ctx .remove_node (node .name )
463
+ # NHWC TO CNHW, so onnx op will work on "N" which is the same as tensorflow
464
+ trans1 = ctx .make_node ("Transpose" , input_tensor .output , {"perm" : [3 , 0 , 1 , 2 ]})
465
+ reorganize_node = ctx .make_node (node .type , trans1 .output , attr = {"blocksize" : blocksize [0 ]})
466
+ trans2 = ctx .make_node ("Transpose" , reorganize_node .output , {"perm" : [1 , 2 , 3 , 0 ]})
467
+
468
+ # implement crop logic, the data format is NHWC
469
+ slice_axis = [1 , 2 ]
470
+ top , bottom = crops [0 ]
471
+ left , right = crops [1 ]
472
+ starts = [top , left ]
473
+ ends = []
474
+ for end in [bottom , right ]:
475
+ if end != 0 :
476
+ ends .append (- end )
477
+ else :
478
+ ends .append (np .iinfo (np .int32 ).max )
479
+
480
+ ctx .make_node ("Slice" , trans2 .output , attr = {"axes" : slice_axis , "ends" : ends , "starts" : starts },
481
+ name = node .name , outputs = node .output )
482
+
483
+
484
+ @tf_op ("SpaceToBatchND" , onnx_op = "SpaceToDepth" )
485
+ class SpaceToBatch :
486
+ @classmethod
487
+ def version_4 (cls , ctx , node , ** kwargs ):
488
+ # https://www.tensorflow.org/api_docs/python/tf/space_to_batch_nd
489
+ # the above link says the data format of input tensor should be (batch, spatial_shape, remaining_shape)
490
+ # and we only support 4D here, so the data format is NHWC
491
+ # onnx op "SpaceToDepth" does the same work on input tensor except that it works on "C",
492
+ # and it only supports NCHW
493
+ # T out = SpaceToBatchND(T input, int32 block_shape, int32 crops)
494
+ input_tensor = node .inputs [0 ]
495
+ blocksize = node .inputs [1 ].get_tensor_value ()
496
+ paddings = node .inputs [2 ].get_tensor_value ()
497
+
498
+ utils .make_sure (len (ctx .get_shape (input_tensor .output [0 ])) == 4 , "only supports 4D for now" )
499
+ utils .make_sure (len (blocksize ) == 2 and blocksize [0 ] == blocksize [1 ],
500
+ "only support same blocksize at different dims" )
501
+
502
+ ctx .remove_node (node .name )
503
+
504
+ # implement pads logic, the data format is NHWC
505
+ top , bottom = paddings [0 ]
506
+ left , right = paddings [1 ]
507
+ pads = [0 , top , left , 0 ,
508
+ 0 , bottom , right , 0 ]
509
+
510
+ pad_op = ctx .make_node ("Pad" , input_tensor .output , attr = {"pads" : pads })
511
+
512
+ # NHWC TO CNHW, so onnx op will work on "N" which is the same as tensorflow
513
+ trans1 = ctx .make_node ("Transpose" , pad_op .output , {"perm" : [3 , 0 , 1 , 2 ]})
514
+ reorganize_node = ctx .make_node (node .type , trans1 .output , attr = {"blocksize" : blocksize [0 ]})
515
+ ctx .make_node ("Transpose" , reorganize_node .output , {"perm" : [1 , 2 , 3 , 0 ]}, name = node .name , outputs = node .output )
516
+
517
+
444
518
@tf_op (["ResizeBilinear" , "ResizeNearestNeighbor" ])
445
519
class ResizeX :
446
520
@classmethod
0 commit comments