@@ -1812,11 +1812,20 @@ def version_9(cls, ctx, node, **kwargs):
1812
1812
onehot_indice = ctx .make_node ("Cast" , [label_name ], attr = {"to" : TensorProto .INT64 }).output [0 ]
1813
1813
else :
1814
1814
onehot_indice = label_name
1815
- label_node = ctx .make_node (op_type = "OneHot" ,
1816
- inputs = [onehot_indice , depth_node , values_node ])
1815
+ if ctx .opset < 11 :
1816
+ label_node = ctx .make_node (op_type = "OneHot" ,
1817
+ inputs = [onehot_indice , depth_node , values_node ])
1818
+ else :
1819
+ # OneHot is very slow but this workaround requires opset 11
1820
+ index_unsq = GraphBuilder (ctx ).make_unsqueeze ({'data' : onehot_indice , 'axes' : [- 1 ]})
1821
+ depth_sq = GraphBuilder (ctx ).make_squeeze ({'data' : depth_node , 'axes' : [0 ]})
1822
+ zero_const = ctx .make_const (utils .make_name ("const_zero" ), np .array (0 , np .int64 )).output [0 ]
1823
+ one_const = ctx .make_const (utils .make_name ("const_one" ), np .array (1 , np .int64 )).output [0 ]
1824
+ dp_range = ctx .make_node ("Range" , [zero_const , depth_sq , one_const ]).output [0 ]
1825
+ label_node = ctx .make_node ("Equal" , [index_unsq , dp_range ])
1817
1826
# the above logic makes output dtype of label_node now always int64
1818
1827
# make sure label has same dtype as logit
1819
- if logit_dtype != TensorProto . INT64 :
1828
+ if logit_dtype != ctx . get_dtype ( label_node . output [ 0 ]) :
1820
1829
label_node = ctx .make_node ("Cast" , label_node .output , attr = {"to" : logit_dtype }, dtypes = [logit_dtype ])
1821
1830
1822
1831
_make_sparse_softmax_cross_entropy_with_logits (ctx , label_node , logit_node , node )
0 commit comments