Skip to content

Commit 9c05b38

Browse files
authored
cherry pick from pr #2584 (#2592)
1 parent e479070 commit 9c05b38

File tree

4 files changed

+7
-3
lines changed

4 files changed

+7
-3
lines changed

examples/alignment/dpo/run_dpo.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,7 @@ def main():
294294
eval_dataset = None
295295
logger.info("Creating dataset successfully ...")
296296

297+
max_seq_len = data_args.max_seq_len if data_args.packing else None
297298
trainer = DPOTrainer(
298299
model=model,
299300
ref_model=ref_model,
@@ -305,7 +306,7 @@ def main():
305306
data_collator=partial(
306307
collate_fn,
307308
tokenizer=tokenizer,
308-
max_seq_len=data_args.max_seq_len,
309+
max_seq_len=max_seq_len,
309310
use_sparse_head_and_loss_fn=model_args.use_sparse_head_and_loss_fn,
310311
use_fused_head_and_loss_fn=model_args.use_fused_head_and_loss_fn,
311312
),

examples/run_finetune.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,11 +262,12 @@ def neft_post_hook(module, input, output):
262262
else:
263263
metrics = compute_metrics
264264

265+
max_seq_len = training_args.max_seq_len + model_config.num_nextn_predict_layers if data_args.packing else None
265266
data_collator = partial(
266267
collate_fn,
267268
tokenizer=tokenizer,
268269
model_args=model_args,
269-
max_seq_len=training_args.max_seq_len + model_config.num_nextn_predict_layers,
270+
max_seq_len=max_seq_len,
270271
)
271272

272273
if training_args.max_steps == -1:

paddleformers/datasets/dpo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def collate_fn(
153153
- attn_mask_startend_row_indices (int32, optional): Sparse attention row indices [batch_size, max_seq_len]
154154
"""
155155
if max_seq_len is None:
156-
raise ValueError("max_seq_len is None.")
156+
max_seq_len = max(len(item.input_ids) for sequence in batch for item in sequence)
157157

158158
input_dict = {
159159
"input_ids": [],

paddleformers/datasets/finetuning.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,8 @@ def collate_fn(batch: List[List[Sequence]], tokenizer, model_args, max_seq_len:
130130
else:
131131
input_keys.append("attention_mask")
132132
return_list = []
133+
if max_seq_len is None:
134+
max_seq_len = max(len(item.token_ids) for sequence in batch for item in sequence)
133135
for batch_sequence in batch:
134136
original_token_ids = [seq.token_ids for seq in batch_sequence]
135137
token_ids = [sum(original_token_ids, [])]

0 commit comments

Comments
 (0)