@@ -2250,6 +2250,77 @@ def version_11(cls, ctx, node, **kwargs):
2250
2250
ctx .remove_node (node .name )
2251
2251
2252
2252
2253
+ @tf_op ("RaggedGather" )
2254
+ class RaggedGather :
2255
+ @classmethod
2256
+ def version_11 (cls , ctx , node , ** kwargs ):
2257
+ * params_nested_splits , params_dense_values , indices = node .input
2258
+ inp_ragged_rank = node .get_attr_value ("PARAMS_RAGGED_RANK" )
2259
+ out_ragged_rank = node .get_attr_value ("OUTPUT_RAGGED_RANK" )
2260
+ err_msg = "RaggedGather conversion only supports ragged rank of 1"
2261
+ utils .make_sure (inp_ragged_rank == 1 and out_ragged_rank == 1 and len (params_nested_splits ) == 1 , err_msg )
2262
+ splits = params_nested_splits [0 ]
2263
+ err_msg2 = "RaggedGather conversion only supports tensors with no dense dimensions"
2264
+ utils .make_sure (ctx .get_rank (splits ) in [None , 1 ] and ctx .get_rank (params_dense_values ) in [None , 1 ], err_msg2 )
2265
+ splits_dtype = ctx .get_dtype (splits )
2266
+
2267
+ if splits_dtype != TensorProto .INT64 :
2268
+ splits_64 = ctx .make_node ("Cast" , [splits ], attr = {'to' : TensorProto .INT64 }).output [0 ]
2269
+ else :
2270
+ splits_64 = splits
2271
+
2272
+ max_int64 = int (utils .get_max_value (np .int64 ))
2273
+ slice1 = GraphBuilder (ctx ).make_slice (
2274
+ {"data" : splits_64 , "ends" : [max_int64 ], "starts" : [1 ], "axes" : [0 ]})
2275
+ slice2 = GraphBuilder (ctx ).make_slice (
2276
+ {"data" : splits_64 , "ends" : [- 1 ], "starts" : [0 ], "axes" : [0 ]})
2277
+ ragged_lens = ctx .make_node ("Sub" , [slice1 , slice2 ]).output [0 ]
2278
+
2279
+ gathered_lens = ctx .make_node ("Gather" , [ragged_lens , indices ], op_name_scope = node .name ).output [0 ]
2280
+
2281
+ const_zero_unsq = ctx .make_const (utils .make_name ("const_zero" ), np .array ([0 ], dtype = np .int64 )).output [0 ]
2282
+ const_one_unsq = ctx .make_const (utils .make_name ("const_one" ), np .array ([1 ], dtype = np .int64 )).output [0 ]
2283
+ gathered_lens_w_zero = ctx .make_node ("Concat" , [const_zero_unsq , gathered_lens ], attr = {'axis' : 0 }).output [0 ]
2284
+
2285
+ const_zero_int64 = ctx .make_const (utils .make_name ("const_zero" ), np .array (0 , dtype = np .int64 )).output [0 ]
2286
+ const_one_int64 = ctx .make_const (utils .make_name ("const_one" ), np .array (1 , dtype = np .int64 )).output [0 ]
2287
+
2288
+ gathered_splits = ctx .make_node ("CumSum" , [gathered_lens_w_zero , const_zero_int64 ]).output [0 ]
2289
+ if splits_dtype != TensorProto .INT64 :
2290
+ output_splits = ctx .make_node ("Cast" , [gathered_splits ], attr = {'to' : splits_dtype }).output [0 ]
2291
+ else :
2292
+ output_splits = gathered_splits
2293
+
2294
+ # Now that we have the splits, we just need to make the list of values.
2295
+ total_length = GraphBuilder (ctx ).make_slice (
2296
+ {"data" : gathered_splits , "ends" : [max_int64 ], "starts" : [- 1 ], "axes" : [0 ]})
2297
+ gathered_starts = ctx .make_node ("Gather" , [splits_64 , indices ], op_name_scope = node .name ).output [0 ]
2298
+ # We disregard any length 0 segments
2299
+ non_zero_pos = ctx .make_node ("Greater" , [gathered_lens , const_zero_int64 ]).output [0 ]
2300
+ non_zero_lens = ctx .make_node ("Compress" , [gathered_lens , non_zero_pos ]).output [0 ]
2301
+ non_zero_lens_shifted = ctx .make_node ("Concat" , [const_zero_unsq , non_zero_lens ], attr = {'axis' : 0 }).output [0 ]
2302
+ non_zero_prev_lens = GraphBuilder (ctx ).make_slice (
2303
+ {"data" : non_zero_lens_shifted , "ends" : [- 1 ], "starts" : [0 ], "axes" : [0 ]})
2304
+ non_zero_starts = ctx .make_node ("Compress" , [gathered_starts , non_zero_pos ]).output [0 ]
2305
+ non_zero_splits = ctx .make_node ("Compress" , [gathered_splits , non_zero_pos ]).output [0 ]
2306
+
2307
+ prev_starts = GraphBuilder (ctx ).make_slice (
2308
+ {"data" : non_zero_starts , "ends" : [- 1 ], "starts" : [0 ], "axes" : [0 ]})
2309
+ prev_starts_concat = ctx .make_node ("Concat" , [const_one_unsq , prev_starts ], attr = {'axis' : 0 }).output [0 ]
2310
+ deltas = ctx .make_node ("Sub" , [non_zero_starts , prev_starts_concat ]).output [0 ]
2311
+ deltas2 = ctx .make_node ("Sub" , [deltas , non_zero_prev_lens ]).output [0 ]
2312
+ deltas3 = ctx .make_node ("Add" , [deltas2 , const_one_int64 ]).output [0 ]
2313
+ one_tensor = helper .make_tensor ("value" , TensorProto .INT64 , dims = [1 ], vals = [1 ])
2314
+ ones_of_shape = ctx .make_node ("ConstantOfShape" , [total_length ], attr = {"value" : one_tensor }).output [0 ]
2315
+ full_deltas = ctx .make_node ("ScatterElements" , [ones_of_shape , non_zero_splits , deltas3 ], attr = {'axis' : 0 })
2316
+ full_indices = ctx .make_node ("CumSum" , [full_deltas .output [0 ], const_zero_int64 ]).output [0 ]
2317
+ output_values = ctx .make_node ("Gather" , [params_dense_values , full_indices ], op_name_scope = node .name ).output [0 ]
2318
+
2319
+ ctx .replace_all_inputs (node .output [0 ], output_splits )
2320
+ ctx .replace_all_inputs (node .output [1 ], output_values )
2321
+ ctx .remove_node (node .name )
2322
+
2323
+
2253
2324
@tf_op ("SparseReshape" )
2254
2325
class SparseReshape :
2255
2326
@classmethod
0 commit comments