@@ -729,20 +729,25 @@ def version_9(cls, ctx, node, **kwargs):
729
729
logit_dtype = ctx .get_dtype (node .input [0 ])
730
730
731
731
label_name = node .input [1 ]
732
- label_dtype = ctx .get_dtype (label_name )
733
732
734
- num_class = logit_shape [- 1 ]
735
- utils .make_sure (num_class != - 1 ,
736
- "number of class should be known, otherwise subgraph to get the info is needed" )
737
- # int64 is used because of onnxruntime "onehot" only supports this dtype
738
- depth_node = ctx .make_const (utils .make_name ("onehot_depth" ), np .array ([num_class ]).astype (np .int64 ))
739
- values_node = ctx .make_const (utils .make_name ("onehot_values" ), np .array ([0 , 1 ]).astype (np .int64 ))
733
+ if logit_shape is not None and logit_shape [- 1 ] != - 1 :
734
+ num_class = logit_shape [- 1 ]
735
+ node_nme = utils .make_name ("onehot_depth" )
736
+ depth_node = ctx .make_const (node_nme , np .array ([num_class ]).astype (np .int64 )).output [0 ]
737
+ else :
738
+ logit_shape = ctx .make_node ("Shape" , [node .input [0 ]]).output [0 ]
739
+ slice_args = {"data" : logit_shape ,
740
+ "starts" : [- 1 ], "ends" : [int (utils .get_max_value (np .int32 ))]}
741
+ num_class = GraphBuilder (ctx ).make_slice (kwargs = slice_args )
742
+ depth_node = num_class
743
+ values_node = ctx .make_const (utils .make_name ("onehot_values" ), np .array ([0 , 1 ]).astype (np .int64 )).output [0 ]
744
+ label_dtype = ctx .get_dtype (label_name )
740
745
if label_dtype != TensorProto .INT64 :
741
746
onehot_indice = ctx .make_node ("Cast" , [label_name ], attr = {"to" : TensorProto .INT64 }).output [0 ]
742
747
else :
743
748
onehot_indice = label_name
744
749
label_node = ctx .make_node (op_type = "OneHot" ,
745
- inputs = [onehot_indice , depth_node . output [ 0 ] , values_node . output [ 0 ] ])
750
+ inputs = [onehot_indice , depth_node , values_node ])
746
751
# the above logic makes output dtype of label_node now always int64
747
752
# make sure label has same dtype as logit
748
753
if logit_dtype != TensorProto .INT64 :
0 commit comments