Skip to content

Commit eb0b03a

Browse files
committed
[bugfix] fix streaming & compat transformers 4.54 (#5381)
1 parent 430071c commit eb0b03a

File tree

3 files changed

+14
-6
lines changed

3 files changed

+14
-6
lines changed

swift/llm/dataset/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -274,14 +274,14 @@ def __iter__(self):
274274

275275
class EncodePreprocessor(RowPreprocessor):
276276

277-
def __init__(self, template: 'Template'):
277+
def __init__(self, template: 'Template', pre_tokenize: bool = False):
278278
super().__init__()
279279
self.template = template
280-
self.is_multimodal = template.model_meta.is_multimodal
280+
self.pre_tokenize = pre_tokenize
281281

282282
def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]:
283283
encoded = self.template.encode(row, return_length=True)
284-
if self.is_multimodal:
284+
if self.pre_tokenize:
285285
row['length'] = encoded['length']
286286
encoded = row
287287
return encoded

swift/llm/train/sft.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,14 @@ def _prepare_dataset(self):
127127
if i == 1 and predict_with_generate:
128128
# val_dataset
129129
continue
130-
if (args.model_meta.is_multimodal or args.lazy_tokenize) and not args.streaming:
130+
if args.streaming:
131+
preprocessor = EncodePreprocessor(template=template)
132+
dataset = preprocessor(
133+
dataset,
134+
num_proc=args.dataset_num_proc,
135+
load_from_cache_file=args.load_from_cache_file,
136+
strict=args.strict)
137+
elif (args.model_meta.is_multimodal or args.lazy_tokenize):
131138
dataset = LazyLLMDataset(dataset, template.encode, strict=args.strict, random_state=args.data_seed)
132139
if args.packing:
133140
packing_dataset_cls = IterablePackingDataset if args.streaming else PackingDataset
@@ -299,7 +306,7 @@ def _encode_dataset(self, train_dataset, val_dataset):
299306
# val_dataset
300307
continue
301308
if not args.lazy_tokenize and not args.streaming:
302-
preprocessor = EncodePreprocessor(template=template)
309+
preprocessor = EncodePreprocessor(template=template, pre_tokenize=args.model_meta.is_multimodal)
303310
batch_size = 100 if args.model_meta.is_multimodal else 1000
304311
dataset = preprocessor(
305312
dataset,

swift/trainers/trainers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,8 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
399399
from swift.trainers.sequence_parallel import sequence_parallel
400400
loss = sequence_parallel.reduce_outputs(loss, labels)
401401

402-
if getattr(self.args, 'average_tokens_across_devices', False) and self.model_accepts_loss_kwargs:
402+
if getattr(self.args, 'average_tokens_across_devices',
403+
False) and self.model_accepts_loss_kwargs and num_items_in_batch is not None:
403404
loss *= self.accelerator.num_processes
404405

405406
if (outputs.logits is not None and labels is not None and not return_outputs

0 commit comments

Comments
 (0)