@@ -2116,43 +2116,77 @@ def ragged_lengths_to_sparse_indices(ctx, ragged_lens):
2116
2116
return num_rows , num_cols , row_indices , col_indices
2117
2117
2118
2118
2119
+ def ragged_nested_splits_to_sparse_indices (ctx , nested_splits , op_name_scope ):
2120
+ sparse_indices = None
2121
+ dense_shape_dims = []
2122
+ for split in nested_splits :
2123
+ if ctx .get_dtype (split ) != TensorProto .INT64 :
2124
+ split = ctx .make_node ("Cast" , [split ], attr = {'to' : TensorProto .INT64 }).output [0 ]
2125
+ max_int64 = int (utils .get_max_value (np .int64 ))
2126
+ slice1 = GraphBuilder (ctx ).make_slice (
2127
+ {"data" : split , "ends" : [max_int64 ], "starts" : [1 ], "axes" : [0 ]})
2128
+ slice2 = GraphBuilder (ctx ).make_slice (
2129
+ {"data" : split , "ends" : [- 1 ], "starts" : [0 ], "axes" : [0 ]})
2130
+ ragged_lens = ctx .make_node ("Sub" , [slice1 , slice2 ]).output [0 ]
2131
+ num_rows , num_cols , row_indices , col_indices = ragged_lengths_to_sparse_indices (ctx , ragged_lens )
2132
+ if not dense_shape_dims :
2133
+ dense_shape_dims .append (num_rows )
2134
+ dense_shape_dims .append (num_cols )
2135
+ if sparse_indices is None :
2136
+ row_indices = GraphBuilder (ctx ).make_unsqueeze ({"data" : row_indices , "axes" : [1 ]})
2137
+ else :
2138
+ row_indices = ctx .make_node ("Gather" , [sparse_indices , row_indices ]).output [0 ]
2139
+ col_indices = GraphBuilder (ctx ).make_unsqueeze ({"data" : col_indices , "axes" : [1 ]})
2140
+ sparse_indices = ctx .make_node ("Concat" , [row_indices , col_indices ], attr = {'axis' : 1 },
2141
+ op_name_scope = op_name_scope ).output [0 ]
2142
+ dense_shape = ctx .make_node ("Concat" , dense_shape_dims , attr = {'axis' : 0 }, op_name_scope = op_name_scope ).output [0 ]
2143
+ return sparse_indices , dense_shape
2144
+
2145
+
2119
2146
@tf_op ("RaggedTensorToSparse" )
2120
2147
class RaggedTensorToSparse :
2121
2148
@classmethod
2122
2149
def version_11 (cls , ctx , node , ** kwargs ):
2123
2150
# https://www.tensorflow.org/guide/ragged_tensor#multiple_ragged_dimensions
2124
2151
dense_values = node .input [- 1 ]
2125
2152
nested_splits = node .input [:- 1 ]
2126
- sparse_indices = None
2127
- dense_shape_dims = []
2128
- for split in nested_splits :
2129
- if ctx .get_dtype (split ) != TensorProto .INT64 :
2130
- split = ctx .make_node ("Cast" , [split ], attr = {'to' : TensorProto .INT64 }).output [0 ]
2131
- max_int64 = int (utils .get_max_value (np .int64 ))
2132
- slice1 = GraphBuilder (ctx ).make_slice (
2133
- {"data" : split , "ends" : [max_int64 ], "starts" : [1 ], "axes" : [0 ]})
2134
- slice2 = GraphBuilder (ctx ).make_slice (
2135
- {"data" : split , "ends" : [- 1 ], "starts" : [0 ], "axes" : [0 ]})
2136
- ragged_lens = ctx .make_node ("Sub" , [slice1 , slice2 ]).output [0 ]
2137
- num_rows , num_cols , row_indices , col_indices = ragged_lengths_to_sparse_indices (ctx , ragged_lens )
2138
- if not dense_shape_dims :
2139
- dense_shape_dims .append (num_rows )
2140
- dense_shape_dims .append (num_cols )
2141
- if sparse_indices is None :
2142
- row_indices = GraphBuilder (ctx ).make_unsqueeze ({"data" : row_indices , "axes" : [1 ]})
2143
- else :
2144
- row_indices = ctx .make_node ("Gather" , [sparse_indices , row_indices ]).output [0 ]
2145
- col_indices = GraphBuilder (ctx ).make_unsqueeze ({"data" : col_indices , "axes" : [1 ]})
2146
- sparse_indices = ctx .make_node ("Concat" , [row_indices , col_indices ], attr = {'axis' : 1 },
2147
- op_name_scope = node .name ).output [0 ]
2148
- dense_shape = ctx .make_node ("Concat" , dense_shape_dims , attr = {'axis' : 0 }, op_name_scope = node .name ).output [0 ]
2149
-
2153
+ sparse_indices , dense_shape = ragged_nested_splits_to_sparse_indices (ctx , nested_splits , node .name )
2150
2154
ctx .replace_all_inputs (node .output [0 ], sparse_indices )
2151
2155
ctx .replace_all_inputs (node .output [1 ], dense_values )
2152
2156
ctx .replace_all_inputs (node .output [2 ], dense_shape )
2153
2157
ctx .remove_node (node .name )
2154
2158
2155
2159
2160
+ @tf_op ("RaggedTensorToTensor" )
2161
+ class RaggedTensorToTensor :
2162
+ @classmethod
2163
+ def version_11 (cls , ctx , node , ** kwargs ):
2164
+ shape , values , default_value , * row_partition_tensors = node .input
2165
+ partition_types = node .get_attr_value ("row_partition_types" )
2166
+ error_msg = "Only ROW_SPLITS partition type is supported for RaggedTensorToTensor. types: %r"
2167
+ utils .make_sure (all (t == b'ROW_SPLITS' for t in partition_types ), error_msg , partition_types )
2168
+ nested_splits = row_partition_tensors
2169
+ sparse_indices , dense_shape = ragged_nested_splits_to_sparse_indices (ctx , nested_splits , node .name )
2170
+ # A shape of rank 0 means the natural shape should be used.
2171
+ if ctx .get_rank (shape ) != 0 :
2172
+ if ctx .get_dtype (shape ) != TensorProto .INT64 :
2173
+ shape = ctx .make_node ("Cast" , [shape ], attr = {'to' : TensorProto .INT64 }).output [0 ]
2174
+ const_zero_int64 = ctx .make_const (utils .make_name ("const_zero" ), np .array (0 , dtype = np .int64 )).output [0 ]
2175
+ unspec_dims = ctx .make_node ("Less" , [shape , const_zero_int64 ]).output [0 ]
2176
+ out_shape = ctx .make_node ("Where" , [unspec_dims , dense_shape , shape ]).output [0 ]
2177
+ out_shape_unsq = GraphBuilder (ctx ).make_unsqueeze ({'data' : out_shape , 'axes' : [0 ]})
2178
+ amt_idx_in_bounds = ctx .make_node ("Sub" , [out_shape_unsq , sparse_indices ]).output [0 ]
2179
+ amt_in_bounds_flat = ctx .make_node ("ReduceMin" , [amt_idx_in_bounds ], attr = {'axes' : [1 ], 'keepdims' : False })
2180
+ idx_in_bounds = ctx .make_node ("Greater" , [amt_in_bounds_flat .output [0 ], const_zero_int64 ]).output [0 ]
2181
+ sparse_indices = ctx .make_node ("Compress" , [sparse_indices , idx_in_bounds ], attr = {'axis' : 0 }).output [0 ]
2182
+ values = ctx .make_node ("Compress" , [values , idx_in_bounds ], attr = {'axis' : 0 }).output [0 ]
2183
+ else :
2184
+ out_shape = dense_shape
2185
+ expand_node = ctx .make_node ("Expand" , [default_value , out_shape ])
2186
+ node .type = "ScatterND"
2187
+ ctx .replace_inputs (node , [expand_node .output [0 ], sparse_indices , values ])
2188
+
2189
+
2156
2190
@tf_op ("RaggedRange" )
2157
2191
class RaggedRange :
2158
2192
@classmethod
0 commit comments