@@ -26,19 +26,24 @@ static inline phi::DataType GetPromoteType(
26
26
kSlotSmallVectorSize >& amp_tensors_vector,
27
27
const phi::DataType& amp_dtype) {
28
28
auto dst_type = amp_dtype;
29
+ // only consider the dtype of input(X).
30
+ if (op_name == " batch_norm" || op_name == " layer_norm" ||
31
+ op_name == " sync_batch_norm" ||
32
+ op_name == " moving_average_abs_max_scale" ) {
33
+ if (amp_tensors_vector[0 ][0 ].dtype () == phi::DataType::FLOAT32) {
34
+ dst_type = phi::DataType::FLOAT32;
35
+ }
36
+ return dst_type;
37
+ }
38
+
29
39
if (egr::Controller::Instance ().GetCurrentTracer ()->GetAmpDtype () ==
30
40
" float16" ) {
31
- if (op_name == " batch_norm" || op_name == " layer_norm" ||
32
- op_name == " sync_batch_norm" ) {
33
- if (amp_tensors_vector[0 ][0 ].dtype () == phi::DataType::FLOAT32) {
34
- dst_type = phi::DataType::FLOAT32;
35
- }
36
- } else if (op_name == " fused_attention" ) {
41
+ if (op_name == " fused_attention" ) {
37
42
for (size_t i = 0 ; i < amp_tensors_vector.size (); i++) {
38
43
if (i != 3 || i != 4 || i != 9 || i != 10 ) {
39
44
if (amp_tensors_vector[i][0 ].dtype () == phi::DataType::FLOAT32) {
40
45
dst_type = phi::DataType::FLOAT32;
41
- break ;
46
+ return dst_type ;
42
47
}
43
48
}
44
49
}
@@ -47,37 +52,22 @@ static inline phi::DataType GetPromoteType(
47
52
if (i != 7 || i != 8 || i != 9 || i != 10 ) {
48
53
if (amp_tensors_vector[i][0 ].dtype () == phi::DataType::FLOAT32) {
49
54
dst_type = phi::DataType::FLOAT32;
50
- break ;
51
- }
52
- }
53
- }
54
- } else {
55
- for (const auto & tensors : amp_tensors_vector) {
56
- for (const auto & tensor : tensors) {
57
- if (tensor.dtype () == phi::DataType::FLOAT32) {
58
- dst_type = tensor.dtype ();
59
- break ;
55
+ return dst_type;
60
56
}
61
57
}
62
58
}
63
59
}
64
- } else {
65
- for (const auto & tensors : amp_tensors_vector) {
66
- for (const auto & tensor : tensors) {
67
- if (tensor.dtype () == phi::DataType::FLOAT32) {
68
- dst_type = tensor.dtype ();
69
- break ;
70
- }
71
- }
72
- }
73
60
}
74
- // NOTE(juncai): moving_average_abs_max_scale only consider the dtype of
75
- // input(X)
76
- if (op_name == " moving_average_abs_max_scale" ) {
77
- if (amp_tensors_vector[0 ][0 ].dtype () == phi::DataType::FLOAT16) {
78
- dst_type = phi::DataType::FLOAT16;
61
+
62
+ for (const auto & tensors : amp_tensors_vector) {
63
+ for (const auto & tensor : tensors) {
64
+ if (tensor.dtype () == phi::DataType::FLOAT32) {
65
+ dst_type = tensor.dtype ();
66
+ break ;
67
+ }
79
68
}
80
69
}
70
+
81
71
return dst_type;
82
72
}
83
73
0 commit comments