23
23
import paddle .nn as nn
24
24
import paddle .nn .functional as F
25
25
from paddle import Tensor
26
+
27
+ try :
28
+ from paddle .amp .auto_cast import amp_state
29
+ except ImportError :
30
+ from paddle .fluid .dygraph .amp .auto_cast import amp_state
26
31
from paddle .distributed .fleet .utils import recompute
27
32
28
33
from ...utils .converter import StateDictNameMapping
@@ -90,7 +95,7 @@ def forward(self, hidden_states):
90
95
hidden_states = hidden_states * paddle .rsqrt (variance + self .variance_epsilon )
91
96
92
97
# convert into float16 if necessary
93
- if self .weight .dtype == paddle .float16 :
98
+ if amp_state () or self .weight .dtype == paddle .float16 :
94
99
hidden_states = hidden_states .astype (paddle .float16 )
95
100
return self .weight * hidden_states
96
101
@@ -502,7 +507,7 @@ def forward(
502
507
attention_outputs = self_attention_outputs [2 :] # Keep self-attention outputs and relative position weights
503
508
504
509
# clamp inf values to enable fp16 training
505
- if hidden_states .dtype == paddle .float16 and paddle .isinf (hidden_states ).any ():
510
+ if ( amp_state () or hidden_states .dtype == paddle .float16 ) and paddle .isinf (hidden_states ).any ():
506
511
# TODO finfo
507
512
clamp_value = finfo (hidden_states .dtype ).max - 1000
508
513
hidden_states = paddle .clip (hidden_states , min = - clamp_value , max = clamp_value )
@@ -529,7 +534,7 @@ def forward(
529
534
hidden_states = cross_attention_outputs [0 ]
530
535
531
536
# clamp inf values to enable fp16 training
532
- if hidden_states .dtype == paddle .float16 and paddle .isinf (hidden_states ).any ():
537
+ if ( amp_state () or hidden_states .dtype == paddle .float16 ) and paddle .isinf (hidden_states ).any ():
533
538
clamp_value = finfo (hidden_states .dtype ).max - 1000
534
539
hidden_states = paddle .clip (hidden_states , min = - clamp_value , max = clamp_value )
535
540
@@ -544,7 +549,7 @@ def forward(
544
549
hidden_states = self .layer [- 1 ](hidden_states )
545
550
546
551
# clamp inf values to enable fp16 training
547
- if hidden_states .dtype == paddle .float16 and paddle .isinf (hidden_states ).any ():
552
+ if ( amp_state () or hidden_states .dtype == paddle .float16 ) and paddle .isinf (hidden_states ).any ():
548
553
clamp_value = finfo (hidden_states .dtype ).max - 1000
549
554
hidden_states = paddle .clip (hidden_states , min = - clamp_value , max = clamp_value )
550
555
@@ -1115,7 +1120,7 @@ def invert_attention_mask(self, encoder_attention_mask):
1115
1120
encoder_extended_attention_mask = encoder_attention_mask .unsqueeze ([1 , 2 ])
1116
1121
encoder_extended_attention_mask = encoder_extended_attention_mask .astype (self .dtype ) # fp16 compatibility
1117
1122
1118
- if self .dtype == paddle .float16 :
1123
+ if amp_state () or self .dtype == paddle .float16 :
1119
1124
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask ) * - 1e4
1120
1125
elif self .dtype == paddle .float32 :
1121
1126
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask ) * - 1e4
0 commit comments