@@ -2036,6 +2036,65 @@ def version_11(cls, ctx, node, **kwargs):
2036
2036
ctx .replace_inputs (node , [expand_node .output [0 ], sparse_indices , sparse_vals ])
2037
2037
2038
2038
2039
+ def ragged_lengths_to_sparse_indices (ctx , ragged_lens ):
2040
+ const_zero_int64 = ctx .make_const (utils .make_name ("const_zero" ), np .array (0 , dtype = np .int64 )).output [0 ]
2041
+ num_cols = ctx .make_node ("ReduceMax" , [ragged_lens ], attr = {'axes' : [0 ], 'keeepdims' : True }).output [0 ]
2042
+ num_rows = ctx .make_node ("Shape" , [ragged_lens ]).output [0 ]
2043
+ range_len = ctx .make_node ("Mul" , [num_cols , num_rows ]).output [0 ]
2044
+
2045
+ # ORT seems to have a shape inference bug for the Range node. Use CumSum instead.
2046
+ one_tensor = helper .make_tensor ("value" , TensorProto .INT64 , dims = [1 ], vals = [1 ])
2047
+ ones_of_shape = ctx .make_node ("ConstantOfShape" , [range_len ], attr = {"value" : one_tensor }).output [0 ]
2048
+ range_node = ctx .make_node ("CumSum" , [ones_of_shape , const_zero_int64 ], attr = {'exclusive' : True }).output [0 ]
2049
+ #const_one_int64 = ctx.make_const(utils.make_name("const_one"), np.array(1, dtype=np.int64)).output[0]
2050
+ #range_node = ctx.make_node("Range", [const_zero_int64, range_len, const_one_int64]).output[0]
2051
+
2052
+ col_indices_dense = ctx .make_node ("Mod" , [range_node , num_cols ]).output [0 ]
2053
+ row_indices_dense = ctx .make_node ("Div" , [range_node , num_cols ]).output [0 ]
2054
+ row_lens_dense = ctx .make_node ("Gather" , [ragged_lens , row_indices_dense ]).output [0 ]
2055
+ indices_to_keep = ctx .make_node ("Less" , [col_indices_dense , row_lens_dense ]).output [0 ]
2056
+ col_indices = ctx .make_node ("Compress" , [col_indices_dense , indices_to_keep ]).output [0 ]
2057
+ row_indices = ctx .make_node ("Compress" , [row_indices_dense , indices_to_keep ]).output [0 ]
2058
+ return num_rows , num_cols , row_indices , col_indices
2059
+
2060
+
2061
+ @tf_op ("RaggedTensorToSparse" )
2062
+ class RaggedTensorToSparse :
2063
+ @classmethod
2064
+ def version_11 (cls , ctx , node , ** kwargs ):
2065
+ # https://www.tensorflow.org/guide/ragged_tensor#multiple_ragged_dimensions
2066
+ dense_values = node .inputs [- 1 ]
2067
+ nested_splits = node .inputs [:- 1 ]
2068
+ sparse_indices = None
2069
+ dense_shape_dims = []
2070
+ for split in nested_splits :
2071
+ if ctx .get_dtype (split .output [0 ]) != TensorProto .INT64 :
2072
+ split = ctx .make_node ("Cast" , [split .output [0 ]], attr = {'to' : TensorProto .INT64 })
2073
+ max_int64 = int (utils .get_max_value (np .int64 ))
2074
+ slice1 = GraphBuilder (ctx ).make_slice (
2075
+ {"data" : split .output [0 ], "ends" : [max_int64 ], "starts" : [1 ], "axes" : [0 ]})
2076
+ slice2 = GraphBuilder (ctx ).make_slice (
2077
+ {"data" : split .output [0 ], "ends" : [- 1 ], "starts" : [0 ], "axes" : [0 ]})
2078
+ ragged_lens = ctx .make_node ("Sub" , [slice1 , slice2 ]).output [0 ]
2079
+ num_rows , num_cols , row_indices , col_indices = ragged_lengths_to_sparse_indices (ctx , ragged_lens )
2080
+ if not dense_shape_dims :
2081
+ dense_shape_dims .append (num_rows )
2082
+ dense_shape_dims .append (num_cols )
2083
+ if sparse_indices is None :
2084
+ row_indices = GraphBuilder (ctx ).make_unsqueeze ({"data" : row_indices , "axes" : [1 ]})
2085
+ else :
2086
+ row_indices = ctx .make_node ("Gather" , [sparse_indices , row_indices ]).output [0 ]
2087
+ col_indices = GraphBuilder (ctx ).make_unsqueeze ({"data" : col_indices , "axes" : [1 ]})
2088
+ sparse_indices = ctx .make_node ("Concat" , [row_indices , col_indices ], attr = {'axis' : 1 },
2089
+ op_name_scope = node .name ).output [0 ]
2090
+ dense_shape = ctx .make_node ("Concat" , dense_shape_dims , attr = {'axis' : 0 }, op_name_scope = node .name ).output [0 ]
2091
+
2092
+ ctx .replace_all_inputs (node .output [0 ], sparse_indices )
2093
+ ctx .replace_all_inputs (node .output [1 ], dense_values .output [0 ])
2094
+ ctx .replace_all_inputs (node .output [2 ], dense_shape )
2095
+ ctx .remove_node (node .name )
2096
+
2097
+
2039
2098
@tf_op ("RaggedRange" )
2040
2099
class RaggedRange :
2041
2100
@classmethod
@@ -2076,34 +2135,17 @@ def version_11(cls, ctx, node, **kwargs):
2076
2135
2077
2136
const_zero_list = ctx .make_const (utils .make_name ("const_zero_list" ), np .array ([0 ], dtype = np .int64 )).output [0 ]
2078
2137
2079
- max_row_len = ctx .make_node ("ReduceMax" , [row_lens ], attr = {'axes' : [0 ], 'keeepdims' : False }).output [0 ]
2080
- inp_shape = ctx .make_node ("Shape" , [row_lens ]).output [0 ]
2081
- range_len = ctx .make_node ("Mul" , [max_row_len , inp_shape ]).output [0 ]
2082
-
2083
- # ORT seems to have a shape inference bug for the Range node. Use CumSum instead.
2084
- one_tensor = helper .make_tensor ("value" , TensorProto .INT64 , dims = [1 ], vals = [1 ])
2085
- ones_of_shape = ctx .make_node ("ConstantOfShape" , [range_len ], attr = {"value" : one_tensor }).output [0 ]
2086
- range_node = ctx .make_node ("CumSum" , [ones_of_shape , const_zero_int64 ], attr = {'exclusive' : True }).output [0 ]
2087
- #const_one_int64 = ctx.make_const(utils.make_name("const_one"), np.array(1, dtype=np.int64)).output[0]
2088
- #range_node = ctx.make_node("Range", [const_zero_int64, range_len, const_one_int64]).output[0]
2089
-
2090
- col_indices_dense = ctx .make_node ("Mod" , [range_node , max_row_len ]).output [0 ]
2091
- row_indices_dense = ctx .make_node ("Div" , [range_node , max_row_len ]).output [0 ]
2092
- row_lens_dense = ctx .make_node ("Gather" , [row_lens , row_indices_dense ]).output [0 ]
2093
- indices_to_keep = ctx .make_node ("Less" , [col_indices_dense , row_lens_dense ]).output [0 ]
2094
- col_indices = ctx .make_node ("Compress" , [col_indices_dense , indices_to_keep ]).output [0 ]
2095
- row_indices = ctx .make_node ("Compress" , [row_indices_dense , indices_to_keep ]).output [0 ]
2096
-
2138
+ num_rows , _ , row_indices , col_indices = ragged_lengths_to_sparse_indices (ctx , row_lens )
2097
2139
2098
2140
split_ends = ctx .make_node ("CumSum" , [row_lens , const_zero_int64 ]).output [0 ]
2099
2141
splits_out = ctx .make_node ("Concat" , [const_zero_list , split_ends ], attr = {'axis' : 0 }).output [0 ]
2100
2142
col_indices_cast = ctx .make_node ("Cast" , [col_indices ], attr = {'to' : data_dtype }).output [0 ]
2101
2143
2102
2144
if ctx .get_rank (starts ) != 1 :
2103
- starts = ctx .make_node ("Expand" , [starts , inp_shape ]).output [0 ]
2145
+ starts = ctx .make_node ("Expand" , [starts , num_rows ]).output [0 ]
2104
2146
2105
2147
if ctx .get_rank (deltas ) != 1 :
2106
- deltas = ctx .make_node ("Expand" , [deltas , inp_shape ]).output [0 ]
2148
+ deltas = ctx .make_node ("Expand" , [deltas , num_rows ]).output [0 ]
2107
2149
2108
2150
gather_starts = ctx .make_node ("Gather" , [starts , row_indices ]).output [0 ]
2109
2151
gather_deltas = ctx .make_node ("Gather" , [deltas , row_indices ]).output [0 ]
0 commit comments