Skip to content

Commit 99e2706

Browse files
Added conversion for DenseBincount (#1401)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 77c6c04 commit 99e2706

File tree

3 files changed

+61
-12
lines changed

3 files changed

+61
-12
lines changed

tests/test_backend.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4041,6 +4041,20 @@ def func(x):
40414041
return y_
40424042
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
40434043

4044+
@skip_tflite("Bug in tflite output shapes")
4045+
@check_opset_min_version(11, "Unique")
4046+
@check_tf_min_version("2.3", "needs tf.math.bincount with axis attr")
4047+
def test_dense_bincount(self):
4048+
x_val = np.array([[5, 2, 3, 1, 3], [2, 7, 5, 9, 10]], dtype=np.int32)
4049+
y_val = np.array([[2.0, 1.5, 3.5, 4.5, 5.5], [6.5, 7.5, 8.5, 9.5, 10.5]], dtype=np.float32)
4050+
for a in [0, -1]:
4051+
for b in [True, False]:
4052+
def func(x, y):
4053+
x_ = tf.math.bincount(x, axis=a, binary_output=b)
4054+
y_ = tf.identity(x_, name=_TFOUTPUT)
4055+
return y_
4056+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val})
4057+
40444058
@check_opset_min_version(11, "ScatterND")
40454059
def test_sparse_to_dense(self):
40464060
i_val = np.array([[0, 0, 0], [0, 0, 2], [0, 1, 3], [1, 2, 2], [1, 2, 3]], dtype=np.int64)

tf2onnx/onnx_opset/reduction.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,8 @@ def any_version(cls, opset, ctx, node, **kwargs):
179179
data_shape = ctx.get_shape(data_inp)
180180
data_rank = len(data_shape) if data_shape is not None else None
181181
data_dtype = ctx.get_dtype(data_inp)
182+
seg_rank = ctx.get_rank(segment_inp)
183+
utils.make_sure(seg_rank == 1, "Segment ops only supported for segments of rank 1, not %s", seg_rank)
182184
data_np_dtype = utils.map_onnx_to_numpy_type(data_dtype)
183185
seg_np_dtype = utils.map_onnx_to_numpy_type(ctx.get_dtype(segment_inp))
184186

tf2onnx/onnx_opset/tensor.py

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2238,13 +2238,13 @@ def version_11(cls, ctx, node, **kwargs):
22382238
ctx.copy_shape(new_node.output[2], cast_node.output[0])
22392239

22402240

2241-
@tf_op("Bincount")
2241+
@tf_op(["Bincount", "DenseBincount"])
22422242
class Bincount:
22432243
@classmethod
22442244
def any_version(cls, opset, ctx, node, **kwargs):
22452245
# arr, size are int32
22462246
arr_inp, size_inp, weights_inp = node.input
2247-
2247+
binary_output = node.get_attr_value("binary_output", False)
22482248
arr_int64 = ctx.make_node("Cast", [arr_inp], attr={'to': TensorProto.INT64}).output[0]
22492249
size_int64 = ctx.make_node("Cast", [size_inp], attr={'to': TensorProto.INT64}).output[0]
22502250

@@ -2253,22 +2253,55 @@ def any_version(cls, opset, ctx, node, **kwargs):
22532253
weights_is_zero = weights_shape is not None and 0 in weights_shape
22542254
utils.make_sure(weights_is_zero, "Non-empty weights not yet supported for bincount")
22552255

