@@ -587,6 +587,67 @@ def version_11(cls, ctx, node, **kwargs):
587
587
cls .version_1 (ctx , node , ** kwargs )
588
588
589
589
590
+ @tf_op (["CropAndResize" ])
591
+ class CropAndResize :
592
+ @classmethod
593
+ def version_11 (cls , ctx , node , ** kwargs ):
594
+ # create loop of resize to cater to tensorflow CropAndResize, one box one iteration
595
+ mode = "nearest" if node .get_attr ("method" ) is not None and node .get_attr (
596
+ "method" ).s == b"nearest" else "linear"
597
+ extrapolation_value = float (node .get_attr ("extrapolation_value" , "0" ).f )
598
+ input_x = node .inputs [0 ]
599
+ boxes = node .inputs [1 ]
600
+ box_ind = node .inputs [2 ]
601
+ crop_size = node .inputs [3 ]
602
+ trip_name = utils .make_name (node .name + "_i" )
603
+ cond_name = utils .make_name (node .name + "_cond" )
604
+ cond_out_name = utils .make_name (node .name + "cond_out" )
605
+ g = ctx .create_new_graph_with_same_config ()
606
+ g .add_graph_input (trip_name , TensorProto .INT64 , [1 ])
607
+ g .add_graph_input (cond_name , TensorProto .BOOL , [])
608
+ g .parent_graph = ctx
609
+ const_zero = g .make_const (utils .make_name (node .name + "_const_zero" ), np .array ([0 ], dtype = np .int32 ))
610
+ const_zero_long = g .make_const (utils .make_name (node .name + "_const_zero_long" ), np .array ([0 ], dtype = np .int64 ))
611
+ const_one = g .make_const (utils .make_name (node .name + "_const_one" ), np .array ([1 ], dtype = np .int32 ))
612
+ const_one_long = g .make_const (utils .make_name (node .name + "_const_one_long" ), np .array ([1 ], dtype = np .int64 ))
613
+ index_end = g .make_node ("Add" , [trip_name , const_one_long .output [0 ]])
614
+ box_index_from = g .make_node ("Slice" , [box_ind .output [0 ], trip_name , index_end .output [0 ]], name = "Slice_a" )
615
+ box_index_to = g .make_node ("Add" , [box_index_from .output [0 ], const_one .output [0 ]])
616
+ target_x = g .make_node ("Slice" , [input_x .output [0 ], box_index_from .output [0 ], box_index_to .output [0 ],
617
+ const_zero .output [0 ]], name = "Slice_b" )
618
+ transposed_x = g .make_node ("Transpose" , [target_x .output [0 ]], attr = {'perm' : constants .NHWC_TO_NCHW })
619
+ const_zero_zero = g .make_const (utils .make_name (node .name + "_const_zero_zero" ),
620
+ np .array ([0 , 0 ], dtype = np .float32 ))
621
+ const_one_one = g .make_const (utils .make_name (node .name + "_const_one_one" ),
622
+ np .array ([1 , 1 ], dtype = np .float32 ))
623
+ const_four = g .make_const (utils .make_name (node .name + "_const_four" ), np .array ([4 ], dtype = np .int64 ))
624
+ const_empty_float = g .make_const (utils .make_name ("const_empty_float" ), np .array ([], dtype = np .float32 ))
625
+ box = g .make_node ("Slice" , [boxes .output [0 ], trip_name , index_end .output [0 ], const_zero_long .output [0 ]],
626
+ name = "Slice_c" )
627
+ roi_raw = g .make_node ("Reshape" , [box .output [0 ], const_four .output [0 ]])
628
+ roi_raw_first_half = GraphBuilder (g ).make_slice ({"data" : roi_raw .output [0 ], "ends" : [2 ], "starts" : [0 ]})
629
+ roi_raw_second_half = GraphBuilder (g ).make_slice ({"data" : roi_raw .output [0 ], "ends" : [4 ], "starts" : [2 ]})
630
+ roi_concat_1 = g .make_node ("Concat" , [const_zero_zero .output [0 ], roi_raw_first_half ], attr = {'axis' : 0 })
631
+ roi_concat_2 = g .make_node ("Concat" , [const_one_one .output [0 ], roi_raw_second_half ], attr = {'axis' : 0 })
632
+ final_roi = g .make_node ("Concat" , [roi_concat_1 .output [0 ], roi_concat_2 .output [0 ]], attr = {'axis' : 0 })
633
+ final_crop_size = build_dynamic_target_size (g , transposed_x , crop_size )
634
+ resized_x = g .make_node ("Resize" , [transposed_x .output [0 ], final_roi .output [0 ], const_empty_float .output [0 ],
635
+ final_crop_size .output [0 ]],
636
+ attr = {"mode" : mode , "extrapolation_value" : extrapolation_value ,
637
+ "coordinate_transformation_mode" : "tf_crop_and_resize" })
638
+ recovered_x = g .make_node ("Transpose" , [resized_x .output [0 ]], attr = {'perm' : constants .NCHW_TO_NHWC })
639
+ squeeze_x = g .make_node ("Squeeze" , inputs = [recovered_x .output [0 ]], attr = {"axes" : [0 ]})
640
+ g .make_node ("Identity" , [cond_name ], outputs = [cond_out_name ])
641
+ g .add_graph_output (cond_out_name , TensorProto .BOOL , [])
642
+ g .add_graph_output (squeeze_x .output [0 ], ctx .get_dtype (node .input [0 ]), [- 1 , - 1 , - 1 ])
643
+ trip_node = ctx .make_node ("Size" , [box_ind .output [0 ]])
644
+ cond_const = ctx .make_const (utils .make_name ("cond" ), np .ones ((), dtype = np .bool ))
645
+ ctx .remove_node (node .name )
646
+ inner_loop = ctx .make_node ("Loop" , [trip_node .output [0 ], cond_const .output [0 ]], name = node .name ,
647
+ outputs = node .output )
648
+ inner_loop .set_body_graph_as_attr ("body" , g )
649
+
650
+
590
651
@tf_op (["ResizeBilinear" , "ResizeNearestNeighbor" ])
591
652
class Resize :
592
653
@classmethod
0 commit comments