Skip to content

Commit c86e447

Browse files
wtmlongongel
andauthored
Sft flash mask (#8664)
* support sft flash mask * dpo support * update * remove dense mask when using flash mask * bugfix * support pp * bugfix * bugfix * bugfix --------- Co-authored-by: gongel <[email protected]>
1 parent 3ebe938 commit c86e447

File tree

13 files changed

+122
-33
lines changed

13 files changed

+122
-33
lines changed

llm/alignment/dpo/dpo_argument.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ class DPOModelArgument:
8787
"help": "The granularity of recompute training can be selected as `full` or `full_attn` or `core_attn`."
8888
},
8989
)
90-
use_attn_mask_start_row_indices: bool = field(
91-
default=False, metadata={"help": "Whether to use attn_mask_start_row_indices in flash attention."}
90+
flash_mask: bool = field(
91+
default=False, metadata={"help": "Whether to use flash mask in flash attention."}
9292
)
9393
virtual_pp_degree: int = field(
9494
default=1,

llm/alignment/dpo/run_dpo.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import os
1818
import sys
1919
import time
20+
import inspect
2021
from functools import partial
2122

2223
import paddle
@@ -36,8 +37,14 @@
3637
preference_collate_fn,
3738
preprocess_preference_data,
3839
)
40+
from paddlenlp.transformers import (
41+
LlamaForCausalLM,
42+
LlamaForCausalLMPipe,
43+
)
3944
from paddlenlp.utils.log import logger
4045

46+
flash_mask_support_list = [LlamaForCausalLM, LlamaForCausalLMPipe]
47+
4148

4249
def main():
4350
"""main"""
@@ -124,6 +131,15 @@ def main():
124131
ref_model = AutoModelForCausalLM.from_config(ref_config)
125132
model.set_state_dict(ref_model.state_dict())
126133

134+
if model_args.flash_mask and not model.config.use_flash_attention:
135+
logger.warning(
136+
"`flash_mask` must use with zero padding and flash attention."
137+
)
138+
model.config.use_flash_attention = True
139+
140+
if model_args.flash_mask and not any(isinstance(model, cls) for cls in flash_mask_support_list):
141+
raise NotImplementedError(f"{model.__class__} not support flash mask.")
142+
127143
if model_args.tokenizer_name_or_path is not None:
128144
tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name_or_path)
129145
else:

llm/config/llama/dpo_argument.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
"sharding_parallel_degree": 1,
2828
"sharding": "stage1",
2929
"use_flash_attention": true,
30-
"use_attn_mask_start_row_indices":false,
30+
"flash_mask":true,
3131
"recompute": false,
3232
"recompute_granularity": "full",
3333
"dpo_beta": 0.1,

llm/config/qwen/dpo_argument.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
"sharding_parallel_degree": 1,
2828
"sharding": "stage1",
2929
"use_flash_attention": true,
30-
"use_attn_mask_start_row_indices":false,
30+
"flash_mask":false,
3131
"recompute": false,
3232
"recompute_granularity": "full",
3333
"dpo_beta": 0.1,

llm/run_finetune.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import json
1515
import os
1616
import sys
17+
import inspect
1718
from functools import partial
1819

1920
import paddle
@@ -51,13 +52,17 @@
5152
AutoTokenizer,
5253
Llama3Tokenizer,
5354
LlamaTokenizer,
55+
LlamaForCausalLM,
56+
LlamaForCausalLMPipe,
5457
)
5558
from paddlenlp.transformers.configuration_utils import LlmMetaConfig
5659
from paddlenlp.utils.log import logger
5760

5861
# Fine-tune Environment Variables to support sharding stage1 overlap optimization.
5962
os.environ["USE_CASUAL_MASK"] = "False"
6063

64+
flash_mask_support_list = [LlamaForCausalLM, LlamaForCausalLMPipe]
65+
6166

6267
def main():
6368
# Arguments
@@ -77,6 +82,7 @@ def main():
7782
raise ValueError(
7883
"--do_train, --do_ptq, --do_gptq and --do_qat cannot work at the same time. Please choose only one at a time"
7984
)
85+
8086

8187
# Setup GPU & distributed training
8288
paddle.set_device(training_args.device)
@@ -160,6 +166,16 @@ def main():
160166
# NOTE(gongenlei): new add autotuner_benchmark
161167
model = model_class.from_config(model_config, dtype=dtype)
162168

