@@ -2768,6 +2768,95 @@ def version_13(cls, ctx, node, **kwargs):
2768
2768
# Parameters moved to inputs for operator Squeeze, Unsqueeze.
2769
2769
cls .any_version (13 , ctx , node , ** kwargs )
2770
2770
2771
+ @tf_op ("DenseToDenseSetOperation" )
2772
+ class DenseToDenseSetOperation :
2773
+ @classmethod
2774
+ def version_11 (cls , ctx , node , ** kwargs ):
2775
+ inp_a , inp_b = node .input
2776
+ dtype = ctx .get_dtype (node .output [1 ])
2777
+ if dtype != TensorProto .INT64 :
2778
+ inp_a = ctx .make_node ("Cast" , [inp_a ], attr = {'to' : TensorProto .INT64 }).output [0 ]
2779
+ inp_b = ctx .make_node ("Cast" , [inp_b ], attr = {'to' : TensorProto .INT64 }).output [0 ]
2780
+ set_op = node .get_attr_value ('set_operation' )
2781
+ if set_op == b'b-a' :
2782
+ set_op = b'a-b'
2783
+ inp_a , inp_b = inp_b , inp_a
2784
+
2785
+ one_tensor = helper .make_tensor ("value" , TensorProto .INT64 , dims = [1 ], vals = [1 ])
2786
+ const_one = ctx .make_const (utils .make_name ("const_one" ), np .array (1 , np .int64 )).output [0 ]
2787
+ const_zero = ctx .make_const (utils .make_name ("const_zero" ), np .array (0 , np .int64 )).output [0 ]
2788
+ const_zero_unsq = ctx .make_const (utils .make_name ("const_zero" ), np .array ([0 ], np .int64 )).output [0 ]
2789
+ const_neg_one_unsq = ctx .make_const (utils .make_name ("const_neg_one" ), np .array ([- 1 ], np .int64 )).output [0 ]
2790
+ max_int64 = int (utils .get_max_value (np .int64 ))
2791
+ const_two = ctx .make_const (utils .make_name ("const_two" ), np .array (2 , np .int64 )).output [0 ]
2792
+
2793
+ def concat_indices (tensor ):
2794
+ shape = ctx .make_node ("Shape" , [tensor ]).output [0 ]
2795
+ tensor_flat = ctx .make_node ("Reshape" , [tensor , const_neg_one_unsq ]).output [0 ]
2796
+ tensor_flat_unsq = GraphBuilder (ctx ).make_unsqueeze ({'data' : tensor_flat , 'axes' : [1 ]})
2797
+ ones_of_shape = ctx .make_node ("ConstantOfShape" , [shape ], attr = {'value' : one_tensor }).output [0 ]
2798
+ indices = ctx .make_node ("NonZero" , [ones_of_shape ]).output [0 ]
2799
+ sliced_indices = GraphBuilder (ctx ).make_slice ({'data' : indices , 'starts' : [0 ], 'ends' : [- 1 ], 'axes' : [0 ]})
2800
+ sliced_indices_trans = ctx .make_node ("Transpose" , [sliced_indices ], attr = {'perm' : [1 , 0 ]}).output [0 ]
2801
+ return ctx .make_node ("Concat" , [sliced_indices_trans , tensor_flat_unsq ], attr = {'axis' : 1 }).output [0 ]
2802
+
2803
+ if set_op == b'union' :
2804
+ combined = ctx .make_node ("Concat" , [inp_a , inp_b ], attr = {'axis' : - 1 }).output [0 ]
2805
+ shape = ctx .make_node ("Shape" , [combined ]).output [0 ]
2806
+ shape_prefix = GraphBuilder (ctx ).make_slice ({'data' : shape , 'starts' : [0 ], 'ends' : [- 1 ], 'axes' : [0 ]})
2807
+ indices_and_vals = concat_indices (combined )
2808
+ res_idx_and_vals = ctx .make_node ("Unique" , [indices_and_vals ], attr = {'axis' : 0 }).output [0 ]
2809
+ else :
2810
+ shape = ctx .make_node ("Shape" , [inp_a ]).output [0 ]
2811
+ shape_prefix = GraphBuilder (ctx ).make_slice ({'data' : shape , 'starts' : [0 ], 'ends' : [- 1 ], 'axes' : [0 ]})
2812
+ a_idx_and_vals = concat_indices (inp_a )
2813
+ b_idx_and_vals = concat_indices (inp_b )
2814
+ a_unique = ctx .make_node ("Unique" , [a_idx_and_vals ], attr = {'axis' : 0 }).output [0 ]
2815
+ b_unique = ctx .make_node ("Unique" , [b_idx_and_vals ], attr = {'axis' : 0 }).output [0 ]
2816
+ if set_op == b'intersection' :
2817
+ combined = ctx .make_node ("Concat" , [a_unique , b_unique ], attr = {'axis' : 0 }).output [0 ]
2818
+ desired_cnt = const_two
2819
+ else :
2820
+ utils .make_sure (set_op == b'a-b' , "Unsupported set operation: %s" , set_op )
2821
+ combined = ctx .make_node ("Concat" , [a_unique , b_unique , b_unique ], attr = {'axis' : 0 }).output [0 ]
2822
+ # cnt will be 1 if and only if element is in only set A
2823
+ desired_cnt = const_one
2824
+ unique_rows , _ , _ , row_cnts = ctx .make_node ("Unique" , [combined ], attr = {'axis' : 0 }, output_count = 4 ).output
2825
+ keep = ctx .make_node ("Equal" , [row_cnts , desired_cnt ]).output [0 ]
2826
+ compress_shape = None
2827
+ rows_shape = ctx .get_shape (unique_rows )
2828
+ if rows_shape is not None :
2829
+ compress_shape = rows_shape .copy ()
2830
+ compress_shape [0 ] = - 1
2831
+ res_idx_and_vals = ctx .make_node ("Compress" , [unique_rows , keep ], attr = {'axis' : 0 },
2832
+ shapes = [compress_shape ]).output [0 ]
2833
+
2834
+ merged_indices = GraphBuilder (ctx ).make_slice (
2835
+ {'data' : res_idx_and_vals , 'starts' : [0 ], 'ends' : [- 1 ], 'axes' : [1 ]})
2836
+ merged_values = GraphBuilder (ctx ).make_slice (
2837
+ {'data' : res_idx_and_vals , 'starts' : [- 1 ], 'ends' : [max_int64 ], 'axes' : [1 ]})
2838
+ merged_values_sq = GraphBuilder (ctx ).make_squeeze ({'data' : merged_values , 'axes' : [1 ]})
2839
+ merged_values_sq_cast = ctx .make_node ("Cast" , [merged_values_sq ], attr = {'to' : dtype }).output [0 ]
2840
+
2841
+ _ , idx_loc , _ , idx_cnts , = ctx .make_node ("Unique" , [merged_indices ], attr = {'axis' : 0 },
2842
+ output_count = 4 , op_name_scope = node .name ).output
2843
+
2844
+ max_cnt = ctx .make_node ("ReduceMax" , [idx_cnts ], attr = {'axes' : [0 ], 'keepdims' : True }).output [0 ]
2845
+ final_shape = ctx .make_node ("Concat" , [shape_prefix , max_cnt ], attr = {'axis' : 0 }).output [0 ]
2846
+ one_minus_cnts = ctx .make_node ("Sub" , [const_one , idx_cnts ]).output [0 ]
2847
+ cnts_sliced = GraphBuilder (ctx ).make_slice (
2848
+ {"data" : one_minus_cnts , "starts" : [0 ], "ends" : [- 1 ], "axes" : [0 ]})
2849
+ cnts_shifted = ctx .make_node ("Concat" , [const_zero_unsq , cnts_sliced ], attr = {'axis' : 0 }).output [0 ]
2850
+ values_shape = ctx .make_node ("Shape" , [merged_values_sq_cast ]).output [0 ]
2851
+ ones_of_shape = ctx .make_node ("ConstantOfShape" , [values_shape ], attr = {'value' : one_tensor }).output [0 ]
2852
+ idx_deltas = ctx .make_node ("ScatterElements" , [ones_of_shape , idx_loc , cnts_shifted ]).output [0 ]
2853
+ last_dim_idx = ctx .make_node ("CumSum" , [idx_deltas , const_zero ]).output [0 ]
2854
+ last_dim_idx_unsq = GraphBuilder (ctx ).make_unsqueeze ({"data" : last_dim_idx , "axes" : [1 ]})
2855
+ full_indices = ctx .make_node ("Concat" , [merged_indices , last_dim_idx_unsq ], attr = {'axis' : 1 }).output [0 ]
2856
+
2857
+ ctx .replace_all_inputs (node .output [0 ], full_indices )
2858
+ ctx .replace_all_inputs (node .output [1 ], merged_values_sq_cast )
2859
+ ctx .replace_all_inputs (node .output [2 ], final_shape )
2771
2860
2772
2861
@tf_op ("DynamicPartition" )
2773
2862
class DynamicPartition :
0 commit comments