@@ -589,6 +589,48 @@ def version_11(cls, ctx, node, **kwargs):
589
589
590
590
@tf_op (["CropAndResize" ])
591
591
class CropAndResize :
592
+ @classmethod
593
+ def version_10 (cls , ctx , node , ** kwargs ):
594
+ utils .make_sure (node .inputs [1 ].type == "Const" , "boxes input must be a Const" )
595
+ utils .make_sure (node .inputs [3 ].type == "Const" , "boxes input must be a Const" )
596
+ name = node .name
597
+ output_height = node .inputs [3 ].get_tensor_value ()[0 ]
598
+ output_width = node .inputs [3 ].get_tensor_value ()[1 ]
599
+ rois = node .inputs [1 ].get_tensor_value ()
600
+ rois_shape = ctx .get_shape (node .input [1 ])
601
+ img_shape = ctx .get_shape (node .input [0 ])
602
+ transform_rois = np .zeros (list (rois_shape ), dtype = np .float32 )
603
+ for i in range (rois_shape [0 ]):
604
+ y1 , x1 , y2 , x2 = rois [i ]
605
+ y1 = y1 * (img_shape [1 ] - 1 )
606
+ y2 = y2 * (img_shape [1 ] - 1 )
607
+ x1 = x1 * (img_shape [2 ] - 1 )
608
+ x2 = x2 * (img_shape [2 ] - 1 )
609
+ spacing_h = (y2 - y1 )
610
+ spacing_w = (x2 - x1 )
611
+ b1 = y1 - 0.5 * spacing_h / (output_height - 1 )
612
+ a1 = x1 - 0.5 * spacing_w / (output_width - 1 )
613
+ b2 = y2 + 0.5 * spacing_h / (output_height - 1 )
614
+ a2 = x2 + 0.5 * spacing_w / (output_width - 1 )
615
+ transform_rois [i ][0 ] = a1
616
+ transform_rois [i ][1 ] = b1
617
+ transform_rois [i ][2 ] = a2
618
+ transform_rois [i ][3 ] = b2
619
+ cast_node = ctx .make_node ("Cast" , [node .input [2 ]], attr = {"to" : onnx_pb .TensorProto .INT64 })
620
+ bbox_node = ctx .make_const (utils .make_name ("bbox" ), transform_rois )
621
+ dtypes = [ctx .get_dtype (node .output [0 ])]
622
+ shapes = [ctx .get_shape (node .output [0 ])]
623
+ input_nchw = ctx .make_node ("Transpose" , [node .input [0 ]], {"perm" : [0 , 3 , 1 , 2 ]},
624
+ name = utils .make_name (node .name ))
625
+ crop_and_resize = ctx .make_node ("RoiAlign" , inputs = [input_nchw .output [0 ], bbox_node .output [0 ],
626
+ cast_node .output [0 ]],
627
+ attr = {"output_height" : output_height , "output_width" : output_width ,
628
+ "spatial_scale" : 1.0 , "sampling_ratio" : 1 },
629
+ name = utils .make_name (node .name ), dtypes = dtypes , shapes = shapes )
630
+ ctx .remove_node (name )
631
+ res = ctx .make_node ("Transpose" , crop_and_resize .output , {"perm" : [0 , 2 , 3 , 1 ]},
632
+ name = name , outputs = node .output , shapes = shapes , dtypes = dtypes )
633
+
592
634
@classmethod
593
635
def version_11 (cls , ctx , node , ** kwargs ):
594
636
# create loop of resize to cater to tensorflow CropAndResize, one box one iteration
0 commit comments