169+
if model_args.flash_mask and (not data_args.zero_padding or not model.config.use_flash_attention):
170+
logger.warning(
171+
"`flash_mask` must use with zero padding and flash attention."
172+
)
173+
data_args.zero_padding = True
174+
model.config.use_flash_attention = True
175+
176+
if model_args.flash_mask and not any(isinstance(model, cls) for cls in flash_mask_support_list):
177+
raise NotImplementedError(f"{model.__class__} not support flash mask.")
178+
163179
if training_args.do_train and model_args.neftune:
164180
# Inspired by https://github.com/neelsjain/NEFTune
165181
if hasattr(model, "get_input_embeddings"):
@@ -329,12 +345,12 @@ def neft_post_hook(module, input, output):
329345
"Zero Padding data stream is only implemented for LLaMA, Bloom, ChatGLM and QWen so far."
330346
)
331347
train_ds = (
332-
train_ds.map(partial(trans_func, is_test=False, zero_padding=data_args.zero_padding))
348+
train_ds.map(partial(trans_func, is_test=False, zero_padding=data_args.zero_padding, flash_mask=model_args.flash_mask))
333349
if train_ds is not None
334350
else None
335351
)
336352
ptq_ds = (
337-
ptq_ds.map(partial(trans_func, is_test=False, zero_padding=data_args.zero_padding))
353+
ptq_ds.map(partial(trans_func, is_test=False, zero_padding=data_args.zero_padding, flash_mask=model_args.flash_mask))
338354
if ptq_ds is not None
339355
else None
340356
)
@@ -345,7 +361,7 @@ def neft_post_hook(module, input, output):
345361
)
346362
eval_zero_padding = False
347363
dev_ds = (
348-
dev_ds.map(partial(trans_func, is_test=data_args.eval_with_do_generation, zero_padding=eval_zero_padding))
364+
dev_ds.map(partial(trans_func, is_test=data_args.eval_with_do_generation, zero_padding=eval_zero_padding, flash_mask=model_args.flash_mask))
349365
if dev_ds is not None
350366
else None
351367
)
@@ -498,6 +514,7 @@ def compute_metrics_do_generation(eval_preds):
498514
padding=padding,
499515
max_label_length=max_length,
500516
return_tensors="np",
517+
return_attention_mask=not model_args.flash_mask,
501518
pad_to_multiple_of=data_args.pad_to_multiple_of,
502519
),
503520
do_generation=data_args.eval_with_do_generation,

llm/utils/argument.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,9 @@ class ModelArgument:
209209
aistudio_token: str = field(default=None, metadata={"help": "The token of aistudio"})
210210
neftune: bool = field(default=False, metadata={"help": "Whether to apply NEFT"})
211211
neftune_noise_alpha: float = field(default=5.0, metadata={"help": "NEFT noise alpha"})
212+
flash_mask: bool = field(
213+
default=False, metadata={"help": "Whether to use flash_mask in flash attention."}
214+
)
212215

213216

214217
@dataclass

llm/utils/data.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -173,11 +173,12 @@ def tokenize_rounds_example(tokenizer, example, data_args, **kwargs):
173173
return tokenized_source, labels
174174

175175

176-
def convert_example_common(example, tokenizer, data_args, is_test=True, zero_padding=False):
176+
def convert_example_common(example, tokenizer, data_args, is_test=True, zero_padding=False, flash_mask=False):
177177
if tokenizer.chat_template is not None:
178-
return convert_rounds_example_common(example, tokenizer, data_args, is_test, zero_padding)
178+
return convert_rounds_example_common(example, tokenizer, data_args, is_test, zero_padding, flash_mask)
179179

