Skip to content

Commit 2900f78

Browse files
authored
[LLM] Optimize llm/GPT3 performance (#8172)
1 parent a6d3a28 commit 2900f78

File tree

4 files changed

+66
-18
lines changed

4 files changed

+66
-18
lines changed

llm/run_pretrain.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ class PreTrainingArguments(TrainingArguments):
7575
"help": "Enable fused linear grad add strategy, which will reduce elementwise add for grad accumulation in the backward of nn.Linear ."
7676
},
7777
)
78+
7879
# NOTE(gongenlei): new add autotuner_benchmark
7980
autotuner_benchmark: bool = field(
8081
default=False,
@@ -154,6 +155,18 @@ class ModelArguments:
154155
default=False,
155156
metadata={"help": "llama or other model, use_fused_rms_norm"},
156157
)
158+
use_fast_layer_norm: bool = field(
159+
default=False,
160+
metadata={"help": "GPT3 model, use fast layernorm"},
161+
)
162+
use_fused_linear: bool = field(
163+
default=False,
164+
metadata={"help": "GPT3 model, use fused linear layer"},
165+
)
166+
use_fused_dropout_add: bool = field(
167+
default=False,
168+
metadata={"help": "GPT3 model, use fused `dropout + residual add` op"},
169+
)
157170
fuse_attention_qkv: bool = field(
158171
default=False,
159172
metadata={"help": "whether to fuse attention qkv"},
@@ -440,6 +453,9 @@ def main():
440453

441454
config.use_flash_attention = model_args.use_flash_attention
442455
config.use_fused_rms_norm = model_args.use_fused_rms_norm
456+
config.use_fast_layer_norm = model_args.use_fast_layer_norm
457+
config.use_fused_linear = model_args.use_fused_linear
458+
config.use_fused_dropout_add = model_args.use_fused_dropout_add
443459
config.fuse_attention_qkv = model_args.fuse_attention_qkv
444460
config.fuse_attention_ffn = model_args.fuse_attention_ffn
445461
config.recompute_granularity = model_args.recompute_granularity

paddlenlp/transformers/gpt/configuration.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,8 @@ def __init__(
257257
ignore_index: int = 0,
258258
use_flash_attention: bool = False,
259259
use_fused_dropout_add: bool = False,
260-
fused_linear: bool = False,
260+
use_fast_layer_norm: bool = False,
261+
use_fused_linear: bool = False,
261262
fuse_attention_qkv: bool = False,
262263
fuse_attention_ffn: bool = False,
263264
fused_softmax_with_triangular: bool = False,
@@ -298,7 +299,8 @@ def __init__(
298299
self.tensor_parallel_output = tensor_parallel_output
299300
self.output_attentions = output_attentions
300301
self.ignore_index = ignore_index
301-
self.fused_linear = fused_linear
302+
self.use_fast_layer_norm = use_fast_layer_norm
303+
self.use_fused_linear = use_fused_linear
302304
self.use_fused_dropout_add = use_fused_dropout_add
303305
self.fused_softmax_with_triangular = fused_softmax_with_triangular
304306
self.virtual_pp_degree = virtual_pp_degree

paddlenlp/transformers/gpt/modeling.py

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
mark_as_sequence_parallel_parameter,
3838
)
3939
from paddle.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
40+
from paddle.utils import try_import
4041

4142
from ...utils.converter import StateDictNameMapping
4243
from ...utils.log import logger
@@ -59,6 +60,9 @@
5960
except:
6061
FusedDropoutAdd = None
6162

63+
OriginLayerNorm = paddle.nn.LayerNorm
64+
65+
6266
__all__ = [
6367
"GPTModel",
6468
"GPTPretrainedModel",
@@ -70,6 +74,7 @@
7074
"GPTForCausalLM",
7175
"GPTEmbeddings",
7276
"GPTDecoderLayer",
77+
"GPTLayerNorm",
7378
]
7479

7580

@@ -119,6 +124,11 @@ def seed_guard_context(name=None):
119124
return contextlib.nullcontext()
120125

121126

127+
def fast_layer_norm(input, weight, bias, eps):
128+
fast_ln_lib = try_import("fast_ln")
129+
return fast_ln_lib.fast_ln(input, weight, bias, eps)[0]
130+
131+
122132
def _make_causal_mask(input_ids_shape, past_key_values_length):
123133
"""
124134
Make causal mask used for self-attention
@@ -149,6 +159,11 @@ def _expand_2d_mask(mask, dtype, tgt_length):
149159
return expanded_mask
150160

151161

162+
def _check_normalized_shape(normalized_shape):
163+
if isinstance(normalized_shape, (list, tuple)):
164+
assert len(normalized_shape) == 1
165+
166+
152167
class MultiHeadAttention(nn.Layer):
153168
"""
154169
Attention mapps queries and a set of key-value pairs to outputs, and
@@ -196,39 +211,39 @@ def __init__(
196211
3 * config.hidden_size,
197212
has_bias=True,
198213
gather_output=False,
199-
fuse_matmul_bias=config.fused_linear,
214+
fuse_matmul_bias=config.use_fused_linear,
200215
)
201216
else:
202217
self.q_proj = ColumnParallelLinear(
203218
config.hidden_size,
204219
config.hidden_size,
205220
has_bias=True,
206221
gather_output=False,
207-
fuse_matmul_bias=config.fused_linear,
222+
fuse_matmul_bias=config.use_fused_linear,
208223
)
209224

210225
self.k_proj = ColumnParallelLinear(
211226
config.hidden_size,
212227
config.hidden_size,
213228
has_bias=True,
214229
gather_output=False,
215-
fuse_matmul_bias=config.fused_linear,
230+
fuse_matmul_bias=config.use_fused_linear,
216231
)
217232

218233
self.v_proj = ColumnParallelLinear(
219234
config.hidden_size,
220235
config.hidden_size,
221236
has_bias=True,
222237
gather_output=False,
223-
fuse_matmul_bias=config.fused_linear,
238+
fuse_matmul_bias=config.use_fused_linear,
224239
)
225240

226241
self.out_proj = RowParallelLinear(
227242
config.hidden_size,
228243
config.hidden_size,
229244
has_bias=True,
230245
input_is_parallel=True,
231-
fuse_matmul_bias=config.fused_linear,
246+
fuse_matmul_bias=config.use_fused_linear,
232247
)
233248
else:
234249
if self.config.fuse_attention_qkv:
@@ -421,7 +436,7 @@ def __init__(self, config, decoder_layers, norm=None, hidden_size=None):
421436

422437
self.config = config
423438
self.layers = decoder_layers
424-
self.norm = nn.LayerNorm(config.hidden_size, epsilon=1e-5)
439+
self.norm = GPTLayerNorm(config, config.hidden_size, epsilon=1e-5)
425440

426441
if config.sequence_parallel:
427442
mark_as_sequence_parallel_parameter(self.norm.weight)
@@ -566,21 +581,23 @@ def __init__(self, config: GPTConfig):
566581
config.intermediate_size,
567582
gather_output=False,
568583
has_bias=True,
569-
fuse_matmul_bias=self.config.fused_linear,
584+
fuse_matmul_bias=self.config.use_fused_linear,
570585
)
586+
571587
self.linear2 = RowParallelLinear(
572588
config.intermediate_size,
573589
config.hidden_size,
574590
input_is_parallel=True,
575591
has_bias=True,
576-
fuse_matmul_bias=self.config.fused_linear,
592+
fuse_matmul_bias=self.config.use_fused_linear,
577593
)
578594
else:
579595
self.linear1 = nn.Linear(config.hidden_size, config.intermediate_size, bias_attr=True)
580596
self.linear2 = nn.Linear(config.intermediate_size, config.hidden_size, bias_attr=True)
581597

582-
self.norm1 = nn.LayerNorm(config.hidden_size, epsilon=1e-5)
583-
self.norm2 = nn.LayerNorm(config.hidden_size, epsilon=1e-5)
598+
self.norm1 = GPTLayerNorm(config, config.hidden_size, epsilon=1e-5)
599+
self.norm2 = GPTLayerNorm(config, config.hidden_size, epsilon=1e-5)
600+
584601
if config.sequence_parallel:
585602
mark_as_sequence_parallel_parameter(self.norm1.weight)
586603
mark_as_sequence_parallel_parameter(self.norm1.bias)
@@ -741,6 +758,21 @@ def forward(self, input_ids, position_ids=None, inputs_embeddings=None):
741758
return embeddings
742759

743760

761+
class GPTLayerNorm(OriginLayerNorm):
762+
def __init__(self, config, normalized_shape, epsilon=1e-05, weight_attr=None, bias_attr=None, name=None):
763+
super().__init__(
764+
normalized_shape=normalized_shape, epsilon=epsilon, weight_attr=weight_attr, bias_attr=bias_attr
765+
)
766+
767+
self.config = config
768+
_check_normalized_shape(self._normalized_shape)
769+
770+
def forward(self, input):
771+
if self.config.use_fast_layer_norm:
772+
return fast_layer_norm(input, self.weight, self.bias, self._epsilon)
773+
return super().forward(input)
774+
775+
744776
class GPTPretrainedModel(PretrainedModel):
745777
"""
746778
An abstract class for pretrained GPT models. It provides GPT related

paddlenlp/transformers/gpt/modeling_pp.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414
import paddle
1515
import paddle.distributed.fleet as fleet
16-
import paddle.nn as nn
1716
from paddle.distributed.fleet.meta_parallel import (
1817
LayerDesc,
1918
PipelineLayer,
@@ -30,6 +29,7 @@
3029
GPTConfig,
3130
GPTDecoderLayer,
3231
GPTEmbeddings,
32+
GPTLayerNorm,
3333
GPTLMHead,
3434
GPTPretrainedModel,
3535
GPTPretrainingCriterion,
@@ -103,15 +103,13 @@ def forward(self, args):
103103
embeddings = super().forward(input_ids=input_ids, position_ids=position_ids)
104104

105105
batch_size, seq_length = input_ids.shape
106-
causal_mask = self.bias[:, :, 0:seq_length, :seq_length]
107106
if attention_mask is not None:
108107
if attention_mask.dtype != paddle.int64:
109108
attention_mask = paddle.cast(attention_mask, dtype=paddle.int64)
110109
if len(attention_mask.shape) == 2:
111110
attention_mask = attention_mask[:, None, None, :]
111+
causal_mask = self.bias[:, :, 0:seq_length, :seq_length]
112112
attention_mask = (1.0 - (attention_mask & causal_mask)) * -1e4
113-
else:
114-
attention_mask = (1.0 - causal_mask) * -1e4
115113

116114
return return_args(embeddings, attention_mask, position_ids)
117115

@@ -127,9 +125,9 @@ def forward(self, args):
127125
return return_args(hidden_states, attention_mask, position_ids)
128126

129127

130-
class LayerNormPipe(nn.LayerNorm):
128+
class LayerNormPipe(GPTLayerNorm):
131129
def __init__(self, config):
132-
super(LayerNormPipe, self).__init__(config.hidden_size, epsilon=1e-05)
130+
super(LayerNormPipe, self).__init__(config, config.hidden_size, epsilon=1e-05)
133131
if config.sequence_parallel:
134132
mark_as_sequence_parallel_parameter(self.weight)
135133
mark_as_sequence_parallel_parameter(self.bias)

0 commit comments

Comments
 (0)