Skip to content

Commit 9223376

Browse files
Implement conversion of set intersection, union, difference (#1556)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent f8a6a86 commit 9223376

File tree

2 files changed

+132
-0
lines changed

2 files changed

+132
-0
lines changed

tests/test_backend.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1772,6 +1772,49 @@ def func(data, indices, segments):
17721772
return tf.identity(x_, name=_TFOUTPUT)
17731773
self._run_test_case(func, [_OUTPUT], {_INPUT: data_val, _INPUT1: indices_val, _INPUT2: segs_val})
17741774

1775+
@check_opset_min_version(11, "CumSum")
1776+
@check_tf_min_version("1.14")
1777+
def test_set_union(self):
1778+
a_val = np.array([[10, 2, 30, 2, 5], [10, 9, 1, 9, 3]], np.int32)
1779+
b_val = np.array([[4, 5, 10, 8, 9], [1, 4, 1, 1, 5]], np.int32)
1780+
def func(a, b):
1781+
s = tf.sets.union(a, b)
1782+
indices, values, shape = s.indices, s.values, s.dense_shape
1783+
indices = tf.identity(indices, name=_TFOUTPUT)
1784+
values = tf.identity(values, name=_TFOUTPUT1)
1785+
shape = tf.identity(shape, name=_TFOUTPUT2)
1786+
return indices, values, shape
1787+
self._run_test_case(func, [_OUTPUT, _OUTPUT1, _OUTPUT2], {_INPUT: a_val, _INPUT1: b_val})
1788+
1789+
@check_opset_min_version(11, "CumSum")
1790+
@check_tf_min_version("1.14")
1791+
def test_set_intersection(self):
1792+
a_val = np.array([[10, 2, 30, 2, 5], [10, 9, 1, 9, 3]], np.int32)
1793+
b_val = np.array([[4, 5, 10, 8, 9], [1, 4, 1, 1, 5]], np.int32)
1794+
def func(a, b):
1795+
s = tf.sets.intersection(a, b)
1796+
indices, values, shape = s.indices, s.values, s.dense_shape
1797+
indices = tf.identity(indices, name=_TFOUTPUT)
1798+
values = tf.identity(values, name=_TFOUTPUT1)
1799+
shape = tf.identity(shape, name=_TFOUTPUT2)
1800+
return indices, values, shape
1801+
self._run_test_case(func, [_OUTPUT, _OUTPUT1, _OUTPUT2], {_INPUT: a_val, _INPUT1: b_val})
1802+
1803+
@check_opset_min_version(11, "CumSum")
1804+
@check_tf_min_version("1.14")
1805+
def test_set_difference(self):
1806+
a_val = np.array([[10, 2, 30, 2, 5], [10, 9, 1, 9, 3]], np.int32)
1807+
b_val = np.array([[4, 5, 10, 8, 9], [1, 4, 1, 1, 5]], np.int32)
1808+
for aminusb in [True, False]:
1809+
def func(a, b):
1810+
s = tf.sets.difference(a, b, aminusb)
1811+
indices, values, shape = s.indices, s.values, s.dense_shape
1812+
indices = tf.identity(indices, name=_TFOUTPUT)
1813+
values = tf.identity(values, name=_TFOUTPUT1)
1814+
shape = tf.identity(shape, name=_TFOUTPUT2)
1815+
return indices, values, shape
1816+
self._run_test_case(func, [_OUTPUT, _OUTPUT1, _OUTPUT2], {_INPUT: a_val, _INPUT1: b_val})
1817+
17751818
@check_onnxruntime_incompatibility("Sqrt")
17761819
def test_sqrt(self):
17771820
x_val = np.array([4.0, 16.0, 4.0, 1.6], dtype=np.float32).reshape((2, 2))

tf2onnx/onnx_opset/tensor.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2768,6 +2768,95 @@ def version_13(cls, ctx, node, **kwargs):
27682768
# Parameters moved to inputs for operator Squeeze, Unsqueeze.
27692769
cls.any_version(13, ctx, node, **kwargs)
27702770

2771+
@tf_op("DenseToDenseSetOperation")
2772+
class DenseToDenseSetOperation:
2773+
@classmethod
2774+
def version_11(cls, ctx, node, **kwargs):
2775+
inp_a, inp_b = node.input
2776+
dtype = ctx.get_dtype(node.output[1])
2777+
if dtype != TensorProto.INT64:
2778+
inp_a = ctx.make_node("Cast", [inp_a], attr={'to': TensorProto.INT64}).output[0]
2779+
inp_b = ctx.make_node("Cast", [inp_b], attr={'to': TensorProto.INT64}).output[0]
2780+
set_op = node.get_attr_value('set_operation')
2781+
if set_op == b'b-a':
2782+
set_op = b'a-b'
2783+
inp_a, inp_b = inp_b, inp_a
2784+
2785+
one_tensor = helper.make_tensor("value", TensorProto.INT64, dims=[1], vals=[1])
2786+
const_one = ctx.make_const(utils.make_name("const_one"), np.array(1, np.int64)).output[0]
2787+
const_zero = ctx.make_const(utils.make_name("const_zero"), np.array(0, np.int64)).output[0]
2788+
const_zero_unsq = ctx.make_const(utils.make_name("const_zero"), np.array([0], np.int64)).output[0]
2789+
const_neg_one_unsq = ctx.make_const(utils.make_name("const_neg_one"), np.array([-1], np.int64)).output[0]
2790+
max_int64 = int(utils.get_max_value(np.int64))
2791+
const_two = ctx.make_const(utils.make_name("const_two"), np.array(2, np.int64)).output[0]
2792+
2793+
def concat_indices(tensor):
2794+
shape = ctx.make_node("Shape", [tensor]).output[0]
2795+
tensor_flat = ctx.make_node("Reshape", [tensor, const_neg_one_unsq]).output[0]
2796+
tensor_flat_unsq = GraphBuilder(ctx).make_unsqueeze({'data': tensor_flat, 'axes': [1]})
2797+
ones_of_shape = ctx.make_node("ConstantOfShape", [shape], attr={'value': one_tensor}).output[0]
2798+
indices = ctx.make_node("NonZero", [ones_of_shape]).output[0]
2799+
sliced_indices = GraphBuilder(ctx).make_slice({'data': indices, 'starts': [0], 'ends': [-1], 'axes': [0]})
2800+
sliced_indices_trans = ctx.make_node("Transpose", [sliced_indices], attr={'perm': [1, 0]}).output[0]
2801+
return ctx.make_node("Concat", [sliced_indices_trans, tensor_flat_unsq], attr={'axis': 1}).output[0]
2802+
2803+
if set_op == b'union':
2804+
combined = ctx.make_node("Concat", [inp_a, inp_b], attr={'axis': -1}).output[0]
2805+
shape = ctx.make_node("Shape", [combined]).output[0]
2806+
shape_prefix = GraphBuilder(ctx).make_slice({'data': shape, 'starts': [0], 'ends': [-1], 'axes': [0]})
2807+
indices_and_vals = concat_indices(combined)
2808+
res_idx_and_vals = ctx.make_node("Unique", [indices_and_vals], attr={'axis': 0}).output[0]
2809+
else:
2810+
shape = ctx.make_node("Shape", [inp_a]).output[0]
2811+
shape_prefix = GraphBuilder(ctx).make_slice({'data': shape, 'starts': [0], 'ends': [-1], 'axes': [0]})
2812+
a_idx_and_vals = concat_indices(inp_a)
2813+
b_idx_and_vals = concat_indices(inp_b)
2814+
a_unique = ctx.make_node("Unique", [a_idx_and_vals], attr={'axis': 0}).output[0]
2815+
b_unique = ctx.make_node("Unique", [b_idx_and_vals], attr={'axis': 0}).output[0]
2816+
if set_op == b'intersection':
2817+
combined = ctx.make_node("Concat", [a_unique, b_unique], attr={'axis': 0}).output[0]
2818+
desired_cnt = const_two
2819+
else:
2820+
utils.make_sure(set_op == b'a-b', "Unsupported set operation: %s", set_op)
2821+
combined = ctx.make_node("Concat", [a_unique, b_unique, b_unique], attr={'axis': 0}).output[0]
2822+
# cnt will be 1 if and only if element is in only set A
2823+
desired_cnt = const_one
2824+
unique_rows, _, _, row_cnts = ctx.make_node("Unique", [combined], attr={'axis': 0}, output_count=4).output
2825+
keep = ctx.make_node("Equal", [row_cnts, desired_cnt]).output[0]
2826+
compress_shape = None
2827+
rows_shape = ctx.get_shape(unique_rows)
2828+
if rows_shape is not None:
2829+
compress_shape = rows_shape.copy()
2830+
compress_shape[0] = -1
2831+
res_idx_and_vals = ctx.make_node("Compress", [unique_rows, keep], attr={'axis': 0},
2832+
shapes=[compress_shape]).output[0]
2833+
2834+
merged_indices = GraphBuilder(ctx).make_slice(
2835+
{'data': res_idx_and_vals, 'starts': [0], 'ends': [-1], 'axes': [1]})
2836+
merged_values = GraphBuilder(ctx).make_slice(
2837+
{'data': res_idx_and_vals, 'starts': [-1], 'ends': [max_int64], 'axes': [1]})
2838+
merged_values_sq = GraphBuilder(ctx).make_squeeze({'data': merged_values, 'axes': [1]})
2839+
merged_values_sq_cast = ctx.make_node("Cast", [merged_values_sq], attr={'to': dtype}).output[0]
2840+
2841+
_, idx_loc, _, idx_cnts, = ctx.make_node("Unique", [merged_indices], attr={'axis': 0},
2842+
output_count=4, op_name_scope=node.name).output
2843+
2844+
max_cnt = ctx.make_node("ReduceMax", [idx_cnts], attr={'axes': [0], 'keepdims': True}).output[0]
2845+
final_shape = ctx.make_node("Concat", [shape_prefix, max_cnt], attr={'axis': 0}).output[0]
2846+
one_minus_cnts = ctx.make_node("Sub", [const_one, idx_cnts]).output[0]
2847+
cnts_sliced = GraphBuilder(ctx).make_slice(
2848+
{"data": one_minus_cnts, "starts": [0], "ends": [-1], "axes": [0]})
2849+
cnts_shifted = ctx.make_node("Concat", [const_zero_unsq, cnts_sliced], attr={'axis': 0}).output[0]
2850+
values_shape = ctx.make_node("Shape", [merged_values_sq_cast]).output[0]
2851+
ones_of_shape = ctx.make_node("ConstantOfShape", [values_shape], attr={'value': one_tensor}).output[0]
2852+
idx_deltas = ctx.make_node("ScatterElements", [ones_of_shape, idx_loc, cnts_shifted]).output[0]
2853+
last_dim_idx = ctx.make_node("CumSum", [idx_deltas, const_zero]).output[0]
2854+
last_dim_idx_unsq = GraphBuilder(ctx).make_unsqueeze({"data": last_dim_idx, "axes": [1]})
2855+
full_indices = ctx.make_node("Concat", [merged_indices, last_dim_idx_unsq], attr={'axis': 1}).output[0]
2856+
2857+
ctx.replace_all_inputs(node.output[0], full_indices)
2858+
ctx.replace_all_inputs(node.output[1], merged_values_sq_cast)
2859+
ctx.replace_all_inputs(node.output[2], final_shape)
27712860

27722861
@tf_op("DynamicPartition")
27732862
class DynamicPartition:

0 commit comments

Comments
 (0)