|
37 | 37 | mark_as_sequence_parallel_parameter,
|
38 | 38 | )
|
39 | 39 | from paddle.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 40 | +from paddle.utils import try_import |
40 | 41 |
|
41 | 42 | from ...utils.converter import StateDictNameMapping
|
42 | 43 | from ...utils.log import logger
|
|
59 | 60 | except:
|
60 | 61 | FusedDropoutAdd = None
|
61 | 62 |
|
| 63 | +OriginLayerNorm = paddle.nn.LayerNorm |
| 64 | + |
| 65 | + |
62 | 66 | __all__ = [
|
63 | 67 | "GPTModel",
|
64 | 68 | "GPTPretrainedModel",
|
|
70 | 74 | "GPTForCausalLM",
|
71 | 75 | "GPTEmbeddings",
|
72 | 76 | "GPTDecoderLayer",
|
| 77 | + "GPTLayerNorm", |
73 | 78 | ]
|
74 | 79 |
|
75 | 80 |
|
@@ -119,6 +124,11 @@ def seed_guard_context(name=None):
|
119 | 124 | return contextlib.nullcontext()
|
120 | 125 |
|
121 | 126 |
|
| 127 | +def fast_layer_norm(input, weight, bias, eps): |
| 128 | + fast_ln_lib = try_import("fast_ln") |
| 129 | + return fast_ln_lib.fast_ln(input, weight, bias, eps)[0] |
| 130 | + |
| 131 | + |
122 | 132 | def _make_causal_mask(input_ids_shape, past_key_values_length):
|
123 | 133 | """
|
124 | 134 | Make causal mask used for self-attention
|
@@ -149,6 +159,11 @@ def _expand_2d_mask(mask, dtype, tgt_length):
|
149 | 159 | return expanded_mask
|
150 | 160 |
|
151 | 161 |
|
| 162 | +def _check_normalized_shape(normalized_shape): |
| 163 | + if isinstance(normalized_shape, (list, tuple)): |
| 164 | + assert len(normalized_shape) == 1 |
| 165 | + |
| 166 | + |
152 | 167 | class MultiHeadAttention(nn.Layer):
|
153 | 168 | """
|
154 | 169 | Attention mapps queries and a set of key-value pairs to outputs, and
|
@@ -196,39 +211,39 @@ def __init__(
|
196 | 211 | 3 * config.hidden_size,
|
197 | 212 | has_bias=True,
|
198 | 213 | gather_output=False,
|
199 |
| - fuse_matmul_bias=config.fused_linear, |
| 214 | + fuse_matmul_bias=config.use_fused_linear, |
200 | 215 | )
|
201 | 216 | else:
|
202 | 217 | self.q_proj = ColumnParallelLinear(
|
203 | 218 | config.hidden_size,
|
204 | 219 | config.hidden_size,
|
205 | 220 | has_bias=True,
|
206 | 221 | gather_output=False,
|
207 |
| - fuse_matmul_bias=config.fused_linear, |
| 222 | + fuse_matmul_bias=config.use_fused_linear, |
208 | 223 | )
|
209 | 224 |
|
210 | 225 | self.k_proj = ColumnParallelLinear(
|
211 | 226 | config.hidden_size,
|
212 | 227 | config.hidden_size,
|
213 | 228 | has_bias=True,
|
214 | 229 | gather_output=False,
|
215 |
| - fuse_matmul_bias=config.fused_linear, |
| 230 | + fuse_matmul_bias=config.use_fused_linear, |
216 | 231 | )
|
217 | 232 |
|
218 | 233 | self.v_proj = ColumnParallelLinear(
|
219 | 234 | config.hidden_size,
|
220 | 235 | config.hidden_size,
|
221 | 236 | has_bias=True,
|
222 | 237 | gather_output=False,
|
223 |
| - fuse_matmul_bias=config.fused_linear, |
| 238 | + fuse_matmul_bias=config.use_fused_linear, |
224 | 239 | )
|
225 | 240 |
|
226 | 241 | self.out_proj = RowParallelLinear(
|
227 | 242 | config.hidden_size,
|
228 | 243 | config.hidden_size,
|
229 | 244 | has_bias=True,
|
230 | 245 | input_is_parallel=True,
|
231 |
| - fuse_matmul_bias=config.fused_linear, |
| 246 | + fuse_matmul_bias=config.use_fused_linear, |
232 | 247 | )
|
233 | 248 | else:
|
234 | 249 | if self.config.fuse_attention_qkv:
|
@@ -421,7 +436,7 @@ def __init__(self, config, decoder_layers, norm=None, hidden_size=None):
|
421 | 436 |
|
422 | 437 | self.config = config
|
423 | 438 | self.layers = decoder_layers
|
424 |
| - self.norm = nn.LayerNorm(config.hidden_size, epsilon=1e-5) |
| 439 | + self.norm = GPTLayerNorm(config, config.hidden_size, epsilon=1e-5) |
425 | 440 |
|
426 | 441 | if config.sequence_parallel:
|
427 | 442 | mark_as_sequence_parallel_parameter(self.norm.weight)
|
@@ -566,21 +581,23 @@ def __init__(self, config: GPTConfig):
|
566 | 581 | config.intermediate_size,
|
567 | 582 | gather_output=False,
|
568 | 583 | has_bias=True,
|
569 |
| - fuse_matmul_bias=self.config.fused_linear, |
| 584 | + fuse_matmul_bias=self.config.use_fused_linear, |
570 | 585 | )
|
| 586 | + |
571 | 587 | self.linear2 = RowParallelLinear(
|
572 | 588 | config.intermediate_size,
|
573 | 589 | config.hidden_size,
|
574 | 590 | input_is_parallel=True,
|
575 | 591 | has_bias=True,
|
576 |
| - fuse_matmul_bias=self.config.fused_linear, |
| 592 | + fuse_matmul_bias=self.config.use_fused_linear, |
577 | 593 | )
|
578 | 594 | else:
|
579 | 595 | self.linear1 = nn.Linear(config.hidden_size, config.intermediate_size, bias_attr=True)
|
580 | 596 | self.linear2 = nn.Linear(config.intermediate_size, config.hidden_size, bias_attr=True)
|
581 | 597 |
|
582 |
| - self.norm1 = nn.LayerNorm(config.hidden_size, epsilon=1e-5) |
583 |
| - self.norm2 = nn.LayerNorm(config.hidden_size, epsilon=1e-5) |
| 598 | + self.norm1 = GPTLayerNorm(config, config.hidden_size, epsilon=1e-5) |
| 599 | + self.norm2 = GPTLayerNorm(config, config.hidden_size, epsilon=1e-5) |
| 600 | + |
584 | 601 | if config.sequence_parallel:
|
585 | 602 | mark_as_sequence_parallel_parameter(self.norm1.weight)
|
586 | 603 | mark_as_sequence_parallel_parameter(self.norm1.bias)
|
@@ -741,6 +758,21 @@ def forward(self, input_ids, position_ids=None, inputs_embeddings=None):
|
741 | 758 | return embeddings
|
742 | 759 |
|
743 | 760 |
|
| 761 | +class GPTLayerNorm(OriginLayerNorm): |
| 762 | + def __init__(self, config, normalized_shape, epsilon=1e-05, weight_attr=None, bias_attr=None, name=None): |
| 763 | + super().__init__( |
| 764 | + normalized_shape=normalized_shape, epsilon=epsilon, weight_attr=weight_attr, bias_attr=bias_attr |
| 765 | + ) |
| 766 | + |
| 767 | + self.config = config |
| 768 | + _check_normalized_shape(self._normalized_shape) |
| 769 | + |
| 770 | + def forward(self, input): |
| 771 | + if self.config.use_fast_layer_norm: |
| 772 | + return fast_layer_norm(input, self.weight, self.bias, self._epsilon) |
| 773 | + return super().forward(input) |
| 774 | + |
| 775 | + |
744 | 776 | class GPTPretrainedModel(PretrainedModel):
|
745 | 777 | """
|
746 | 778 | An abstract class for pretrained GPT models. It provides GPT related
|
|
0 commit comments