17
17
from functools import partial
18
18
19
19
import paddle
20
- from utils .data import get_convert_example
21
20
22
- from paddleformers .data import DataCollatorForSeq2Seq
23
- from paddleformers .datasets import (
24
- ZeroPaddingIterableDataset ,
25
- ZeroPaddingMapDataset ,
26
- load_dataset ,
27
- )
21
+ from paddleformers .datasets .finetuning import collate_fn
22
+ from paddleformers .datasets .finetuning import create_dataset as create_dataset_sft
28
23
from paddleformers .peft import LoRAConfig , LoRAModel
29
- from paddleformers .peft .reft import ReftDataCollator
30
24
from paddleformers .trainer import PdArgumentParser , get_last_checkpoint , set_seed
31
- from paddleformers .trainer .trainer_callback import TrainerState
32
25
from paddleformers .transformers import (
33
26
AutoConfig ,
34
27
AutoModelForCausalLM ,
49
42
)
50
43
from paddleformers .transformers .configuration_utils import LlmMetaConfig
51
44
from paddleformers .trl import DataConfig , ModelConfig , SFTConfig , SFTTrainer
52
- from paddleformers .trl .llm_utils import (
53
- ZeroPaddingIterDatasetCallback ,
54
- compute_metrics ,
55
- get_lora_target_modules ,
56
- init_chat_template ,
57
- )
45
+ from paddleformers .trl .llm_utils import compute_metrics , get_lora_target_modules
58
46
from paddleformers .utils .log import logger
59
47
60
48
# Fine-tune Environment Variables to support sharding stage1 overlap optimization.
@@ -152,6 +140,7 @@ def main():
152
140
model_config .fuse_attention_ffn = model_args .fuse_attention_ffn
153
141
154
142
model_config .seq_length = data_args .max_length
143
+ model_config .num_nextn_predict_layers = model_args .num_nextn_predict_layers
155
144
logger .info (f"Final model config: { model_config } " )
156
145
157
146
logger .info ("Creating model" )
@@ -201,10 +190,10 @@ def neft_post_hook(module, input, output):
201
190
202
191
# Load tokenizer & dataset
203
192
tokenizer = AutoTokenizer .from_pretrained (model_args .model_name_or_path , download_hub = model_args .download_hub )
204
- tokenizer .chat_template = None
193
+ # tokenizer.chat_template = None
205
194
206
195
# init chat_template for tokenizer
207
- init_chat_template (tokenizer , model_args .model_name_or_path , data_args .chat_template )
196
+ # init_chat_template(tokenizer, model_args.model_name_or_path, data_args.chat_template)
208
197
209
198
# if using chat_template, data_args.eval_with_do_generation must be false
210
199
if tokenizer .chat_template is not None :
@@ -213,106 +202,57 @@ def neft_post_hook(module, input, output):
213
202
if isinstance (tokenizer , LlamaTokenizer ) or isinstance (tokenizer , Llama3Tokenizer ):
214
203
tokenizer .pad_token_id = tokenizer .eos_token_id
215
204
216
- train_ds , dev_ds , test_ds = create_dataset (data_args , training_args )
217
-
218
- if training_args .resume_from_checkpoint is not None and data_args .lazy :
219
- logger .info (
220
- f"Loading from '{ training_args .resume_from_checkpoint } ' with `lazy=True`, manually skipping dataset and setting `ignore_data_skip` to True."
221
- )
222
- training_args .ignore_data_skip = True
223
- state = TrainerState .load_from_json (os .path .join (training_args .resume_from_checkpoint , "trainer_state.json" ))
224
- if state .trial_params is not None and "zero_padding_global_step" in state .trial_params :
225
- consumed_samples = state .trial_params ["zero_padding_global_step" ]
226
- else :
227
- consumed_samples = (
228
- state .global_step
229
- * training_args .per_device_train_batch_size
230
- * training_args .gradient_accumulation_steps
231
- * training_args .dataset_world_size
232
- )
233
- logger .info (
234
- f"Skipping the first { consumed_samples } samples to warmup the dataset from checkpoint '{ training_args .resume_from_checkpoint } '."
235
- )
236
- train_ds = train_ds .skip (consumed_samples )
237
-
238
- if training_args .pipeline_parallel_degree > 1 :
239
- from utils .data import convert_example_common
240
-
241
- trans_func = partial (convert_example_common , tokenizer = tokenizer , data_args = data_args )
242
- else :
243
- trans_func = partial (get_convert_example (model ), tokenizer = tokenizer , data_args = data_args )
244
-
245
- eval_zero_padding = data_args .zero_padding
246
- if data_args .zero_padding and data_args .eval_with_do_generation :
247
- logger .warning (
248
- "`zero_padding` conflicts with `eval_with_do_generation`. Setting zero_padding to False for the eval_dataset."
249
- )
250
- eval_zero_padding = False
251
-
252
- logger .info ("Trans the dataset text into token ids, please wait for a moment." )
253
- train_ds , dev_ds , test_ds = trans_dataset_to_ids (
254
- train_ds , dev_ds , test_ds , model_args , data_args , trans_func , eval_zero_padding
205
+ dataset_config = {
206
+ "tokenizer" : tokenizer ,
207
+ "max_seq_len" : training_args .max_seq_length ,
208
+ "random_seed" : training_args .seed ,
209
+ "num_replicas" : 1 ,
210
+ "rank" : 0 ,
211
+ "num_samples_each_epoch" : 6000000 ,
212
+ "random_shuffle" : data_args .random_shuffle ,
213
+ "greedy_intokens" : data_args .greedy_intokens ,
214
+ "packing" : data_args .packing ,
215
+ "mix_strategy" : data_args .mix_strategy ,
216
+ }
217
+
218
+ train_dataset = create_dataset_sft (
219
+ task_group = data_args .train_dataset_path ,
220
+ task_group_prob = data_args .train_dataset_prob ,
221
+ sub_dataset_type = data_args .train_dataset_type ,
222
+ ** dataset_config ,
223
+ )
224
+ eval_dataset = create_dataset_sft (
225
+ task_group = data_args .eval_dataset_path ,
226
+ task_group_prob = data_args .eval_dataset_prob ,
227
+ sub_dataset_type = data_args .eval_dataset_type ,
228
+ is_valid = True ,
229
+ ** dataset_config ,
255
230
)
256
-
257
- if data_args .zero_padding :
258
- if data_args .lazy :
259
- intoken_dataset = ZeroPaddingIterableDataset
260
- else :
261
- intoken_dataset = ZeroPaddingMapDataset
262
- logger .info ("Creating Zero Padding Data Stream. This may take a few minutes." )
263
- if train_ds is not None :
264
- train_ds = intoken_dataset (
265
- train_ds ,
266
- tokenizer = tokenizer ,
267
- max_length = data_args .max_length ,
268
- greedy_zero_padding = data_args .greedy_zero_padding ,
269
- )
270
- if eval_zero_padding and dev_ds is not None :
271
- dev_ds = intoken_dataset (dev_ds , tokenizer = tokenizer , max_length = data_args .max_length )
272
- if eval_zero_padding and test_ds is not None :
273
- test_ds = intoken_dataset (test_ds , tokenizer = tokenizer , max_length = data_args .max_length )
274
231
275
232
model = create_peft_model (model_args , training_args , dtype , model )
276
233
277
234
# Create trainer
278
235
279
- if (
280
- training_args .pipeline_parallel_degree > 1
281
- or training_args .sequence_parallel
282
- or training_args .autotuner_benchmark
283
- or data_args .zero_padding
284
- or data_args .pad_to_max_length
285
- ):
286
- max_length = data_args .max_length
287
- padding = "max_length"
288
- else :
289
- max_length = None
290
- padding = True
291
-
292
236
if training_args .pipeline_parallel_degree > 1 :
293
237
metrics = None
294
238
else :
295
239
metrics = compute_metrics
296
240
297
- data_collator_fn = DataCollatorForSeq2Seq (
241
+ data_collator = partial (
242
+ collate_fn ,
298
243
tokenizer = tokenizer ,
299
- max_length = max_length ,
300
- padding = padding ,
301
- max_label_length = max_length ,
302
- return_tensors = "np" ,
303
- return_attention_mask = not model_args .flash_mask ,
304
- pad_to_multiple_of = data_args .pad_to_multiple_of ,
244
+ model_args = model_args ,
245
+ max_seq_len = training_args .max_seq_length + model_config .num_nextn_predict_layers ,
305
246
)
306
247
trainer = SFTTrainer (
307
248
model = model ,
308
249
args = training_args ,
309
- train_dataset = train_ds ,
310
- eval_dataset = dev_ds ,
250
+ train_dataset = train_dataset ,
251
+ eval_dataset = eval_dataset ,
311
252
tokenizer = tokenizer ,
312
253
compute_metrics = metrics ,
313
- data_collator = data_collator_fn if not model_args . reft else ReftDataCollator ( data_collator = data_collator_fn ) ,
254
+ data_collator = data_collator ,
314
255
do_generation = data_args .eval_with_do_generation ,
315
- callbacks = [ZeroPaddingIterDatasetCallback ()] if isinstance (train_ds , ZeroPaddingIterableDataset ) else None ,
316
256
data_args = data_args ,
317
257
)
318
258
trainable_parameters = [
@@ -344,16 +284,6 @@ def neft_post_hook(module, input, output):
344
284
trainer .save_metrics ("train" , train_result .metrics )
345
285
trainer .save_state ()
346
286
347
- # Evaluation test set
348
- if training_args .do_predict :
349
- eval_result = trainer .predict (test_ds ).metrics
350
- trainer .log_metrics ("test" , eval_result )
351
- # Evaluation dev set
352
- if training_args .do_eval :
353
- logger .info ("*** Evaluate result after train ***" )
354
- eval_result = trainer .evaluate (dev_ds )
355
- trainer .log_metrics ("eval" , eval_result )
356
-
357
287
358
288
def create_peft_model (model_args , training_args , dtype , model ):
359
289
if model_args .lora :
@@ -387,98 +317,5 @@ def create_peft_model(model_args, training_args, dtype, model):
387
317
return model
388
318
389
319
390
- def trans_dataset_to_ids (train_ds , dev_ds , test_ds , model_args , data_args , trans_func , eval_zero_padding ):
391
- if train_ds is not None :
392
- train_ds = train_ds .map (
393
- partial (
394
- trans_func ,
395
- is_test = False ,
396
- zero_padding = data_args .zero_padding ,
397
- flash_mask = model_args .flash_mask ,
398
- )
399
- )
400
- if dev_ds is not None :
401
- dev_ds = dev_ds .map (
402
- partial (
403
- trans_func ,
404
- is_test = data_args .eval_with_do_generation ,
405
- zero_padding = eval_zero_padding ,
406
- flash_mask = model_args .flash_mask ,
407
- )
408
- )
409
- if test_ds is not None :
410
- test_ds = test_ds .map (partial (trans_func , is_test = data_args .eval_with_do_generation ))
411
-
412
- return train_ds , dev_ds , test_ds
413
-
414
-
415
- def create_dataset (data_args , training_args ):
416
- if data_args .dataset_name_or_path is None :
417
- raise ValueError (f"Please specific dataset name or path (got { data_args .dataset_name_or_path } )" )
418
-
419
- train_ds = None
420
- dev_ds = None
421
- test_ds = None
422
- if os .path .exists (os .path .join (data_args .dataset_name_or_path , "train.json" )) or os .path .exists (
423
- os .path .join (data_args .dataset_name_or_path , "dev.json" )
424
- ):
425
- logger .info ("load train" )
426
- if training_args .do_train :
427
- train_ds = load_dataset (
428
- "json" ,
429
- data_files = os .path .join (data_args .dataset_name_or_path , "train.json" ),
430
- lazy = data_args .lazy ,
431
- )[0 ]
432
- logger .info ("load eval" )
433
- if training_args .do_eval :
434
- dev_ds = load_dataset (
435
- "json" ,
436
- data_files = os .path .join (data_args .dataset_name_or_path , "dev.json" ),
437
- lazy = data_args .lazy ,
438
- )[0 ]
439
- logger .info ("load test" )
440
- if training_args .do_predict :
441
- test_ds = load_dataset (
442
- "json" ,
443
- data_files = os .path .join (data_args .dataset_name_or_path , "test.json" ),
444
- lazy = data_args .lazy ,
445
- )[0 ]
446
-
447
- elif os .path .exists (os .path .join (data_args .dataset_name_or_path , "train" )) or os .path .exists (
448
- os .path .join (data_args .dataset_name_or_path , "dev" )
449
- ):
450
- import glob
451
-
452
- if training_args .do_train :
453
- train_ds = load_dataset (
454
- "json" ,
455
- data_files = glob .glob (os .path .join (data_args .dataset_name_or_path , "train" , "*.json" )),
456
- lazy = data_args .lazy ,
457
- )[0 ]
458
- if training_args .do_eval :
459
- dev_ds = load_dataset (
460
- "json" ,
461
- data_files = glob .glob (os .path .join (data_args .dataset_name_or_path , "dev" , "*.json" )),
462
- lazy = data_args .lazy ,
463
- )[0 ]
464
- if training_args .do_predict :
465
- test_ds = load_dataset (
466
- "json" ,
467
- data_files = glob .glob (os .path .join (data_args .dataset_name_or_path , "test" , "*.json" )),
468
- lazy = data_args .lazy ,
469
- )[0 ]
470
- else :
471
- if training_args .do_train :
472
- train_ds = load_dataset (data_args .dataset_name_or_path , splits = ["train" ])[0 ]
473
-
474
- if training_args .do_eval :
475
- dev_ds = load_dataset (data_args .dataset_name_or_path , splits = ["dev" ])[0 ]
476
-
477
- if training_args .do_predict :
478
- test_ds = load_dataset (data_args .dataset_name_or_path , splits = ["test" ])[0 ]
479
-
480
- return train_ds , dev_ds , test_ds
481
-
482
-
483
320
if __name__ == "__main__" :
484
321
main ()
0 commit comments