Skip to content

Commit b9a5ed1

Browse files
authored
Add SoftmaxCrossEntropyLoss to mixed-precision-transformer. (microsoft#3760)
1 parent 9f72752 commit b9a5ed1

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

orttraining/orttraining/core/graph/mixed_precision_transformer.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@ namespace training {
3333
// continue to use 32-bit precision. Others will used reduced precision.
3434
static const std::unordered_set<std::string> FP32_Nodes = {
3535
"SparseSoftmaxCrossEntropy",
36-
"SparseSoftmaxCrossEntropyGrad"};
36+
"SparseSoftmaxCrossEntropyGrad",
37+
"SoftmaxCrossEntropyLoss",
38+
"SoftmaxCrossEntropyLossGrad"};
3739

3840
bool IsFP32Node(const Node* node) {
3941
return FP32_Nodes.find(node->OpType()) != FP32_Nodes.cend();
@@ -54,6 +56,8 @@ static const std::unordered_map<std::string, std::vector<int>> stage2_fp32_node_
5456
{"DropoutGrad", {2}},
5557
{"SparseSoftmaxCrossEntropy", {0, 2}},
5658
{"SparseSoftmaxCrossEntropyGrad", {0, 1, 3}},
59+
{"SoftmaxCrossEntropyLoss", {0, 2}},
60+
{"SoftmaxCrossEntropyLossGrad", {0, 1, 3}},
5761
};
5862

5963
bool IsFP32(const std::unordered_map<std::string, std::vector<int>>& map, std::string opname, int argnum) {

0 commit comments

Comments
 (0)