@@ -2034,6 +2034,86 @@ def version_11(cls, ctx, node, **kwargs):
2034
2034
ctx .replace_inputs (node , [expand_node .output [0 ], sparse_indices , sparse_vals ])
2035
2035
2036
2036
2037
+ @tf_op ("RaggedRange" )
2038
+ class RaggedRange :
2039
+ @classmethod
2040
+ def version_11 (cls , ctx , node , ** kwargs ):
2041
+ starts , limits , deltas = node .input
2042
+ data_dtype = ctx .get_dtype (starts )
2043
+ data_np_dtype = utils .map_onnx_to_numpy_type (data_dtype )
2044
+ data_is_float = np .dtype (data_np_dtype ).kind == 'f'
2045
+
2046
+ if data_is_float :
2047
+ sub_node = ctx .make_node ("Sub" , [limits , starts ]).output [0 ]
2048
+ div_node = ctx .make_node ("Div" , [sub_node , deltas ]).output [0 ]
2049
+ ceil_node = ctx .make_node ("Ceil" , [div_node ]).output [0 ]
2050
+ row_lens = ctx .make_node ("Cast" , [ceil_node ], attr = {'to' : TensorProto .INT64 }).output [0 ]
2051
+
2052
+ else :
2053
+ # compute ceil(a/b) with ints
2054
+ starts_cast = ctx .make_node ("Cast" , [starts ], attr = {'to' : TensorProto .INT64 }).output [0 ]
2055
+ limits_cast = ctx .make_node ("Cast" , [limits ], attr = {'to' : TensorProto .INT64 }).output [0 ]
2056
+ deltas_cast = ctx .make_node ("Cast" , [deltas ], attr = {'to' : TensorProto .INT64 }).output [0 ]
2057
+ sub_node = ctx .make_node ("Sub" , [limits_cast , starts_cast ]).output [0 ]
2058
+ div_node = ctx .make_node ("Div" , [sub_node , deltas_cast ]).output [0 ]
2059
+ mul_node = ctx .make_node ("Mul" , [div_node , deltas_cast ]).output [0 ]
2060
+ eq_node = ctx .make_node ("Equal" , [mul_node , sub_node ]).output [0 ]
2061
+ ne_node = ctx .make_node ("Not" , [eq_node ]).output [0 ]
2062
+ # we want to round up if it isn't evenly divisible
2063
+ offset = ctx .make_node ("Cast" , [ne_node ], attr = {'to' : TensorProto .INT64 }).output [0 ]
2064
+ row_lens = ctx .make_node ("Add" , [div_node , offset ]).output [0 ]
2065
+
2066
+ const_zero_int64 = ctx .make_const (utils .make_name ("const_zero" ), np .array (0 , dtype = np .int64 )).output [0 ]
2067
+ if ctx .opset <= 11 :
2068
+ const_zero_double = ctx .make_const (utils .make_name ("const_zero" ), np .array (0 , dtype = np .float64 )).output [0 ]
2069
+ row_lens = ctx .make_node ("Cast" , [row_lens ], attr = {'to' : TensorProto .DOUBLE }).output [0 ]
2070
+ row_lens = ctx .make_node ("Max" , [row_lens , const_zero_double ]).output [0 ]
2071
+ row_lens = ctx .make_node ("Cast" , [row_lens ], attr = {'to' : TensorProto .INT64 }).output [0 ]
2072
+ else :
2073
+ row_lens = ctx .make_node ("Max" , [row_lens , const_zero_int64 ]).output [0 ]
2074
+
2075
+ const_zero_list = ctx .make_const (utils .make_name ("const_zero_list" ), np .array ([0 ], dtype = np .int64 )).output [0 ]
2076
+
2077
+ max_row_len = ctx .make_node ("ReduceMax" , [row_lens ], attr = {'axes' : [0 ], 'keeepdims' : False }).output [0 ]
2078
+ inp_shape = ctx .make_node ("Shape" , [row_lens ]).output [0 ]
2079
+ range_len = ctx .make_node ("Mul" , [max_row_len , inp_shape ]).output [0 ]
2080
+
2081
+ # ORT seems to have a shape inference bug for the Range node. Use CumSum instead.
2082
+ one_tensor = helper .make_tensor ("value" , TensorProto .INT64 , dims = [1 ], vals = [1 ])
2083
+ ones_of_shape = ctx .make_node ("ConstantOfShape" , [range_len ], attr = {"value" : one_tensor }).output [0 ]
2084
+ range_node = ctx .make_node ("CumSum" , [ones_of_shape , const_zero_int64 ], attr = {'exclusive' : True }).output [0 ]
2085
+ #const_one_int64 = ctx.make_const(utils.make_name("const_one"), np.array(1, dtype=np.int64)).output[0]
2086
+ #range_node = ctx.make_node("Range", [const_zero_int64, range_len, const_one_int64]).output[0]
2087
+
2088
+ col_indices_dense = ctx .make_node ("Mod" , [range_node , max_row_len ]).output [0 ]
2089
+ row_indices_dense = ctx .make_node ("Div" , [range_node , max_row_len ]).output [0 ]
2090
+ row_lens_dense = ctx .make_node ("Gather" , [row_lens , row_indices_dense ]).output [0 ]
2091
+ indices_to_keep = ctx .make_node ("Less" , [col_indices_dense , row_lens_dense ]).output [0 ]
2092
+ col_indices = ctx .make_node ("Compress" , [col_indices_dense , indices_to_keep ]).output [0 ]
2093
+ row_indices = ctx .make_node ("Compress" , [row_indices_dense , indices_to_keep ]).output [0 ]
2094
+
2095
+
2096
+ split_ends = ctx .make_node ("CumSum" , [row_lens , const_zero_int64 ]).output [0 ]
2097
+ splits_out = ctx .make_node ("Concat" , [const_zero_list , split_ends ], attr = {'axis' : 0 }).output [0 ]
2098
+ col_indices_cast = ctx .make_node ("Cast" , [col_indices ], attr = {'to' : data_dtype }).output [0 ]
2099
+
2100
+ if ctx .get_rank (starts ) != 1 :
2101
+ starts = ctx .make_node ("Expand" , [starts , inp_shape ]).output [0 ]
2102
+
2103
+ if ctx .get_rank (deltas ) != 1 :
2104
+ deltas = ctx .make_node ("Expand" , [deltas , inp_shape ]).output [0 ]
2105
+
2106
+ gather_starts = ctx .make_node ("Gather" , [starts , row_indices ]).output [0 ]
2107
+ gather_deltas = ctx .make_node ("Gather" , [deltas , row_indices ]).output [0 ]
2108
+
2109
+ mul_node = ctx .make_node ("Mul" , [col_indices_cast , gather_deltas ], op_name_scope = node .name ).output [0 ]
2110
+ dense_vals_out = ctx .make_node ("Add" , [gather_starts , mul_node ], op_name_scope = node .name ).output [0 ]
2111
+
2112
+ ctx .replace_all_inputs (node .output [0 ], splits_out )
2113
+ ctx .replace_all_inputs (node .output [1 ], dense_vals_out )
2114
+ ctx .remove_node (node .name )
2115
+
2116
+
2037
2117
@tf_op ("SparseReshape" )
2038
2118
class SparseReshape :
2039
2119
@classmethod
0 commit comments