14
14
import json
15
15
import os
16
16
import sys
17
+ import inspect
17
18
from functools import partial
18
19
19
20
import paddle
51
52
AutoTokenizer ,
52
53
Llama3Tokenizer ,
53
54
LlamaTokenizer ,
55
+ LlamaForCausalLM ,
56
+ LlamaForCausalLMPipe ,
54
57
)
55
58
from paddlenlp .transformers .configuration_utils import LlmMetaConfig
56
59
from paddlenlp .utils .log import logger
57
60
58
61
# Fine-tune Environment Variables to support sharding stage1 overlap optimization.
59
62
os .environ ["USE_CASUAL_MASK" ] = "False"
60
63
64
+ flash_mask_support_list = [LlamaForCausalLM , LlamaForCausalLMPipe ]
65
+
61
66
62
67
def main ():
63
68
# Arguments
@@ -77,6 +82,7 @@ def main():
77
82
raise ValueError (
78
83
"--do_train, --do_ptq, --do_gptq and --do_qat cannot work at the same time. Please choose only one at a time"
79
84
)
85
+
80
86
81
87
# Setup GPU & distributed training
82
88
paddle .set_device (training_args .device )
@@ -160,6 +166,16 @@ def main():
160
166
# NOTE(gongenlei): new add autotuner_benchmark
161
167
model = model_class .from_config (model_config , dtype = dtype )
162
168
169
+ if model_args .flash_mask and (not data_args .zero_padding or not model .config .use_flash_attention ):
170
+ logger .warning (
171
+ "`flash_mask` must use with zero padding and flash attention."
172
+ )
173
+ data_args .zero_padding = True
174
+ model .config .use_flash_attention = True
175
+
176
+ if model_args .flash_mask and not any (isinstance (model , cls ) for cls in flash_mask_support_list ):
177
+ raise NotImplementedError (f"{ model .__class__ } not support flash mask." )
178
+
163
179
if training_args .do_train and model_args .neftune :
164
180
# Inspired by https://github.com/neelsjain/NEFTune
165
181
if hasattr (model , "get_input_embeddings" ):
@@ -329,12 +345,12 @@ def neft_post_hook(module, input, output):
329
345
"Zero Padding data stream is only implemented for LLaMA, Bloom, ChatGLM and QWen so far."
330
346
)
331
347
train_ds = (
332
- train_ds .map (partial (trans_func , is_test = False , zero_padding = data_args .zero_padding ))
348
+ train_ds .map (partial (trans_func , is_test = False , zero_padding = data_args .zero_padding , flash_mask = model_args . flash_mask ))
333
349
if train_ds is not None
334
350
else None
335
351
)
336
352
ptq_ds = (
337
- ptq_ds .map (partial (trans_func , is_test = False , zero_padding = data_args .zero_padding ))
353
+ ptq_ds .map (partial (trans_func , is_test = False , zero_padding = data_args .zero_padding , flash_mask = model_args . flash_mask ))
338
354
if ptq_ds is not None
339
355
else None
340
356
)
@@ -345,7 +361,7 @@ def neft_post_hook(module, input, output):
345
361
)
346
362
eval_zero_padding = False
347
363
dev_ds = (
348
- dev_ds .map (partial (trans_func , is_test = data_args .eval_with_do_generation , zero_padding = eval_zero_padding ))
364
+ dev_ds .map (partial (trans_func , is_test = data_args .eval_with_do_generation , zero_padding = eval_zero_padding , flash_mask = model_args . flash_mask ))
349
365
if dev_ds is not None
350
366
else None
351
367
)
@@ -498,6 +514,7 @@ def compute_metrics_do_generation(eval_preds):
498
514
padding = padding ,
499
515
max_label_length = max_length ,
500
516
return_tensors = "np" ,
517
+ return_attention_mask = not model_args .flash_mask ,
501
518
pad_to_multiple_of = data_args .pad_to_multiple_of ,
502
519
),
503
520
do_generation = data_args .eval_with_do_generation ,
0 commit comments