Skip to content

Commit 630cf94

Browse files
committed
[train] fix packing/padding_free & predict_with_generate (#4942)
1 parent 1347cdd commit 630cf94

File tree

3 files changed

+14
-5
lines changed

3 files changed

+14
-5
lines changed

swift/llm/train/sft.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def _encode_dataset(self, train_dataset, val_dataset):
227227
num_proc=args.dataset_num_proc,
228228
strict=args.strict,
229229
load_from_cache_file=args.load_from_cache_file)
230-
if val_dataset is not None:
230+
if val_dataset is not None and not predict_with_generate:
231231
val_dataset = packing_dataset_cls(
232232
self.template,
233233
val_dataset,

swift/trainers/trainers.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,9 +216,15 @@ def _predict_data_collator(batch):
216216
def _patch_predict_with_generate(self):
217217
origin_data_collator = self.data_collator
218218
self.data_collator = self._predict_data_collator
219+
_packing = self.template._packing
220+
padding_free = self.template.padding_free
221+
self.template._packing = False
222+
self.template.padding_free = False
219223
try:
220224
yield
221225
finally:
226+
self.template._packing = _packing
227+
self.template.padding_free = padding_free
222228
self.data_collator = origin_data_collator
223229

224230
def evaluate(self, *args, **kwargs):

tests/train/test_sft.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def test_mllm_mp():
5050
from swift.llm import sft_main, TrainArguments, infer_main, InferArguments
5151
result = sft_main(
5252
TrainArguments(
53-
model='bytedance-research/Valley-Eagle-7B',
53+
model='Qwen/Qwen2.5-VL-7B-Instruct',
5454
dataset=['modelscope/coco_2014_caption:validation#20'],
5555
# dataset=['modelscope/coco_2014_caption:validation#20', 'AI-ModelScope/alpaca-gpt4-data-en#20'],
5656
split_dataset_ratio=0.01,
@@ -270,10 +270,13 @@ def test_predict_with_generate():
270270
sft_main(
271271
TrainArguments(
272272
model='Qwen/Qwen2-7B-Instruct',
273-
dataset=['AI-ModelScope/alpaca-gpt4-data-en#40'],
274-
split_dataset_ratio=0.01,
273+
dataset=['AI-ModelScope/alpaca-gpt4-data-en#400'],
275274
predict_with_generate=True,
276-
split_dataset_ratio=0.5,
275+
# padding_free=True,
276+
max_length=512,
277+
packing=True,
278+
attn_impl='flash_attn',
279+
split_dataset_ratio=0.01,
277280
**kwargs))
278281

279282

0 commit comments

Comments
 (0)