@@ -1313,8 +1313,11 @@ def version_1(cls, ctx, node, **kwargs):
1313
1313
if axis .i == 0 :
1314
1314
# TODO: revisit for rank > 1
1315
1315
name = utils .make_name (node .name )
1316
- transpose_node = ctx .insert_new_node_on_output ("Transpose" , node .output [0 ], name )
1317
- ctx .copy_shape (node .output [0 ], transpose_node .output [0 ])
1316
+ shape = ctx .get_shape (node .output [0 ])
1317
+ transpose_node = ctx .make_node ("Transpose" , [node .output [0 ]], name = name , shapes = [shape ])
1318
+ ctx .insert_node_on_output (transpose_node , node .output [0 ])
1319
+ if shape is not None :
1320
+ ctx .set_shape (node .output [0 ], shape [::- 1 ])
1318
1321
1319
1322
@classmethod
1320
1323
def any_version_after9 (cls , opset , ctx , node , ** kwargs ):
@@ -1323,9 +1326,11 @@ def any_version_after9(cls, opset, ctx, node, **kwargs):
1323
1326
# in ONNX, op's schema is (input, depth, value, @int axis), meaning of "value" is [off-value, on-value]
1324
1327
# onnxruntime only supports int64
1325
1328
output_dtype = ctx .get_dtype (node .input [2 ])
1326
- if ctx .is_target (constants .TARGET_RS6 ) \
1327
- and output_dtype not in [onnx_pb .TensorProto .INT64 , onnx_pb .TensorProto .INT32 ]:
1328
- logger .warning ("unsupported dtype in onnxruntime, onehot-9 can't be used directly" )
1329
+ supported_dtypes = [onnx_pb .TensorProto .FLOAT ]
1330
+ if ctx .is_target (constants .TARGET_RS6 ):
1331
+ supported_dtypes = [onnx_pb .TensorProto .INT64 , onnx_pb .TensorProto .INT32 ]
1332
+ if output_dtype not in supported_dtypes :
1333
+ logger .warning ("unsupported dtype in target runtime, OneHot op can't be used directly" )
1329
1334
cls .version_1 (ctx , node , ** kwargs )
1330
1335
return
1331
1336
0 commit comments