Skip to content

Commit 4878744

Browse files
tastelikefeettastelikefeet
authored andcommitted
fix (#4565)
Co-authored-by: tastelikefeet <[email protected]>
1 parent b35ba1d commit 4878744

File tree

2 files changed

+33
-3
lines changed

2 files changed

+33
-3
lines changed

swift/llm/template/base.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,7 @@ def split_multi_medias(_inputs):
370370
positive_encoded = self._encode_truncated(positive)
371371
for key in positive_encoded:
372372
_encoded[f'positive_{key}'] = positive_encoded[key]
373+
_encoded[f'negative_{key}'] = []
373374
labels.append(float(inputs.label) if inputs.label is not None else 1.0)
374375

375376
rejected_len = len(inputs.rejected_response) if inputs.rejected_response else 0
@@ -381,7 +382,7 @@ def split_multi_medias(_inputs):
381382
split_multi_medias(negative)
382383
negative_encoded = self._encode_truncated(negative)
383384
for key in negative_encoded:
384-
_encoded[f'negative{i}_{key}'] = negative_encoded[key]
385+
_encoded[f'negative_{key}'].append(negative_encoded[key])
385386
labels.append(0.0)
386387

387388
_encoded['labels'] = labels
@@ -1314,10 +1315,18 @@ def _embedding_data_collator(self,
13141315
new_batch = []
13151316
for b in batch:
13161317
keys = [key for key in b.keys() if 'negative' in key]
1317-
max_neg = max([int(re.findall(r'negative(-?\d+)', key)[0]) for key in keys]) if keys else None
1318+
max_neg = None
1319+
for key in keys:
1320+
value_list = b[key]
1321+
suffix = key[len('negative_'):]
1322+
max_neg = len(value_list)
1323+
for i, value in enumerate(value_list):
1324+
b[f'negative{i}_{suffix}'] = value
1325+
b.pop(key)
1326+
13181327
indexes = ['anchor_', 'positive_']
13191328
if max_neg is not None:
1320-
for i in range(0, max_neg + 1):
1329+
for i in range(0, max_neg):
13211330
indexes.append(f'negative{i}_')
13221331
for prefix in indexes:
13231332
new_batch += self._fetch_inputs_startswith([b], prefix)

swift/trainers/sequence_parallel/ulysses.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -851,6 +851,27 @@ def prepare_trainer(self, trainer):
851851
trainer._get_per_token_logps = MethodType(_get_per_token_logps, trainer)
852852
trainer.split_by_mini_batches = MethodType(split_by_mini_batches, trainer)
853853

854+
class DataloaderWrap:
855+
856+
def __init__(self, dataloader):
857+
self.dataloader = dataloader
858+
859+
def __getattr__(self, item):
860+
return getattr(self.dataloader, item)
861+
862+
def __len__(wrapped):
863+
return len(wrapped.dataloader) * self.sp_world_size
864+
865+
def __iter__(self):
866+
yield from self.dataloader
867+
868+
def get_train_dataloader(trainer):
869+
dataloader = trainer.get_origin_train_dataloader()
870+
return DataloaderWrap(dataloader)
871+
872+
trainer.get_origin_train_dataloader = trainer.get_train_dataloader
873+
trainer.get_train_dataloader = MethodType(get_train_dataloader, trainer)
874+
854875
from swift.plugin import metric
855876
from swift.trainers import mixin
856877
compute_acc_origin = metric.compute_acc

0 commit comments

Comments
 (0)