13
13
# pylint: disable=unused-argument,missing-docstring
14
14
15
15
16
- def softmax_cross_entropy_with_logits_computation (ctx , label , logit , tf_ori_node ):
16
+ def _make_softmax_cross_entropy_with_logits (ctx , label , logit , tf_ori_node ):
17
17
label_dtype = ctx .get_dtype (label .output [0 ])
18
18
logit_dtype = ctx .get_dtype (logit .output [0 ])
19
19
utils .make_sure (label_dtype == logit_dtype , "the following logic only works on same dtype of label and logit" )
@@ -32,18 +32,18 @@ def softmax_cross_entropy_with_logits_computation(ctx, label, logit, tf_ori_node
32
32
outputs = [tf_ori_node .output [0 ]], shapes = [shapes [0 ]], dtypes = [dtypes [0 ]])
33
33
34
34
35
- def softmax_cross_entropy_with_logits_op (ctx , node , name , args ):
35
+ def softmax_cross_entropy_with_logits_op7 (ctx , node , name , args ):
36
36
logits = node .inputs [0 ]
37
37
logit_dtype = ctx .get_dtype (logits .output [0 ])
38
38
labels = node .inputs [1 ]
39
39
label_dtype = ctx .get_dtype (labels .output [0 ])
40
40
if label_dtype != logit_dtype :
41
41
labels = ctx .make_node ("Cast" , labels .output , attr = {"to" : logit_dtype }, dtypes = [logit_dtype ])
42
42
43
- softmax_cross_entropy_with_logits_computation (ctx , labels , logits , node )
43
+ _make_softmax_cross_entropy_with_logits (ctx , labels , logits , node )
44
44
45
45
46
- def sparse_softmax_cross_entropy_with_logits_op (ctx , node , name , args ):
46
+ def sparse_softmax_cross_entropy_with_logits_op7 (ctx , node , name , args ):
47
47
# make subgraph to implement one_hot, idea comes from onehot_op
48
48
indices_name = node .input [1 ]
49
49
indices_shape = ctx .get_shape (indices_name )
@@ -150,4 +150,4 @@ def sparse_softmax_cross_entropy_with_logits_op9(ctx, node, name, args):
150
150
if logit_dtype != TensorProto .INT64 :
151
151
label_node = ctx .make_node ("Cast" , label_node .output , attr = {"to" : logit_dtype }, dtypes = [logit_dtype ])
152
152
153
- softmax_cross_entropy_with_logits_computation (ctx , label_node , logit_node , node )
153
+ _make_softmax_cross_entropy_with_logits (ctx , label_node , logit_node , node )
0 commit comments