@@ -1452,7 +1452,8 @@ def version_10(cls, ctx, node, **kwargs):
1452
1452
raise ValueError ("dtype " + str (node_dtype ) + " is not supported in onnx for now" )
1453
1453
1454
1454
1455
- @tf_op (["NonMaxSuppressionV2" , "NonMaxSuppressionV3" ], onnx_op = "NonMaxSuppression" )
1455
+ @tf_op (["NonMaxSuppressionV2" , "NonMaxSuppressionV3" , "NonMaxSuppressionV4" , "NonMaxSuppressionV5" ],
1456
+ onnx_op = "NonMaxSuppression" )
1456
1457
class NonMaxSuppression :
1457
1458
@classmethod
1458
1459
def version_10 (cls , ctx , node , ** kwargs ):
@@ -1464,18 +1465,41 @@ def version_10(cls, ctx, node, **kwargs):
1464
1465
# onnx output is [num_selected_boxes, 3], the meaning of last dim is [batch_index, class_index, box_index]
1465
1466
# while tf's output is [num_selected_boxes]
1466
1467
ctx .insert_new_node_on_input (node , "Unsqueeze" , node .input [0 ], axes = [0 ])
1467
- ctx .insert_new_node_on_input (node , "Unsqueeze" , node .input [1 ], axes = [0 , 1 ])
1468
+ input_score = ctx .insert_new_node_on_input (node , "Unsqueeze" , node .input [1 ], axes = [0 , 1 ])
1468
1469
ctx .insert_new_node_on_input (node , "Cast" , node .input [2 ], to = onnx_pb .TensorProto .INT64 )
1469
1470
# replace original node with nonmaxsurppress + slice + squeeze + cast
1470
- dtypes = [ctx .get_dtype (node .output [0 ])]
1471
- shapes = [ctx .get_shape (node .output [0 ])]
1471
+ dtypes = [[ctx .get_dtype (output )] for output in node .output ]
1472
+ shapes = [[ctx .get_shape (output )] for output in node .output ]
1473
+ max_output_size = node .input [2 ]
1474
+ utils .make_sure (len (node .inputs ) <= 5 or int (node .inputs [5 ].get_tensor_value (False )) == 0 ,
1475
+ "soft_nms_sigma must be 0" )
1472
1476
ctx .remove_node (node .name )
1473
- new_nonmaxsurppress = ctx .make_node (node .type , node .input ).output [0 ]
1477
+ new_nonmaxsurppress = ctx .make_node (node .type , node .input [: 5 ] ).output [0 ]
1474
1478
slice_op = GraphBuilder (ctx ).make_slice ({"data" : new_nonmaxsurppress ,
1475
1479
"axes" : [1 ], "ends" : [3 ], "starts" : [2 ]})
1476
1480
squeeze_op = ctx .make_node ("Squeeze" , [slice_op ], attr = {"axes" : [1 ]})
1477
- ctx .make_node ("Cast" , inputs = squeeze_op .output , attr = {"to" : onnx_pb .TensorProto .INT32 },
1478
- name = node .name , outputs = node .output , dtypes = dtypes , shapes = shapes )
1481
+ if len (node .input ) > 5 : # v5, called by ..._with_scores(), pad_to_max_output_size always False
1482
+ ctx .make_node ("Cast" , inputs = squeeze_op .output , attr = {"to" : onnx_pb .TensorProto .INT32 },
1483
+ outputs = [node .output [0 ]], dtypes = dtypes [0 ], shapes = shapes [0 ])
1484
+ ctx .make_node ("Gather" , inputs = [input_score .input [0 ], squeeze_op .output [0 ]],
1485
+ outputs = [node .output [1 ]], dtypes = dtypes [1 ], shapes = shapes [1 ])
1486
+ elif "pad_to_max_output_size" in node .attr : # V4
1487
+ shape_op = ctx .make_node ("Shape" , inputs = [squeeze_op .output [0 ]])
1488
+ const_zero = ctx .make_const (utils .make_name ("const_zero" ), np .array ([0 ], dtype = np .int64 ))
1489
+ sub_op = ctx .make_node ("Sub" , inputs = [max_output_size , shape_op .output [0 ]])
1490
+ raw_pad = ctx .make_node ("Concat" , inputs = [const_zero .output [0 ], sub_op .output [0 ]], attr = {'axis' : 0 })
1491
+ raw_pad_float = ctx .make_node ("Cast" , inputs = [raw_pad .output [0 ]], attr = {"to" : onnx_pb .TensorProto .FLOAT })
1492
+ relu_op = ctx .make_node ("Relu" , inputs = [raw_pad_float .output [0 ]])
1493
+ pad_val = ctx .make_node ("Cast" , inputs = [relu_op .output [0 ]], attr = {"to" : onnx_pb .TensorProto .INT64 })
1494
+ pad_op = ctx .make_node ("Pad" , inputs = [squeeze_op .output [0 ], pad_val .output [0 ]])
1495
+ ctx .make_node ("Cast" , inputs = pad_op .output , name = "cast_A" , attr = {"to" : onnx_pb .TensorProto .INT32 },
1496
+ outputs = [node .output [0 ]], dtypes = dtypes [0 ], shapes = shapes [0 ])
1497
+ reduce_op = ctx .make_node ("ReduceSum" , inputs = shape_op .output , attr = {"axes" : [0 ], "keepdims" : 0 })
1498
+ ctx .make_node ("Cast" , inputs = [reduce_op .output [0 ]], name = "cast_B" , attr = {"to" : onnx_pb .TensorProto .INT32 },
1499
+ outputs = [node .output [1 ]], dtypes = dtypes [1 ], shapes = shapes [1 ])
1500
+ else :
1501
+ ctx .make_node ("Cast" , inputs = squeeze_op .output , attr = {"to" : onnx_pb .TensorProto .INT32 },
1502
+ name = node .name , outputs = node .output , dtypes = dtypes [0 ], shapes = shapes [0 ])
1479
1503
1480
1504
@classmethod
1481
1505
def version_11 (cls , ctx , node , ** kwargs ):
0 commit comments