Skip to content

Commit ae7dc15

Browse files
authored
add mc2 & finetune fused (#8139)
1 parent 2273ee7 commit ae7dc15

File tree

5 files changed

+267
-11
lines changed

5 files changed

+267
-11
lines changed

.copyright.hook

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ RE_SHEBANG = re.compile(r"^[ \t\v]*#[ \t]?\!")
7171
def _check_copyright(path):
7272
head=[]
7373
try:
74-
with open(path) as f:
74+
with open(path, encoding="utf-8") as f:
7575
head = [next(f) for x in range(4)]
7676
except StopIteration:
7777
pass

llm/argument.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from dataclasses import dataclass, field
15+
from typing import List, Optional
1516

1617
from paddlenlp.trainer import TrainingArguments
1718
from paddlenlp.trainer.trainer_utils import IntervalStrategy
@@ -48,6 +49,9 @@ class DataArgument:
4849
dataset_name_or_path: str = field(default=None, metadata={"help": "Name or path for dataset"})
4950
task_name: str = field(default=None, metadata={"help": "Additional name to select a more specific task."})
5051
zero_padding: bool = field(default=False, metadata={"help": "Whether to use Zero Padding data stream"})
52+
pad_to_multiple_of: int = field(
53+
default=None, metadata={"help": "If set will pad the sequence to a multiple of the provided value."}
54+
)
5155
src_length: int = field(default=1024, metadata={"help": "The maximum length of source(context) tokens."})
5256
max_length: int = field(
5357
default=2048,
@@ -102,6 +106,64 @@ class ModelArgument:
102106
default=None, metadata={"help": "Build-in pretrained model name or the path to local model."}
103107
)
104108
use_flash_attention: bool = field(default=False, metadata={"help": "Whether to use flash attention"})
109+
tokenizer_name_or_path: Optional[str] = field(
110+
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
111+
)
112+
use_fused_rms_norm: bool = field(
113+
default=False,
114+
metadata={"help": "llama or other model, use_fused_rms_norm"},
115+
)
116+
fuse_attention_qkv: bool = field(
117+
default=False,
118+
metadata={"help": "whether to fuse attention qkv"},
119+
)
120+
fuse_attention_ffn: bool = field(
121+
default=False,
122+
metadata={"help": "whether to fuse first up and gate proj in mlp block"},
123+
)
124+
recompute_granularity: str = field(
125+
default="full",
126+
metadata={"help": "Choose among ['full', 'core_attn', 'full_attn']"},
127+
)
128+
virtual_pp_degree: int = field(
129+
default=1,
130+
metadata={"help": "virtual_pp_degree"},
131+
)
132+
hidden_dropout_prob: float = field(default=0.1, metadata={"help": "The hidden dropout prob."})
133+
attention_probs_dropout_prob: float = field(default=0.1, metadata={"help": "The attention hidden dropout prob."})
134+
135+
continue_training: bool = field(
136+
default=False,
137+
metadata={
138+
"help": "Pre-training from existing paddlenlp model weights. Default False and model will train from scratch. If set True, the model_name_or_path argument must exist in the paddlenlp models."
139+
},
140+
)
141+
sequence_parallel: bool = field(
142+
default=False,
143+
metadata={"help": "whether to use sequence parallel"},
144+
)
145+
fuse_sequence_parallel_allreduce: bool = field(
146+
default=False,
147+
metadata={"help": "whether to use fuse sequence parallel allreduce"},
148+
)
149+
use_fused_rope: Optional[bool] = field(
150+
default=False,
151+
metadata={"help": "Enable rope fusion or not."},
152+
)
153+
no_recompute_layers: Optional[List[int]] = field(
154+
default=None,
155+
metadata={"help": "Specify the full transformer layers that should not be recomputed."},
156+
)
157+
pp_recompute_interval: int = field(
158+
default=1,
159+
metadata={
160+
"help": "The interval for the number of layers at which recomputation occurs. A value of 0 indicates no recomputation. Default is 0."
161+
},
162+
)
163+
recompute_use_reentrant: bool = field(
164+
default=False,
165+
metadata={"help": "recompute_use_reentrant"},
166+
)
105167
weight_quantize_algo: str = field(
106168
default=None,
107169
metadata={

llm/finetune_generation.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,30 @@ def main():
154154
if hasattr(model_config, "use_flash_attention"):
155155
model_config.use_flash_attention = model_args.use_flash_attention
156156

157+
model_config.use_fused_rms_norm = model_args.use_fused_rms_norm
158+
model_config.fuse_attention_qkv = model_args.fuse_attention_qkv
159+
model_config.fuse_attention_ffn = model_args.fuse_attention_ffn
160+
model_config.recompute_granularity = model_args.recompute_granularity
161+
model_config.virtual_pp_degree = model_args.virtual_pp_degree
162+
model_config.sequence_parallel = model_args.sequence_parallel
163+
model_config.fuse_sequence_parallel_allreduce = model_args.fuse_sequence_parallel_allreduce
164+
model_config.use_fused_rope = model_args.use_fused_rope
165+
166+
model_config.no_recompute_layers = model_args.no_recompute_layers
167+
model_config.pp_recompute_interval = model_args.pp_recompute_interval
168+
model_config.recompute_use_reentrant = model_args.recompute_use_reentrant
169+
model_config.use_recompute = training_args.recompute
170+
171+
model_config.tensor_parallel_degree = training_args.tensor_parallel_degree
172+
model_config.tensor_parallel_rank = training_args.tensor_parallel_rank
173+
174+
# Config for model using dropout, such as GPT.
175+
model_config.hidden_dropout_prob = model_args.hidden_dropout_prob
176+
model_config.attention_probs_dropout_prob = model_args.attention_probs_dropout_prob
177+
178+
model_config.sep_parallel_degree = training_args.sep_parallel_degree
179+
model_config.tensor_parallel_output = True
180+
model_config.seq_length = data_args.max_length
157181
if not training_args.autotuner_benchmark:
158182
model = AutoModelForCausalLM.from_pretrained(
159183
model_args.model_name_or_path,
@@ -494,6 +518,7 @@ def compute_metrics_do_generation(eval_preds):
494518
padding=padding,
495519
max_label_length=max_length,
496520
return_tensors="np",
521+
pad_to_multiple_of=data_args.pad_to_multiple_of,
497522
),
498523
do_generation=data_args.eval_with_do_generation,
499524
callbacks=[InTokensIterDatasetCallback()] if isinstance(train_ds, InTokensIterableDataset) else None,

paddlenlp/transformers/llama/modeling.py

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from __future__ import annotations
1717

1818
import math
19+
import os
1920
import warnings
2021
from functools import partial
2122
from typing import Optional, Tuple
@@ -75,8 +76,6 @@ def swiglu(x, y=None):
7576

7677
try:
7778
if get_env_device() == "npu":
78-
import os
79-
8079
from paddle.base import core
8180

8281
for lib in os.listdir(os.getenv("CUSTOM_DEVICE_ROOT")):
@@ -94,6 +93,13 @@ def swiglu(x, y=None):
9493
]
9594

9695

96+
def is_mc2_valid():
97+
current_device = get_env_device()
98+
if current_device == "npu":
99+
return True
100+
return False
101+
102+
97103
def _get_interleave(n):
98104
def _get_interleave_power_of_2(n):
99105
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
@@ -565,8 +571,17 @@ def __init__(self, config):
565571
self.fuse_attention_ffn = config.fuse_attention_ffn
566572

567573
if config.sequence_parallel:
568-
ColumnParallelLinear = ColumnSequenceParallelLinear
569-
RowParallelLinear = RowSequenceParallelLinear
574+
if is_mc2_valid and int(os.getenv("FLAGS_NPU_MC2", 0)):
575+
from paddlenlp.transformers.mc2_seqence_parallel_linear import (
576+
MC2ColumnSeqParallelLinear,
577+
MC2RowSeqParallelLinear,
578+
)
579+
580+
ColumnParallelLinear = MC2ColumnSeqParallelLinear
581+
RowParallelLinear = MC2RowSeqParallelLinear
582+
else:
583+
ColumnParallelLinear = ColumnSequenceParallelLinear
584+
RowParallelLinear = RowSequenceParallelLinear
570585
else:
571586
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
572587
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
@@ -670,7 +685,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
670685
)
671686

672687
self.use_fused_rope = config.use_fused_rope
673-
if self.use_fused_rope:
688+
if self.use_fused_rope and get_env_device() != "npu":
674689
if "gpu" not in paddle.device.get_device() or fused_rotary_position_embedding is None:
675690
warnings.warn(
676691
"Enable fuse rope in the config, but fuse rope is not available. "
@@ -679,8 +694,17 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
679694
self.use_fused_rope = False
680695

681696
if config.sequence_parallel:
682-
ColumnParallelLinear = ColumnSequenceParallelLinear
683-
RowParallelLinear = RowSequenceParallelLinear
697+
if is_mc2_valid and int(os.getenv("FLAGS_NPU_MC2", 0)):
698+
from paddlenlp.transformers.mc2_seqence_parallel_linear import (
699+
MC2ColumnSeqParallelLinear,
700+
MC2RowSeqParallelLinear,
701+
)
702+
703+
ColumnParallelLinear = MC2ColumnSeqParallelLinear
704+
RowParallelLinear = MC2RowSeqParallelLinear
705+
else:
706+
ColumnParallelLinear = ColumnSequenceParallelLinear
707+
RowParallelLinear = RowSequenceParallelLinear
684708
else:
685709
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
686710
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
@@ -1526,9 +1550,12 @@ def forward(
15261550
attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype
15271551
) # [bs, 1, seq_len, seq_len]
15281552
if self.config.use_flash_attention:
1529-
is_casual = is_casual_mask(attention_mask)
1530-
if is_casual and alibi is None:
1531-
attention_mask = None
1553+
if get_env_device != "npu":
1554+
is_casual = is_casual_mask(attention_mask)
1555+
if is_casual and alibi is None:
1556+
attention_mask = None
1557+
else:
1558+
attention_mask = attention_mask.astype("bool")
15321559
hidden_states = inputs_embeds
15331560

15341561
# decoder layers
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
17+
import paddle
18+
19+
try:
20+
import paddle_custom_device
21+
except ImportError:
22+
raise ImportError("Current device does not support MC2!")
23+
24+
from paddle import distributed as dist
25+
from paddle.autograd import PyLayer
26+
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
27+
ColumnSequenceParallelLinear,
28+
RowSequenceParallelLinear,
29+
)
30+
31+
__all_gather_recomputation__ = False
32+
if int(os.getenv("MC2_Recompute", 0)):
33+
__all_gather_recomputation__ = True
34+
35+
36+
class MC2Column(PyLayer):
37+
@staticmethod
38+
def forward(ctx, input_, weight, group):
39+
ctx.save_for_backward(input_, weight)
40+
41+
rank = dist.get_rank()
42+
hcomm_info = group.process_group.get_comm_name(rank)
43+
44+
world_size = group.nranks
45+
output, gather_out = paddle_custom_device.npu.fused_allgather_mm(
46+
input_,
47+
weight,
48+
bias=None,
49+
hcom=hcomm_info,
50+
world_size=world_size,
51+
gather_index=0,
52+
gather_output=(not __all_gather_recomputation__),
53+
comm_turn=0,
54+
)
55+
56+
ctx.all_gather_output = gather_out
57+
ctx.world_size = world_size
58+
ctx.group = group
59+
return output
60+
61+
@staticmethod
62+
def backward(ctx, grad_output):
63+
input_, weight = ctx.saved_tensor()
64+
65+
if __all_gather_recomputation__:
66+
dim_size = input_.shape
67+
dim_size[0] = dim_size[0] * ctx.world_size
68+
all_gather_output = paddle.empty(dim_size, dtype=input_.dtype)
69+
all_gather_output.stop_gradient = True
70+
all_gather_work = dist.stream.all_gather(all_gather_output, input_, group=ctx.group, sync_op=False)
71+
else:
72+
all_gather_output = ctx.all_gather_output
73+
74+
grad_input = paddle.matmul(grad_output, weight, transpose_y=True)
75+
sub_grad_input = paddle.empty(input_.shape, dtype=input_.dtype)
76+
reduce_scatter_work = dist.stream.reduce_scatter(sub_grad_input, grad_input, group=ctx.group, sync_op=False)
77+
78+
if __all_gather_recomputation__:
79+
all_gather_work.wait()
80+
81+
grad_weight = paddle.matmul(all_gather_output, grad_output, transpose_x=True)
82+
reduce_scatter_work.wait()
83+
84+
return sub_grad_input, grad_weight
85+
86+
87+
class MC2Row(PyLayer):
88+
@staticmethod
89+
def forward(ctx, input_, weight, group):
90+
ctx.save_for_backward(input_, weight)
91+
92+
rank = dist.get_rank()
93+
hcomm_info = group.process_group.get_comm_name(rank)
94+
world_size = group.nranks
95+
96+
output = paddle_custom_device.npu.fused_mm_reduce_scatter(
97+
input_,
98+
weight,
99+
bias=None,
100+
hcom=hcomm_info,
101+
world_size=world_size,
102+
reduce_op="sum",
103+
comm_turn=0,
104+
)
105+
106+
ctx.hcomm_info = hcomm_info
107+
ctx.world_size = world_size
108+
return output
109+
110+
@staticmethod
111+
def backward(ctx, grad_output):
112+
input_, weight = ctx.saved_tensor()
113+
hcomm_info = ctx.hcomm_info
114+
world_size = ctx.world_size
115+
116+
grad_input, all_gather_grad_output = paddle_custom_device.npu.fused_allgather_mm(
117+
grad_output,
118+
weight.t(),
119+
bias=None,
120+
hcom=hcomm_info,
121+
world_size=world_size,
122+
gather_index=0,
123+
gather_output=True,
124+
comm_turn=0,
125+
)
126+
grad_weight = paddle.matmul(input_, all_gather_grad_output, transpose_x=True)
127+
128+
return grad_input, grad_weight
129+
130+
131+
class MC2ColumnSeqParallelLinear(ColumnSequenceParallelLinear):
132+
def forward(self, x):
133+
output = MC2Column.apply(x, self.weight, self.model_parallel_group)
134+
output = output + self.bias if self.bias is not None else output
135+
return output
136+
137+
138+
class MC2RowSeqParallelLinear(RowSequenceParallelLinear):
139+
def forward(self, x):
140+
output = MC2Row.apply(x, self.weight, self.model_parallel_group)
141+
output = output + self.bias if self.bias is not None else output
142+
return output

0 commit comments

Comments
 (0)