@@ -2238,13 +2238,13 @@ def version_11(cls, ctx, node, **kwargs):
2238
2238
ctx .copy_shape (new_node .output [2 ], cast_node .output [0 ])
2239
2239
2240
2240
2241
- @tf_op ("Bincount" )
2241
+ @tf_op ([ "Bincount" , "DenseBincount" ] )
2242
2242
class Bincount :
2243
2243
@classmethod
2244
2244
def any_version (cls , opset , ctx , node , ** kwargs ):
2245
2245
# arr, size are int32
2246
2246
arr_inp , size_inp , weights_inp = node .input
2247
-
2247
+ binary_output = node . get_attr_value ( "binary_output" , False )
2248
2248
arr_int64 = ctx .make_node ("Cast" , [arr_inp ], attr = {'to' : TensorProto .INT64 }).output [0 ]
2249
2249
size_int64 = ctx .make_node ("Cast" , [size_inp ], attr = {'to' : TensorProto .INT64 }).output [0 ]
2250
2250
@@ -2253,22 +2253,55 @@ def any_version(cls, opset, ctx, node, **kwargs):
2253
2253
weights_is_zero = weights_shape is not None and 0 in weights_shape
2254
2254
utils .make_sure (weights_is_zero , "Non-empty weights not yet supported for bincount" )
2255
2255
2256
- values , _ , _ , counts = ctx .make_node ("Unique" , [arr_int64 ], attr = {'sorted' : 1 }, output_count = 4 ,
2257
- op_name_scope = node .name ).output
2256
+ if ctx .get_rank (arr_inp ) == 2 :
2257
+ zero_const = ctx .make_const (utils .make_name ("zero_const" ), np .array (0 , np .int64 )).output [0 ]
2258
+ one_const = ctx .make_const (utils .make_name ("one_const" ), np .array (1 , np .int64 )).output [0 ]
2259
+ inp_shape = ctx .make_node ("Shape" , [arr_inp ]).output [0 ]
2260
+ num_rows = GraphBuilder (ctx ).make_slice ({"data" : inp_shape , "starts" : [0 ], "ends" : [1 ], "axes" : [0 ]})
2261
+ num_rows_sq = GraphBuilder (ctx ).make_squeeze ({"data" : num_rows , "axes" : [0 ]})
2262
+ row_idx = ctx .make_node ("Range" , [zero_const , num_rows_sq , one_const ]).output [0 ]
2263
+ row_idx_unsq = GraphBuilder (ctx ).make_unsqueeze ({"data" : row_idx , "axes" : [1 ]})
2264
+ row_idx_expand = ctx .make_node ("Expand" , [row_idx_unsq , inp_shape ]).output [0 ]
2265
+ arr_int64_unsq = GraphBuilder (ctx ).make_unsqueeze ({"data" : arr_int64 , "axes" : [2 ]})
2266
+ row_idx_expand_unsq = GraphBuilder (ctx ).make_unsqueeze ({"data" : row_idx_expand , "axes" : [2 ]})
2267
+ concat = ctx .make_node ("Concat" , [row_idx_expand_unsq , arr_int64_unsq ], {"axis" : 2 }).output [0 ]
2268
+ reshape_const = ctx .make_const (utils .make_name ("reshape_const" ), np .array ([- 1 , 2 ], np .int64 )).output [0 ]
2269
+ reshaped = ctx .make_node ("Reshape" , [concat , reshape_const ]).output [0 ]
2270
+ values , _ , _ , counts = ctx .make_node ("Unique" , [reshaped ], attr = {'sorted' : 1 , 'axis' : 0 }, output_count = 4 ,
2271
+ op_name_scope = node .name ).output
2272
+ values_to_check_unsq = GraphBuilder (ctx ).make_slice (
2273
+ {"data" : values , "starts" : [1 ], "ends" : [2 ], "axes" : [1 ]})
2274
+ values_to_check = GraphBuilder (ctx ).make_squeeze ({"data" : values_to_check_unsq , "axes" : [1 ]})
2275
+ size_unsq = GraphBuilder (ctx ).make_unsqueeze ({'data' : size_int64 , "axes" : [0 ]})
2276
+ output_shape = ctx .make_node ("Concat" , [num_rows , size_unsq ], attr = {"axis" : 0 }).output [0 ]
2277
+ else :
2278
+ values , _ , _ , counts = ctx .make_node ("Unique" , [arr_int64 ], attr = {'sorted' : 1 }, output_count = 4 ,
2279
+ op_name_scope = node .name ).output
2280
+ values_to_check = values
2281
+ output_shape = GraphBuilder (ctx ).make_unsqueeze ({'data' : size_int64 , "axes" : [0 ]})
2282
+
2258
2283
neg_one_const = ctx .make_const (utils .make_name ("neg_one_const" ), np .array (- 1 , np .int64 )).output [0 ]
2259
- non_neg_val_locs = ctx .make_node ("Greater" , [values , neg_one_const ]).output [0 ]
2260
- small_val_locs = ctx .make_node ("Less" , [values , size_int64 ]).output [0 ]
2284
+ non_neg_val_locs = ctx .make_node ("Greater" , [values_to_check , neg_one_const ]).output [0 ]
2285
+ small_val_locs = ctx .make_node ("Less" , [values_to_check , size_int64 ]).output [0 ]
2261
2286
valid_val_locs = ctx .make_node ("And" , [non_neg_val_locs , small_val_locs ]).output [0 ]
2262
2287
2263
2288
valid_values = ctx .make_node ("Compress" , [values , valid_val_locs ], attr = {'axis' : 0 }).output [0 ]
2264
- valid_counts = ctx .make_node ("Compress" , [counts , valid_val_locs ], attr = {'axis' : 0 }).output [0 ]
2265
-
2266
- output_shape = GraphBuilder (ctx ).make_unsqueeze ({'data' : size_int64 , "axes" : [0 ]})
2289
+ if binary_output :
2290
+ counts_shape = ctx .make_node ("Shape" , [valid_values ]).output [0 ]
2291
+ counts_shape_1d = GraphBuilder (ctx ).make_slice (
2292
+ {"data" : counts_shape , "starts" : [0 ], "ends" : [1 ], "axes" : [0 ]})
2293
+ ones_tensor = helper .make_tensor ("value" , TensorProto .INT64 , dims = [1 ], vals = [1 ])
2294
+ valid_counts = ctx .make_node ("ConstantOfShape" , [counts_shape_1d ], attr = {'value' : ones_tensor }).output [0 ]
2295
+ else :
2296
+ valid_counts = ctx .make_node ("Compress" , [counts , valid_val_locs ], attr = {'axis' : 0 }).output [0 ]
2267
2297
2268
- false_tensor = helper .make_tensor ("value" , TensorProto .INT64 , dims = [1 ], vals = [0 ])
2269
- zeros = ctx .make_node ("ConstantOfShape" , [output_shape ], attr = {'value' : false_tensor }).output [0 ]
2298
+ zero_tensor = helper .make_tensor ("value" , TensorProto .INT64 , dims = [1 ], vals = [0 ])
2299
+ zeros = ctx .make_node ("ConstantOfShape" , [output_shape ], attr = {'value' : zero_tensor }).output [0 ]
2270
2300
2271
- result = ctx .make_node ("ScatterElements" , [zeros , valid_values , valid_counts ], attr = {'axis' : 0 }).output [0 ]
2301
+ if ctx .get_rank (arr_inp ) == 2 :
2302
+ result = ctx .make_node ("ScatterND" , [zeros , valid_values , valid_counts ]).output [0 ]
2303
+ else :
2304
+ result = ctx .make_node ("ScatterElements" , [zeros , valid_values , valid_counts ], attr = {'axis' : 0 }).output [0 ]
2272
2305
result_cast = result
2273
2306
if res_dtype != TensorProto .INT64 :
2274
2307
result_cast = ctx .make_node ("Cast" , [result ], attr = {'to' : res_dtype }).output [0 ]
0 commit comments