Skip to content

Commit e93e48e

Browse files
[AMP] fix bf16 amp training error (#54571) (#54643)
fix bf16 amp training error cherry pick #54571
1 parent 6b778b9 commit e93e48e

File tree

3 files changed

+34
-41
lines changed

3 files changed

+34
-41
lines changed

paddle/fluid/eager/amp_auto_cast.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,15 +69,16 @@ inline paddle::Tensor AmpAutoCast(const std::string& input_name,
6969
VLOG(6) << "AMP AmpAutoCasts:"
7070
<< " input(" << input_name << ") dst_dtype("
7171
<< phi::DataTypeToString(dst_dtype) << ").";
72+
73+
if ((op_name == "batch_norm" || op_name == "layer_norm" ||
74+
op_name == "sync_batch_norm") &&
75+
input_name != "X") {
76+
return input;
77+
}
7278
if (dst_dtype == phi::DataType::FLOAT16) {
7379
if (op_name == "run_program") {
7480
return input;
7581
}
76-
if ((op_name == "batch_norm" || op_name == "layer_norm" ||
77-
op_name == "sync_batch_norm") &&
78-
input_name != "X") {
79-
return input;
80-
}
8182
if ((op_name == "fused_attention" || op_name == "fused_feedforward")) {
8283
if (input_name == "LnScale" || input_name == "LnBias" ||
8384
input_name == "Ln2Scale" || input_name == "Ln2Bias" ||
@@ -86,6 +87,7 @@ inline paddle::Tensor AmpAutoCast(const std::string& input_name,
8687
}
8788
}
8889
}
90+
8991
if (NeedCast(input, dst_dtype)) {
9092
paddle::framework::AttributeMap cast_attrs = {
9193
{"in_dtype", paddle::framework::TransToProtoVarType(input.dtype())},

paddle/fluid/eager/amp_utils.h

Lines changed: 21 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -26,19 +26,24 @@ static inline phi::DataType GetPromoteType(
2626
kSlotSmallVectorSize>& amp_tensors_vector,
2727
const phi::DataType& amp_dtype) {
2828
auto dst_type = amp_dtype;
29+
// only consider the dtype of input(X).
30+
if (op_name == "batch_norm" || op_name == "layer_norm" ||
31+
op_name == "sync_batch_norm" ||
32+
op_name == "moving_average_abs_max_scale") {
33+
if (amp_tensors_vector[0][0].dtype() == phi::DataType::FLOAT32) {
34+
dst_type = phi::DataType::FLOAT32;
35+
}
36+
return dst_type;
37+
}
38+
2939
if (egr::Controller::Instance().GetCurrentTracer()->GetAmpDtype() ==
3040
"float16") {
31-
if (op_name == "batch_norm" || op_name == "layer_norm" ||
32-
op_name == "sync_batch_norm") {
33-
if (amp_tensors_vector[0][0].dtype() == phi::DataType::FLOAT32) {
34-
dst_type = phi::DataType::FLOAT32;
35-
}
36-
} else if (op_name == "fused_attention") {
41+
if (op_name == "fused_attention") {
3742
for (size_t i = 0; i < amp_tensors_vector.size(); i++) {
3843
if (i != 3 || i != 4 || i != 9 || i != 10) {
3944
if (amp_tensors_vector[i][0].dtype() == phi::DataType::FLOAT32) {
4045
dst_type = phi::DataType::FLOAT32;
41-
break;
46+
return dst_type;
4247
}
4348
}
4449
}
@@ -47,37 +52,22 @@ static inline phi::DataType GetPromoteType(
4752
if (i != 7 || i != 8 || i != 9 || i != 10) {
4853
if (amp_tensors_vector[i][0].dtype() == phi::DataType::FLOAT32) {
4954
dst_type = phi::DataType::FLOAT32;
50-
break;
51-
}
52-
}
53-
}
54-
} else {
55-
for (const auto& tensors : amp_tensors_vector) {
56-
for (const auto& tensor : tensors) {
57-
if (tensor.dtype() == phi::DataType::FLOAT32) {
58-
dst_type = tensor.dtype();
59-
break;
55+
return dst_type;
6056
}
6157
}
6258
}
6359
}
64-
} else {
65-
for (const auto& tensors : amp_tensors_vector) {
66-
for (const auto& tensor : tensors) {
67-
if (tensor.dtype() == phi::DataType::FLOAT32) {
68-
dst_type = tensor.dtype();
69-
break;
70-
}
71-
}
72-
}
7360
}
74-
// NOTE(juncai): moving_average_abs_max_scale only consider the dtype of
75-
// input(X)
76-
if (op_name == "moving_average_abs_max_scale") {
77-
if (amp_tensors_vector[0][0].dtype() == phi::DataType::FLOAT16) {
78-
dst_type = phi::DataType::FLOAT16;
61+
62+
for (const auto& tensors : amp_tensors_vector) {
63+
for (const auto& tensor : tensors) {
64+
if (tensor.dtype() == phi::DataType::FLOAT32) {
65+
dst_type = tensor.dtype();
66+
break;
67+
}
7968
}
8069
}
70+
8171
return dst_type;
8272
}
8373

paddle/fluid/eager/eager_amp_auto_cast.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,15 +89,16 @@ inline paddle::Tensor EagerAmpAutoCast(const std::string& input_name,
8989
VLOG(6) << "AMP AmpAutoCasts:"
9090
<< " input(" << egr::EagerUtils::TensorStr(input) << " to dst_dtype("
9191
<< phi::DataTypeToString(dst_dtype) << ").";
92+
if ((op_name == "batch_norm" || op_name == "layer_norm" ||
93+
op_name == "sync_batch_norm") &&
94+
input_name != "x") {
95+
return input;
96+
}
97+
9298
if (dst_dtype == phi::DataType::FLOAT16) {
9399
if (op_name == "run_program") {
94100
return input;
95101
}
96-
if ((op_name == "batch_norm" || op_name == "layer_norm" ||
97-
op_name == "sync_batch_norm") &&
98-
input_name != "x") {
99-
return input;
100-
}
101102
if ((op_name == "fused_attention" || op_name == "fused_feedforward")) {
102103
if (input_name == "LnScale" || input_name == "LnBias" ||
103104
input_name == "Ln2Scale" || input_name == "Ln2Bias" ||

0 commit comments

Comments
 (0)