2256-
values, _, _, counts = ctx.make_node("Unique", [arr_int64], attr={'sorted': 1}, output_count=4,
2257-
op_name_scope=node.name).output
2256+
if ctx.get_rank(arr_inp) == 2:
2257+
zero_const = ctx.make_const(utils.make_name("zero_const"), np.array(0, np.int64)).output[0]
2258+
one_const = ctx.make_const(utils.make_name("one_const"), np.array(1, np.int64)).output[0]
2259+
inp_shape = ctx.make_node("Shape", [arr_inp]).output[0]
2260+
num_rows = GraphBuilder(ctx).make_slice({"data": inp_shape, "starts": [0], "ends": [1], "axes": [0]})
2261+
num_rows_sq = GraphBuilder(ctx).make_squeeze({"data": num_rows, "axes": [0]})
2262+
row_idx = ctx.make_node("Range", [zero_const, num_rows_sq, one_const]).output[0]
2263+
row_idx_unsq = GraphBuilder(ctx).make_unsqueeze({"data": row_idx, "axes": [1]})
2264+
row_idx_expand = ctx.make_node("Expand", [row_idx_unsq, inp_shape]).output[0]
2265+
arr_int64_unsq = GraphBuilder(ctx).make_unsqueeze({"data": arr_int64, "axes": [2]})
2266+
row_idx_expand_unsq = GraphBuilder(ctx).make_unsqueeze({"data": row_idx_expand, "axes": [2]})
2267+
concat = ctx.make_node("Concat", [row_idx_expand_unsq, arr_int64_unsq], {"axis": 2}).output[0]
2268+
reshape_const = ctx.make_const(utils.make_name("reshape_const"), np.array([-1, 2], np.int64)).output[0]
2269+
reshaped = ctx.make_node("Reshape", [concat, reshape_const]).output[0]
2270+
values, _, _, counts = ctx.make_node("Unique", [reshaped], attr={'sorted': 1, 'axis': 0}, output_count=4,
2271+
op_name_scope=node.name).output
2272+
values_to_check_unsq = GraphBuilder(ctx).make_slice(
2273+
{"data": values, "starts": [1], "ends": [2], "axes": [1]})
2274+
values_to_check = GraphBuilder(ctx).make_squeeze({"data": values_to_check_unsq, "axes": [1]})
2275+
size_unsq = GraphBuilder(ctx).make_unsqueeze({'data': size_int64, "axes": [0]})
2276+
output_shape = ctx.make_node("Concat", [num_rows, size_unsq], attr={"axis": 0}).output[0]
2277+
else:
2278+
values, _, _, counts = ctx.make_node("Unique", [arr_int64], attr={'sorted': 1}, output_count=4,
2279+
op_name_scope=node.name).output
2280+
values_to_check = values
2281+
output_shape = GraphBuilder(ctx).make_unsqueeze({'data': size_int64, "axes": [0]})
2282+
22582283
neg_one_const = ctx.make_const(utils.make_name("neg_one_const"), np.array(-1, np.int64)).output[0]
2259-
non_neg_val_locs = ctx.make_node("Greater", [values, neg_one_const]).output[0]
2260-
small_val_locs = ctx.make_node("Less", [values, size_int64]).output[0]
2284+
non_neg_val_locs = ctx.make_node("Greater", [values_to_check, neg_one_const]).output[0]
2285+
small_val_locs = ctx.make_node("Less", [values_to_check, size_int64]).output[0]
22612286
valid_val_locs = ctx.make_node("And", [non_neg_val_locs, small_val_locs]).output[0]
22622287

22632288
valid_values = ctx.make_node("Compress", [values, valid_val_locs], attr={'axis': 0}).output[0]
2264-
valid_counts = ctx.make_node("Compress", [counts, valid_val_locs], attr={'axis': 0}).output[0]
2265-
2266-
output_shape = GraphBuilder(ctx).make_unsqueeze({'data': size_int64, "axes": [0]})
2289+
if binary_output:
2290+
counts_shape = ctx.make_node("Shape", [valid_values]).output[0]
2291+
counts_shape_1d = GraphBuilder(ctx).make_slice(
2292+
{"data": counts_shape, "starts": [0], "ends": [1], "axes": [0]})
2293+
ones_tensor = helper.make_tensor("value", TensorProto.INT64, dims=[1], vals=[1])
2294+
valid_counts = ctx.make_node("ConstantOfShape", [counts_shape_1d], attr={'value': ones_tensor}).output[0]
2295+
else:
2296+
valid_counts = ctx.make_node("Compress", [counts, valid_val_locs], attr={'axis': 0}).output[0]
22672297

2268-
false_tensor = helper.make_tensor("value", TensorProto.INT64, dims=[1], vals=[0])
2269-
zeros = ctx.make_node("ConstantOfShape", [output_shape], attr={'value': false_tensor}).output[0]
2298+
zero_tensor = helper.make_tensor("value", TensorProto.INT64, dims=[1], vals=[0])
2299+
zeros = ctx.make_node("ConstantOfShape", [output_shape], attr={'value': zero_tensor}).output[0]
22702300

2271-
result = ctx.make_node("ScatterElements", [zeros, valid_values, valid_counts], attr={'axis': 0}).output[0]
2301+
if ctx.get_rank(arr_inp) == 2:
2302+
result = ctx.make_node("ScatterND", [zeros, valid_values, valid_counts]).output[0]
2303+
else:
2304+
result = ctx.make_node("ScatterElements", [zeros, valid_values, valid_counts], attr={'axis': 0}).output[0]
22722305
result_cast = result
22732306
if res_dtype != TensorProto.INT64:
22742307
result_cast = ctx.make_node("Cast", [result], attr={'to': res_dtype}).output[0]

0 commit comments

Comments
 (0)