Skip to content

Commit 84fcbc9

Browse files
committed
refactor code
1 parent 056e3d8 commit 84fcbc9

File tree

3 files changed

+12
-12
lines changed

3 files changed

+12
-12
lines changed

tf2onnx/function/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,17 @@
1111
from tf2onnx.function.matrixbandpart import matrixbandpart_op
1212
from tf2onnx.function.range import range_op7
1313
from tf2onnx.function.select import select_op8
14-
from tf2onnx.function.sparse_softmax_cross_entropy_with_logits import softmax_cross_entropy_with_logits_op
15-
from tf2onnx.function.sparse_softmax_cross_entropy_with_logits import sparse_softmax_cross_entropy_with_logits_op
16-
from tf2onnx.function.sparse_softmax_cross_entropy_with_logits import sparse_softmax_cross_entropy_with_logits_op9
14+
from tf2onnx.function.softmax_cross_entropy_with_logits import softmax_cross_entropy_with_logits_op7
15+
from tf2onnx.function.softmax_cross_entropy_with_logits import sparse_softmax_cross_entropy_with_logits_op7
16+
from tf2onnx.function.softmax_cross_entropy_with_logits import sparse_softmax_cross_entropy_with_logits_op9
1717

1818
__all__ = [
1919
"gathernd_op",
2020
"lstm_block_cell_op",
2121
"matrixbandpart_op",
2222
"range_op7",
2323
"select_op8",
24-
"softmax_cross_entropy_with_logits_op",
25-
"sparse_softmax_cross_entropy_with_logits_op",
24+
"softmax_cross_entropy_with_logits_op7",
25+
"sparse_softmax_cross_entropy_with_logits_op7",
2626
"sparse_softmax_cross_entropy_with_logits_op9",
2727
]

tf2onnx/function/sparse_softmax_cross_entropy_with_logits.py renamed to tf2onnx/function/softmax_cross_entropy_with_logits.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# pylint: disable=unused-argument,missing-docstring
1414

1515

16-
def softmax_cross_entropy_with_logits_computation(ctx, label, logit, tf_ori_node):
16+
def _make_softmax_cross_entropy_with_logits(ctx, label, logit, tf_ori_node):
1717
label_dtype = ctx.get_dtype(label.output[0])
1818
logit_dtype = ctx.get_dtype(logit.output[0])
1919
utils.make_sure(label_dtype == logit_dtype, "the following logic only works on same dtype of label and logit")
@@ -32,18 +32,18 @@ def softmax_cross_entropy_with_logits_computation(ctx, label, logit, tf_ori_node
3232
outputs=[tf_ori_node.output[0]], shapes=[shapes[0]], dtypes=[dtypes[0]])
3333

3434

35-
def softmax_cross_entropy_with_logits_op(ctx, node, name, args):
35+
def softmax_cross_entropy_with_logits_op7(ctx, node, name, args):
3636
logits = node.inputs[0]
3737
logit_dtype = ctx.get_dtype(logits.output[0])
3838
labels = node.inputs[1]
3939
label_dtype = ctx.get_dtype(labels.output[0])
4040
if label_dtype != logit_dtype:
4141
labels = ctx.make_node("Cast", labels.output, attr={"to": logit_dtype}, dtypes=[logit_dtype])
4242

43-
softmax_cross_entropy_with_logits_computation(ctx, labels, logits, node)
43+
_make_softmax_cross_entropy_with_logits(ctx, labels, logits, node)
4444

4545

46-
def sparse_softmax_cross_entropy_with_logits_op(ctx, node, name, args):
46+
def sparse_softmax_cross_entropy_with_logits_op7(ctx, node, name, args):
4747
# make subgraph to implement one_hot, idea comes from onehot_op
4848
indices_name = node.input[1]
4949
indices_shape = ctx.get_shape(indices_name)
@@ -150,4 +150,4 @@ def sparse_softmax_cross_entropy_with_logits_op9(ctx, node, name, args):
150150
if logit_dtype != TensorProto.INT64:
151151
label_node = ctx.make_node("Cast", label_node.output, attr={"to": logit_dtype}, dtypes=[logit_dtype])
152152

153-
softmax_cross_entropy_with_logits_computation(ctx, label_node, logit_node, node)
153+
_make_softmax_cross_entropy_with_logits(ctx, label_node, logit_node, node)

tf2onnx/tfonnx.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1800,7 +1800,6 @@ def where_op(ctx, node, name, args):
18001800
"ExpandDims": (expanddims_op7, []),
18011801
"OneHot": (onehot_op, []),
18021802
"Reshape": (reshape_op5, []),
1803-
"SparseSoftmaxCrossEntropyWithLogits": (sparse_softmax_cross_entropy_with_logits_op, [])
18041803
}
18051804

18061805
_OPSET_6 = {
@@ -1841,7 +1840,8 @@ def where_op(ctx, node, name, args):
18411840
"ResizeNearestNeighbor": (upsample_op7, ["Upsample", "nearest"]),
18421841
"Sin": (direct_op, []),
18431842
"Sub": (broadcast_op7, []),
1844-
"SoftmaxCrossEntropyWithLogits": (softmax_cross_entropy_with_logits_op, []),
1843+
"SoftmaxCrossEntropyWithLogits": (softmax_cross_entropy_with_logits_op7, []),
1844+
"SparseSoftmaxCrossEntropyWithLogits": (sparse_softmax_cross_entropy_with_logits_op7, []),
18451845
"Tan": (direct_op, []),
18461846
"Tile": (tile_op7, []),
18471847
"TruncateDiv": (broadcast_op7, ["Div"]),

0 commit comments

Comments
 (0)