@@ -1452,7 +1452,8 @@ def version_10(cls, ctx, node, **kwargs):
1452
1452
raise ValueError ("dtype " + str (node_dtype ) + " is not supported in onnx for now" )
1453
1453
1454
1454
1455
- @tf_op (["NonMaxSuppressionV2" , "NonMaxSuppressionV3" ], onnx_op = "NonMaxSuppression" )
1455
+ @tf_op (["NonMaxSuppressionV2" , "NonMaxSuppressionV3" , "NonMaxSuppressionV4" , "NonMaxSuppressionV5" ],
1456
+ onnx_op = "NonMaxSuppression" )
1456
1457
class NonMaxSuppression :
1457
1458
@classmethod
1458
1459
def version_10 (cls , ctx , node , ** kwargs ):
@@ -1464,18 +1465,40 @@ def version_10(cls, ctx, node, **kwargs):
1464
1465
# onnx output is [num_selected_boxes, 3], the meaning of last dim is [batch_index, class_index, box_index]
1465
1466
# while tf's output is [num_selected_boxes]
1466
1467
ctx .insert_new_node_on_input (node , "Unsqueeze" , node .input [0 ], axes = [0 ])
1467
- ctx .insert_new_node_on_input (node , "Unsqueeze" , node .input [1 ], axes = [0 , 1 ])
1468
+ input_score = ctx .insert_new_node_on_input (node , "Unsqueeze" , node .input [1 ], axes = [0 , 1 ])
1468
1469
ctx .insert_new_node_on_input (node , "Cast" , node .input [2 ], to = onnx_pb .TensorProto .INT64 )
1469
1470
# replace original node with nonmaxsurppress + slice + squeeze + cast
1470
- dtypes = [ctx .get_dtype (node .output [0 ])]
1471
- shapes = [ctx .get_shape (node .output [0 ])]
1471
+ dtypes = [[ctx .get_dtype (output )] for output in node .output ]
1472
+ shapes = [[ctx .get_shape (output )] for output in node .output ]
1473
+ max_output_size = node .input [2 ]
1472
1474
ctx .remove_node (node .name )
1473
- new_nonmaxsurppress = ctx .make_node (node .type , node .input ).output [0 ]
1475
+ new_nonmaxsurppress = ctx .make_node (node .type , node .input [: 5 ] ).output [0 ]
1474
1476
slice_op = GraphBuilder (ctx ).make_slice ({"data" : new_nonmaxsurppress ,
1475
1477
"axes" : [1 ], "ends" : [3 ], "starts" : [2 ]})
1476
1478
squeeze_op = ctx .make_node ("Squeeze" , [slice_op ], attr = {"axes" : [1 ]})
1477
- ctx .make_node ("Cast" , inputs = squeeze_op .output , attr = {"to" : onnx_pb .TensorProto .INT32 },
1478
- name = node .name , outputs = node .output , dtypes = dtypes , shapes = shapes )
1479
+ if len (node .input ) > 5 : # V5
1480
+ logger .warning ("NonMaxSuppressionV5 only parltially supported, soft_nms_sigma must be 0.0" )
1481
+ ctx .make_node ("Cast" , inputs = squeeze_op .output , attr = {"to" : onnx_pb .TensorProto .INT32 },
1482
+ outputs = [node .output [0 ]], dtypes = dtypes [0 ], shapes = shapes [0 ])
1483
+ ctx .make_node ("Gather" , inputs = [input_score .input [0 ], squeeze_op .output [0 ]],
1484
+ outputs = [node .output [1 ]], dtypes = dtypes [1 ], shapes = shapes [1 ])
1485
+ elif "pad_to_max_output_size" in node .attr : # V4
1486
+ shape_op = ctx .make_node ("Shape" , inputs = [squeeze_op .output [0 ]])
1487
+ const_zero = ctx .make_const (utils .make_name ("const_zero" ), np .array ([0 ], dtype = np .int64 ))
1488
+ sub_op = ctx .make_node ("Sub" , inputs = [max_output_size , shape_op .output [0 ]])
1489
+ raw_pad = ctx .make_node ("Concat" , inputs = [const_zero .output [0 ], sub_op .output [0 ]], attr = {'axis' : 0 })
1490
+ raw_pad_float = ctx .make_node ("Cast" , inputs = [raw_pad .output [0 ]], attr = {"to" : onnx_pb .TensorProto .FLOAT })
1491
+ relu_op = ctx .make_node ("Relu" , inputs = [raw_pad_float .output [0 ]])
1492
+ pad_val = ctx .make_node ("Cast" , inputs = [relu_op .output [0 ]], attr = {"to" : onnx_pb .TensorProto .INT64 })
1493
+ pad_op = ctx .make_node ("Pad" , inputs = [squeeze_op .output [0 ], pad_val .output [0 ]])
1494
+ ctx .make_node ("Cast" , inputs = pad_op .output , name = "cast_A" , attr = {"to" : onnx_pb .TensorProto .INT32 },
1495
+ outputs = [node .output [0 ]], dtypes = dtypes [0 ], shapes = shapes [0 ])
1496
+ reduce_op = ctx .make_node ("ReduceSum" , inputs = shape_op .output , attr = {"axes" : [0 ], "keepdims" : 0 })
1497
+ ctx .make_node ("Cast" , inputs = [reduce_op .output [0 ]], name = "cast_B" , attr = {"to" : onnx_pb .TensorProto .INT32 },
1498
+ outputs = [node .output [1 ]], dtypes = dtypes [1 ], shapes = shapes [1 ])
1499
+ else :
1500
+ ctx .make_node ("Cast" , inputs = squeeze_op .output , attr = {"to" : onnx_pb .TensorProto .INT32 },
1501
+ name = node .name , outputs = node .output , dtypes = dtypes [0 ], shapes = shapes [0 ])
1479
1502
1480
1503
@classmethod
1481
1504
def version_11 (cls , ctx , node , ** kwargs ):
0 commit comments