@@ -1960,9 +1960,15 @@ def version_11(cls, ctx, node, **kwargs):
1960
1960
node_inputs = node .input
1961
1961
node_outputs = node .output
1962
1962
ctx .remove_node (node_name )
1963
+ if dtypes [0 ] in [TensorProto .INT32 , TensorProto .INT16 , TensorProto .UINT8 , TensorProto .UINT16 ]:
1964
+ inp_cast = ctx .make_node ("Cast" , [node_inputs [0 ]], attr = {'to' : TensorProto .INT64 }).output [0 ]
1965
+ node_inputs [0 ] = inp_cast
1963
1966
new_node = ctx .make_node ("Unique" , node_inputs , name = node_name , output_count = 3 , attr = {'sorted' : 0 })
1964
1967
ctx .replace_all_inputs (node_outputs [0 ], new_node .output [0 ])
1965
1968
ctx .replace_all_inputs (node_outputs [1 ], new_node .output [2 ])
1969
+ if ctx .get_dtype (new_node .output [0 ]) != dtypes [0 ]:
1970
+ ctx .insert_new_node_on_output ("Cast" , new_node .output [0 ], name = utils .make_name (node .name ) + "_cast" ,
1971
+ to = dtypes [0 ])
1966
1972
if len (node_outputs ) > 1 :
1967
1973
# cast to int64 if needed
1968
1974
if dtypes [1 ] != onnx_pb .TensorProto .INT64 :
@@ -2064,18 +2070,18 @@ class RaggedTensorToSparse:
2064
2070
@classmethod
2065
2071
def version_11 (cls , ctx , node , ** kwargs ):
2066
2072
# https://www.tensorflow.org/guide/ragged_tensor#multiple_ragged_dimensions
2067
- dense_values = node .inputs [- 1 ]
2068
- nested_splits = node .inputs [:- 1 ]
2073
+ dense_values = node .input [- 1 ]
2074
+ nested_splits = node .input [:- 1 ]
2069
2075
sparse_indices = None
2070
2076
dense_shape_dims = []
2071
2077
for split in nested_splits :
2072
- if ctx .get_dtype (split . output [ 0 ] ) != TensorProto .INT64 :
2073
- split = ctx .make_node ("Cast" , [split . output [ 0 ]] , attr = {'to' : TensorProto .INT64 })
2078
+ if ctx .get_dtype (split ) != TensorProto .INT64 :
2079
+ split = ctx .make_node ("Cast" , [split ] , attr = {'to' : TensorProto .INT64 }). output [ 0 ]
2074
2080
max_int64 = int (utils .get_max_value (np .int64 ))
2075
2081
slice1 = GraphBuilder (ctx ).make_slice (
2076
- {"data" : split . output [ 0 ] , "ends" : [max_int64 ], "starts" : [1 ], "axes" : [0 ]})
2082
+ {"data" : split , "ends" : [max_int64 ], "starts" : [1 ], "axes" : [0 ]})
2077
2083
slice2 = GraphBuilder (ctx ).make_slice (
2078
- {"data" : split . output [ 0 ] , "ends" : [- 1 ], "starts" : [0 ], "axes" : [0 ]})
2084
+ {"data" : split , "ends" : [- 1 ], "starts" : [0 ], "axes" : [0 ]})
2079
2085
ragged_lens = ctx .make_node ("Sub" , [slice1 , slice2 ]).output [0 ]
2080
2086
num_rows , num_cols , row_indices , col_indices = ragged_lengths_to_sparse_indices (ctx , ragged_lens )
2081
2087
if not dense_shape_dims :
@@ -2091,7 +2097,7 @@ def version_11(cls, ctx, node, **kwargs):
2091
2097
dense_shape = ctx .make_node ("Concat" , dense_shape_dims , attr = {'axis' : 0 }, op_name_scope = node .name ).output [0 ]
2092
2098
2093
2099
ctx .replace_all_inputs (node .output [0 ], sparse_indices )
2094
- ctx .replace_all_inputs (node .output [1 ], dense_values . output [ 0 ] )
2100
+ ctx .replace_all_inputs (node .output [1 ], dense_values )
2095
2101
ctx .replace_all_inputs (node .output [2 ], dense_shape )
2096
2102
ctx .remove_node (node .name )
2097
2103
0 commit comments