Skip to content

Commit 99eb861

Browse files
authored
Fix T5 to_static + amp by amp_state (#4783)
1 parent 48d87a7 commit 99eb861

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

paddlenlp/transformers/t5/modeling.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@
2323
import paddle.nn as nn
2424
import paddle.nn.functional as F
2525
from paddle import Tensor
26+
27+
try:
28+
from paddle.amp.auto_cast import amp_state
29+
except ImportError:
30+
from paddle.fluid.dygraph.amp.auto_cast import amp_state
2631
from paddle.distributed.fleet.utils import recompute
2732

2833
from ...utils.converter import StateDictNameMapping
@@ -90,7 +95,7 @@ def forward(self, hidden_states):
9095
hidden_states = hidden_states * paddle.rsqrt(variance + self.variance_epsilon)
9196

9297
# convert into float16 if necessary
93-
if self.weight.dtype == paddle.float16:
98+
if amp_state() or self.weight.dtype == paddle.float16:
9499
hidden_states = hidden_states.astype(paddle.float16)
95100
return self.weight * hidden_states
96101

@@ -502,7 +507,7 @@ def forward(
502507
attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights
503508

504509
# clamp inf values to enable fp16 training
505-
if hidden_states.dtype == paddle.float16 and paddle.isinf(hidden_states).any():
510+
if (amp_state() or hidden_states.dtype == paddle.float16) and paddle.isinf(hidden_states).any():
506511
# TODO finfo
507512
clamp_value = finfo(hidden_states.dtype).max - 1000
508513
hidden_states = paddle.clip(hidden_states, min=-clamp_value, max=clamp_value)
@@ -529,7 +534,7 @@ def forward(
529534
hidden_states = cross_attention_outputs[0]
530535

531536
# clamp inf values to enable fp16 training
532-
if hidden_states.dtype == paddle.float16 and paddle.isinf(hidden_states).any():
537+
if (amp_state() or hidden_states.dtype == paddle.float16) and paddle.isinf(hidden_states).any():
533538
clamp_value = finfo(hidden_states.dtype).max - 1000
534539
hidden_states = paddle.clip(hidden_states, min=-clamp_value, max=clamp_value)
535540

@@ -544,7 +549,7 @@ def forward(
544549
hidden_states = self.layer[-1](hidden_states)
545550

546551
# clamp inf values to enable fp16 training
547-
if hidden_states.dtype == paddle.float16 and paddle.isinf(hidden_states).any():
552+
if (amp_state() or hidden_states.dtype == paddle.float16) and paddle.isinf(hidden_states).any():
548553
clamp_value = finfo(hidden_states.dtype).max - 1000
549554
hidden_states = paddle.clip(hidden_states, min=-clamp_value, max=clamp_value)
550555

@@ -1115,7 +1120,7 @@ def invert_attention_mask(self, encoder_attention_mask):
11151120
encoder_extended_attention_mask = encoder_attention_mask.unsqueeze([1, 2])
11161121
encoder_extended_attention_mask = encoder_extended_attention_mask.astype(self.dtype) # fp16 compatibility
11171122

1118-
if self.dtype == paddle.float16:
1123+
if amp_state() or self.dtype == paddle.float16:
11191124
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e4
11201125
elif self.dtype == paddle.float32:
11211126
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e4

0 commit comments

Comments
 (0)