@@ -2465,6 +2465,8 @@ def version_11(cls, ctx, node, **kwargs):
2465
2465
# https://www.tensorflow.org/guide/ragged_tensor#multiple_ragged_dimensions
2466
2466
dense_values = node .input [- 1 ]
2467
2467
nested_splits = node .input [:- 1 ]
2468
+ err_msg2 = "RaggedTensorToSparse conversion only supports tensors with no dense dimensions"
2469
+ utils .make_sure (ctx .get_rank (dense_values ) in [None , 1 ], err_msg2 )
2468
2470
sparse_indices , dense_shape = ragged_nested_splits_to_sparse_indices (ctx , nested_splits , node .name )
2469
2471
ctx .replace_all_inputs (node .output [0 ], sparse_indices )
2470
2472
ctx .replace_all_inputs (node .output [1 ], dense_values )
@@ -2477,6 +2479,7 @@ class RaggedTensorToTensor:
2477
2479
@classmethod
2478
2480
def version_11 (cls , ctx , node , ** kwargs ):
2479
2481
shape , values , default_value , * row_partition_tensors = node .input
2482
+ has_uniform_dims = ctx .get_rank (values ) != 1
2480
2483
partition_types = node .get_attr_value ("row_partition_types" )
2481
2484
layout_type = None
2482
2485
if len (partition_types ) >= 2 and partition_types [0 ] == b'FIRST_DIM_SIZE' and \
@@ -2488,17 +2491,25 @@ def version_11(cls, ctx, node, **kwargs):
2488
2491
2489
2492
if layout_type == 'ROW_SPLITS' :
2490
2493
nested_splits = row_partition_tensors
2494
+ n_dims = len (nested_splits ) + 1
2491
2495
sparse_indices , dense_shape = ragged_nested_splits_to_sparse_indices (ctx , nested_splits , node .name )
2492
2496
else :
2493
2497
utils .make_sure (layout_type == 'VALUE_ROWIDS' , error_msg , partition_types )
2494
2498
first_dim = row_partition_tensors [0 ]
2495
2499
row_ids = row_partition_tensors [1 :]
2500
+ n_dims = len (row_ids ) + 1
2496
2501
sparse_indices , dense_shape = ragged_nested_row_ids_to_sparse_indices (ctx , first_dim , row_ids , node .name )
2497
2502
2498
2503
# A shape of rank 0 means the natural shape should be used.
2499
2504
if ctx .get_rank (shape ) != 0 :
2500
2505
if ctx .get_dtype (shape ) != TensorProto .INT64 :
2501
2506
shape = ctx .make_node ("Cast" , [shape ], attr = {'to' : TensorProto .INT64 }).output [0 ]
2507
+ if has_uniform_dims :
2508
+ const_zero_unsq = ctx .make_const (utils .make_name ("const_zero" ), np .array ([0 ], dtype = np .int64 )).output [0 ]
2509
+ const_n_unsq = ctx .make_const (utils .make_name ("const_num_dims" ),
2510
+ np .array ([n_dims ], dtype = np .int64 )).output [0 ]
2511
+ shape = GraphBuilder (ctx ).make_slice (
2512
+ {'data' : shape , 'starts' : const_zero_unsq , 'ends' : const_n_unsq , 'axes' : [0 ]})
2502
2513
const_zero_int64 = ctx .make_const (utils .make_name ("const_zero" ), np .array (0 , dtype = np .int64 )).output [0 ]
2503
2514
unspec_dims = ctx .make_node ("Less" , [shape , const_zero_int64 ]).output [0 ]
2504
2515
out_shape = ctx .make_node ("Where" , [unspec_dims , dense_shape , shape ]).output [0 ]
@@ -2510,6 +2521,16 @@ def version_11(cls, ctx, node, **kwargs):
2510
2521
values = ctx .make_node ("Compress" , [values , idx_in_bounds ], attr = {'axis' : 0 }).output [0 ]
2511
2522
else :
2512
2523
out_shape = dense_shape
2524
+
2525
+ if has_uniform_dims :
2526
+ values_shape = ctx .make_node ("Shape" , [values ]).output [0 ]
2527
+ const_one_unsq = ctx .make_const (utils .make_name ("const_one" ), np .array ([1 ], dtype = np .int64 )).output [0 ]
2528
+ max_int64 = np .array ([utils .get_max_value (np .int64 )], dtype = np .int64 )
2529
+ const_max_val_unsq = ctx .make_const (utils .make_name ("max_int" ), max_int64 ).output [0 ]
2530
+ uniform_dims = GraphBuilder (ctx ).make_slice (
2531
+ {'data' : values_shape , 'starts' : const_one_unsq , 'ends' : const_max_val_unsq , 'axes' : [0 ]})
2532
+ out_shape = ctx .make_node ("Concat" , [out_shape , uniform_dims ], attr = {'axis' : 0 }).output [0 ]
2533
+
2513
2534
expand_node = ctx .make_node ("Expand" , [default_value , out_shape ])
2514
2535
node .type = "ScatterND"
2515
2536
ctx .replace_inputs (node , [expand_node .output [0 ], sparse_indices , values ])
0 commit comments