@@ -192,6 +192,28 @@ def conv_kernel_shape(ctx, node, input_idx, spatial=2):
192
192
return kernel_shape
193
193
194
194
195
+ def build_dynamic_target_size (ctx , transposed_intput , target_hw ):
196
+ """
197
+ Build the target tensor shape for the Resize op.
198
+
199
+ Args:
200
+ - ctx: the graph context
201
+ - transposed_intput: A tensor of rank 4 of shape [n c h w]
202
+ - target_hw: tensor of rank 2 containing the target size for a resize: [nh nw]
203
+
204
+ Returns:
205
+ A tensor of rank 2 containing [n c nh nw]
206
+ """
207
+ # We get the first half [n c] of the target shape
208
+ shape_of_transposed_input = ctx .make_node ("Shape" , [transposed_intput .output [0 ]])
209
+ first_half_of_shape = GraphBuilder (ctx ).make_slice (
210
+ {"data" : shape_of_transposed_input .output [0 ], "ends" : [2 ], "starts" : [0 ]})
211
+ target_size_int64 = ctx .make_node ("Cast" , [target_hw .output [0 ]], attr = {'to' : TensorProto .INT64 })
212
+ # We build a tensor containing [n c nh nw]
213
+ final_target_size = ctx .make_node ("Concat" , [first_half_of_shape , target_size_int64 .output [0 ]], {'axis' : 0 })
214
+ return final_target_size
215
+
216
+
195
217
@tf_op (["Conv1D" , "Conv2D" , "Conv3D" ])
196
218
class ConvOp :
197
219
@classmethod
@@ -594,15 +616,12 @@ def version_11(cls, ctx, node, **kwargs):
594
616
target_x = g .make_node ("Slice" , [input_x .output [0 ], box_index_from .output [0 ], box_index_to .output [0 ],
595
617
const_zero .output [0 ]], name = "Slice_b" )
596
618
transposed_x = g .make_node ("Transpose" , [target_x .output [0 ]], attr = {'perm' : constants .NHWC_TO_NCHW })
597
- shape_of_transposed_x = g .make_node ("Shape" , [transposed_x .output [0 ]])
598
619
const_zero_zero = g .make_const (utils .make_name (node .name + "_const_zero_zero" ),
599
620
np .array ([0 , 0 ], dtype = np .float32 ))
600
621
const_one_one = g .make_const (utils .make_name (node .name + "_const_one_one" ),
601
622
np .array ([1 , 1 ], dtype = np .float32 ))
602
623
const_four = g .make_const (utils .make_name (node .name + "_const_four" ), np .array ([4 ], dtype = np .int64 ))
603
624
const_empty_float = g .make_const (utils .make_name ("const_empty_float" ), np .array ([], dtype = np .float32 ))
604
- first_half_of_shape = GraphBuilder (g ).make_slice (
605
- {"data" : shape_of_transposed_x .output [0 ], "ends" : [2 ], "starts" : [0 ]})
606
625
box = g .make_node ("Slice" , [boxes .output [0 ], trip_name , index_end .output [0 ], const_zero_long .output [0 ]],
607
626
name = "Slice_c" )
608
627
roi_raw = g .make_node ("Reshape" , [box .output [0 ], const_four .output [0 ]])
@@ -611,8 +630,7 @@ def version_11(cls, ctx, node, **kwargs):
611
630
roi_concat_1 = g .make_node ("Concat" , [const_zero_zero .output [0 ], roi_raw_first_half ], attr = {'axis' : 0 })
612
631
roi_concat_2 = g .make_node ("Concat" , [const_one_one .output [0 ], roi_raw_second_half ], attr = {'axis' : 0 })
613
632
final_roi = g .make_node ("Concat" , [roi_concat_1 .output [0 ], roi_concat_2 .output [0 ]], attr = {'axis' : 0 })
614
- crop_size_int64 = g .make_node ("Cast" , [crop_size .output [0 ]], attr = {'to' : TensorProto .INT64 })
615
- final_crop_size = g .make_node ("Concat" , [first_half_of_shape , crop_size_int64 .output [0 ]], {'axis' : 0 })
633
+ final_crop_size = build_dynamic_target_size (g , transposed_x , crop_size )
616
634
resized_x = g .make_node ("Resize" , [transposed_x .output [0 ], final_roi .output [0 ], const_empty_float .output [0 ],
617
635
final_crop_size .output [0 ]],
618
636
attr = {"mode" : mode , "extrapolation_value" : extrapolation_value ,
@@ -661,50 +679,52 @@ def version_10(cls, ctx, node, **kwargs):
661
679
662
680
@classmethod
663
681
def version_11 (cls , ctx , node , ** kwargs ):
664
- cls ._convert_since_9 (ctx , node , op_type = "Resize" , roi_required = True )
682
+ cls ._convert_since_9 (ctx , node , op_type = "Resize" , use_target_size = True )
665
683
666
684
@classmethod
667
- def _convert_since_9 (cls , ctx , node , op_type , roi_required = False ):
685
+ def _convert_since_9 (cls , ctx , node , op_type , use_target_size = False ):
668
686
669
687
# float32 out = ResizeBilinear/ResizeNearestNeighbor(T images, int size)
670
688
# https://www.tensorflow.org/api_docs/python/tf/image/resize_nearest_neighbor
671
689
# wants the input to be NHWC - adjust target_shape to this.
672
690
mode = "linear" if node .type == "ResizeBilinear" else "nearest"
673
691
674
- # first create "scales" info for onnx upsample
675
- # if shape of input and output known then "scale" is calculated statically and set as a const node
676
- shape = ctx .get_shape (node .input [0 ])
677
- if shape and shape [2 ] != - 1 and shape [1 ] != - 1 and node .inputs [1 ].is_const ():
678
- target_shape = node .inputs [1 ].get_tensor_value ()
679
- n , h , w , c = shape
680
- nh , nw = target_shape
681
- # scales is nchw
682
- # the reason not storing data at raw field is because of the bug: https://github.com/onnx/onnx/issues/1852
683
- scale_val = np .array ([1.0 , 1.0 , float (nh ) / h , float (nw ) / w ]).astype (np .float32 )
684
- scales = ctx .make_const (utils .make_name ("scales" ), scale_val , raw = False )
685
- else :
686
- ori_shape = ctx .make_node ("Shape" , [node .input [0 ]])
687
- attr = {"axes" : [0 ], "starts" : [1 ], "ends" : [3 ]}
688
- inputs_map = {"data" : ori_shape .output [0 ], ** attr }
689
- ori_shape_hw = GraphBuilder (ctx ).make_slice (inputs_map )
690
- ori_shape_hw_float = ctx .make_node ("Cast" , [ori_shape_hw ], attr = {"to" : onnx_pb .TensorProto .FLOAT })
691
-
692
- target_hw = node .inputs [1 ]
693
- target_hw_float = ctx .make_node ("Cast" , target_hw .output , attr = {"to" : onnx_pb .TensorProto .FLOAT })
694
-
695
- scales_hw = ctx .make_node ("Div" , [target_hw_float .output [0 ], ori_shape_hw_float .output [0 ]])
696
-
697
- const_one_array = ctx .make_const (utils .make_name ("one" ), np .array ([1.0 , 1.0 ]).astype (np .float32 ))
698
- # scales is nchw
699
- scales = ctx .make_node ("Concat" , [const_one_array .output [0 ], scales_hw .output [0 ]], {"axis" : 0 })
700
692
# because onnxruntime only supports to scale the last two dims so transpose is inserted
701
693
input_nchw = ctx .make_node ("Transpose" , [node .input [0 ]], {"perm" : constants .NHWC_TO_NCHW })
702
- if roi_required :
694
+ if use_target_size :
695
+ final_target_size = build_dynamic_target_size (ctx , input_nchw , node .inputs [1 ])
703
696
roi = ctx .make_const (utils .make_name ("roi" ), np .array ([]).astype (np .float32 ))
704
- upsample = ctx .make_node ("Resize" , [input_nchw .output [0 ], roi .output [0 ], scales .output [0 ]],
697
+ const_empty_float = ctx .make_const (utils .make_name ("const_empty_float" ), np .array ([], dtype = np .float32 ))
698
+ upsample = ctx .make_node ("Resize" , [input_nchw .output [0 ], roi .output [0 ], const_empty_float .output [0 ], final_target_size .output [0 ]],
705
699
attr = {"mode" : mode , "nearest_mode" : "floor" ,
706
700
"coordinate_transformation_mode" : "asymmetric" })
707
701
else :
702
+ # first create "scales" info for onnx upsample
703
+ # if shape of input and output known then "scale" is calculated statically and set as a const node
704
+ shape = ctx .get_shape (node .input [0 ])
705
+ if shape and shape [2 ] != - 1 and shape [1 ] != - 1 and node .inputs [1 ].is_const ():
706
+ target_shape = node .inputs [1 ].get_tensor_value ()
707
+ n , h , w , c = shape
708
+ nh , nw = target_shape
709
+ # scales is nchw
710
+ # the reason not storing data at raw field is because of the bug: https://github.com/onnx/onnx/issues/1852
711
+ scale_val = np .array ([1.0 , 1.0 , float (nh ) / h , float (nw ) / w ]).astype (np .float32 )
712
+ scales = ctx .make_const (utils .make_name ("scales" ), scale_val , raw = False )
713
+ else :
714
+ ori_shape = ctx .make_node ("Shape" , [node .input [0 ]])
715
+ attr = {"axes" : [0 ], "starts" : [1 ], "ends" : [3 ]}
716
+ inputs_map = {"data" : ori_shape .output [0 ], ** attr }
717
+ ori_shape_hw = GraphBuilder (ctx ).make_slice (inputs_map )
718
+ ori_shape_hw_float = ctx .make_node ("Cast" , [ori_shape_hw ], attr = {"to" : onnx_pb .TensorProto .FLOAT })
719
+
720
+ target_hw = node .inputs [1 ]
721
+ target_hw_float = ctx .make_node ("Cast" , target_hw .output , attr = {"to" : onnx_pb .TensorProto .FLOAT })
722
+
723
+ scales_hw = ctx .make_node ("Div" , [target_hw_float .output [0 ], ori_shape_hw_float .output [0 ]])
724
+
725
+ const_one_array = ctx .make_const (utils .make_name ("one" ), np .array ([1.0 , 1.0 ]).astype (np .float32 ))
726
+ # scales is nchw
727
+ scales = ctx .make_node ("Concat" , [const_one_array .output [0 ], scales_hw .output [0 ]], {"axis" : 0 })
708
728
upsample = ctx .make_node (op_type , [input_nchw .output [0 ], scales .output [0 ]], attr = {"mode" : mode })
709
729
710
730
shapes = node .output_shapes
0 commit comments