@@ -2385,6 +2385,25 @@ def ragged_lengths_to_sparse_indices(ctx, ragged_lens):
2385
2385
return num_rows , num_cols , row_indices , col_indices
2386
2386
2387
2387
2388
+ def ragged_row_ids_to_sparse_indices (ctx , row_ids ):
2389
+ _ , indices , _ , counts = ctx .make_node ("Unique" , [row_ids ], attr = {'axis' : 0 }, output_count = 4 ).output
2390
+ num_cols = ctx .make_node ("ReduceMax" , [counts ], attr = {'axes' : [0 ], 'keepdims' : True }).output [0 ]
2391
+ const_one = ctx .make_const (utils .make_name ("const_one" ), np .array (1 , np .int64 )).output [0 ]
2392
+ const_zero_unsq = ctx .make_const (utils .make_name ("const_zero" ), np .array ([0 ], np .int64 )).output [0 ]
2393
+ const_zero = ctx .make_const (utils .make_name ("const_zero" ), np .array (0 , np .int64 )).output [0 ]
2394
+ const_neg_one_unsq = ctx .make_const (utils .make_name ("const_neg_one" ), np .array ([- 1 ], np .int64 )).output [0 ]
2395
+ one_minus_cnt = ctx .make_node ("Sub" , [const_one , counts ]).output [0 ]
2396
+ cnts_prefixed = ctx .make_node ("Concat" , [const_zero_unsq , one_minus_cnt ], attr = {'axis' : 0 }).output [0 ]
2397
+ cnts_shifted = GraphBuilder (ctx ).make_slice (
2398
+ {'data' : cnts_prefixed , 'starts' : const_zero_unsq , 'ends' : const_neg_one_unsq , 'axes' : [0 ]})
2399
+ ids_shape = ctx .make_node ("Shape" , [row_ids ]).output [0 ]
2400
+ one_tensor = helper .make_tensor ("value" , onnx_pb .TensorProto .INT64 , dims = [1 ], vals = [1 ])
2401
+ ones_of_shape = ctx .make_node ("ConstantOfShape" , [ids_shape ], attr = {'value' : one_tensor }).output [0 ]
2402
+ deltas = ctx .make_node ("ScatterElements" , [ones_of_shape , indices , cnts_shifted ], attr = {'axis' : 0 }).output [0 ]
2403
+ col_indices = ctx .make_node ("CumSum" , [deltas , const_zero ]).output [0 ]
2404
+ return num_cols , col_indices
2405
+
2406
+
2388
2407
def ragged_nested_splits_to_sparse_indices (ctx , nested_splits , op_name_scope ):
2389
2408
sparse_indices = None
2390
2409
dense_shape_dims = []
@@ -2412,6 +2431,28 @@ def ragged_nested_splits_to_sparse_indices(ctx, nested_splits, op_name_scope):
2412
2431
return sparse_indices , dense_shape
2413
2432
2414
2433
2434
+ def ragged_nested_row_ids_to_sparse_indices (ctx , num_rows , nested_row_ids , op_name_scope ):
2435
+ sparse_indices = None
2436
+ if ctx .get_dtype (num_rows ) != TensorProto .INT64 :
2437
+ num_rows = ctx .make_node ("Cast" , [num_rows ], attr = {'to' : TensorProto .INT64 }).output [0 ]
2438
+ num_rows = GraphBuilder (ctx ).make_unsqueeze ({"data" : num_rows , "axes" : [0 ]})
2439
+ dense_shape_dims = [num_rows ]
2440
+ for row_ids in nested_row_ids :
2441
+ if ctx .get_dtype (row_ids ) != TensorProto .INT64 :
2442
+ row_ids = ctx .make_node ("Cast" , [row_ids ], attr = {'to' : TensorProto .INT64 }).output [0 ]
2443
+ num_cols , col_indices = ragged_row_ids_to_sparse_indices (ctx , row_ids )
2444
+ dense_shape_dims .append (num_cols )
2445
+ if sparse_indices is None :
2446
+ row_indices = GraphBuilder (ctx ).make_unsqueeze ({"data" : row_ids , "axes" : [1 ]})
2447
+ else :
2448
+ row_indices = ctx .make_node ("Gather" , [sparse_indices , row_ids ]).output [0 ]
2449
+ col_indices = GraphBuilder (ctx ).make_unsqueeze ({"data" : col_indices , "axes" : [1 ]})
2450
+ sparse_indices = ctx .make_node ("Concat" , [row_indices , col_indices ], attr = {'axis' : 1 },
2451
+ op_name_scope = op_name_scope ).output [0 ]
2452
+ dense_shape = ctx .make_node ("Concat" , dense_shape_dims , attr = {'axis' : 0 }, op_name_scope = op_name_scope ).output [0 ]
2453
+ return sparse_indices , dense_shape
2454
+
2455
+
2415
2456
@tf_op ("RaggedTensorToSparse" )
2416
2457
class RaggedTensorToSparse :
2417
2458
@classmethod
@@ -2432,10 +2473,23 @@ class RaggedTensorToTensor:
2432
2473
def version_11 (cls , ctx , node , ** kwargs ):
2433
2474
shape , values , default_value , * row_partition_tensors = node .input
2434
2475
partition_types = node .get_attr_value ("row_partition_types" )
2435
- error_msg = "Only ROW_SPLITS partition type is supported for RaggedTensorToTensor. types: %r"
2436
- utils .make_sure (all (t == b'ROW_SPLITS' for t in partition_types ), error_msg , partition_types )
2437
- nested_splits = row_partition_tensors
2438
- sparse_indices , dense_shape = ragged_nested_splits_to_sparse_indices (ctx , nested_splits , node .name )
2476
+ layout_type = None
2477
+ if len (partition_types ) >= 2 and partition_types [0 ] == b'FIRST_DIM_SIZE' and \
2478
+ all (t == b'VALUE_ROWIDS' for t in partition_types [1 :]):
2479
+ layout_type = 'VALUE_ROWIDS'
2480
+ elif all (t == b'ROW_SPLITS' for t in partition_types ):
2481
+ layout_type = 'ROW_SPLITS'
2482
+ error_msg = "Only ROW_SPLITS partition and VALUE_ROWIDS types supported for RaggedTensorToTensor. types: %r"
2483
+
2484
+ if layout_type == 'ROW_SPLITS' :
2485
+ nested_splits = row_partition_tensors
2486
+ sparse_indices , dense_shape = ragged_nested_splits_to_sparse_indices (ctx , nested_splits , node .name )
2487
+ else :
2488
+ utils .make_sure (layout_type == 'VALUE_ROWIDS' , error_msg , partition_types )
2489
+ first_dim = row_partition_tensors [0 ]
2490
+ row_ids = row_partition_tensors [1 :]
2491
+ sparse_indices , dense_shape = ragged_nested_row_ids_to_sparse_indices (ctx , first_dim , row_ids , node .name )
2492
+
2439
2493
# A shape of rank 0 means the natural shape should be used.
2440
2494
if ctx .get_rank (shape ) != 0 :
2441
2495
if ctx .get_dtype (shape ) != TensorProto .INT64 :
0 commit comments