180180
tokenized_source, tokenized_target_input_ids = tokenize_example(tokenizer, example, data_args)
181+
181182
if is_test:
182183
return {
183184
**tokenized_source,
@@ -194,12 +195,17 @@ def convert_example_common(example, tokenizer, data_args, is_test=True, zero_pad
194195
if "position_ids" in tokenized_source:
195196
features["position_ids"] = list(range(seq_length))
196197
if zero_padding:
197-
features["attention_mask"] = np.tri(seq_length, seq_length, dtype=bool)
198+
if flash_mask:
199+
features["attn_mask_startend_row_indices"] = (
200+
[seq_length] * seq_length
201+
)
202+
else:
203+
features["attention_mask"] = np.tri(seq_length, seq_length, dtype=bool)
198204

199205
return features
200206

201207

202-
def convert_rounds_example_common(example, tokenizer, data_args, is_test=True, zero_padding=False):
208+
def convert_rounds_example_common(example, tokenizer, data_args, is_test=True, zero_padding=False, flash_mask=False):
203209
"""convert multi-rounds conversation example
204210
205211
Args:
@@ -227,7 +233,13 @@ def convert_rounds_example_common(example, tokenizer, data_args, is_test=True, z
227233
seq_length = len(input_ids)
228234
features = {"input_ids": input_ids, "labels": labels}
229235
if zero_padding:
230-
features["attention_mask"] = np.tri(seq_length, seq_length, dtype=bool)
236+
if flash_mask:
237+
features["attn_mask_startend_row_indices"] = (
238+
[seq_length] * seq_length
239+
)
240+
else:
241+
features["attention_mask"] = np.tri(seq_length, seq_length, dtype=bool)
242+
231243

232244
if "position_ids" in rounds_inputs:
233245
rounds_inputs["position_ids"] = rounds_inputs["position_ids"][:-1]

paddlenlp/data/data_collator.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,11 @@ def __call__(self, features, return_tensors=None):
370370
if return_tensors is None:
371371
return_tensors = self.return_tensors
372372
labels = [feature["labels"] for feature in batch] if "labels" in batch[0].keys() else None
373+
use_attn_mask_startend_row_indices = (
374+
[feature["attn_mask_startend_row_indices"] for feature in batch]
375+
if "attn_mask_startend_row_indices" in batch[0].keys()
376+
else None
377+
)
373378
# We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the
374379
# same length to return tensors.
375380
if labels is not None:
@@ -396,6 +401,29 @@ def __call__(self, features, return_tensors=None):
396401
feature["labels"] = np.concatenate([feature["labels"], remainder]).astype(np.int64)
397402
else:
398403
feature["labels"] = np.concatenate([remainder, feature["labels"]]).astype(np.int64)
404+
if use_attn_mask_startend_row_indices is not None:
405+
if self.max_length is not None:
406+
max_length = self.max_length
407+
else:
408+
max_length = max(len(l) for l in use_attn_mask_startend_row_indices)
409+
if self.pad_to_multiple_of is not None:
410+
max_length = (
411+
(max_length + self.pad_to_multiple_of - 1) // self.pad_to_multiple_of * self.pad_to_multiple_of
412+
)
413+
414+
for feature in batch:
415+
pad_len = max_length - len(feature["attn_mask_startend_row_indices"])
416+
remainder = np.zeros([1, pad_len], dtype=np.int32)
417+
feature["attn_mask_startend_row_indices"] = (
418+
np.concatenate(
419+
[remainder, np.array([feature["attn_mask_startend_row_indices"]], dtype=np.int32) + pad_len],
420+
axis=-1,
421+
)
422+
if padding_side == "left"
423+
else np.concatenate(
424+
[np.array([feature["attn_mask_startend_row_indices"]], dtype=np.int32), remainder], axis=-1
425+
)
426+
)
399427

400428
batch = self.tokenizer.pad(
401429
batch,

paddlenlp/datasets/zero_padding_dataset.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,14 @@ class ZeroPadding:
2828
"chosen_labels",
2929
"rejected_labels",
3030
"response_indexs",
31-
"attn_mask_start_row_indices",
31+
"attn_mask_startend_row_indices",
3232
]
3333

3434
@classmethod
3535
def _pad_batch_records(cls, batch_records):
3636
# Only consider supported input keys
3737
input_keys = [key for key in batch_records[0].keys() if key in cls.supported_input_keys]
38-
if "attn_mask_start_row_indices" not in input_keys and "attention_mask" not in input_keys:
38+
if "attn_mask_startend_row_indices" not in input_keys and "attention_mask" not in input_keys:
3939
input_keys.append("attention_mask")
4040
batched_features = {key: [] for key in input_keys}
4141
sequence_sum = 0
@@ -57,9 +57,9 @@ def _pad_batch_records(cls, batch_records):
5757

5858
seq_length = len(record["input_ids"])
5959
# If attention_mask is not given, assume it's causal mask
60-
if "attn_mask_start_row_indices" in record:
61-
attn_mask_start_row_indices = [i + sequence_sum for i in record["attn_mask_start_row_indices"]]
62-
batched_features["attn_mask_start_row_indices"].extend(attn_mask_start_row_indices)
60+
if "attn_mask_startend_row_indices" in record:
61+
attn_mask_startend_row_indices = [i + sequence_sum for i in record["attn_mask_startend_row_indices"]]
62+
batched_features["attn_mask_startend_row_indices"].extend(attn_mask_startend_row_indices)
6363
else:
6464
attention_mask = record.get("attention_mask", np.tril(np.ones([seq_length, seq_length], dtype=bool)))
6565
batched_features["attention_mask"].append(attention_mask)

paddlenlp/transformers/llama/fusion_ops.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,8 @@ def fusion_flash_attention(
211211
else:
212212
if attn_mask_startend_row_indices is not None:
213213
assert alibi is None, "flash_attention_with_sparse_mask not support alibi"
214+
if len(attn_mask_startend_row_indices.shape) == 2:
215+
attn_mask_startend_row_indices = paddle.unsqueeze(attn_mask_startend_row_indices, axis=1)
214216
attn_output = F.flash_attention_with_sparse_mask(
215217
query_states,
216218
key_states,

0 commit comments

Comments
 (0)