Skip to content

Commit ff3e583

Browse files
committed
Merge branch 'main' into release/3.9
2 parents 10bbebe + 23cb839 commit ff3e583

File tree

4 files changed

+11
-14
lines changed

4 files changed

+11
-14
lines changed

examples/train/embedding/train_emb.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,5 @@ swift sft \
2828
--gradient_accumulation_steps 4 \
2929
--learning_rate 6e-6 \
3030
--loss_type infonce \
31-
--label_names labels \
3231
--dataloader_drop_last true \
3332
--deepspeed zero2

swift/llm/dataset/preprocessor/core.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,6 @@ def __call__(
314314
dataset = sample_dataset(dataset, self.dataset_sample, True, self.random_state)
315315

316316
map_kwargs = {'batched': True, 'batch_size': batch_size}
317-
cache_file_name = None
318317
if isinstance(dataset, HfDataset):
319318
if not load_from_cache_file and is_dist() and not is_master():
320319
load_from_cache_file = True
@@ -326,29 +325,28 @@ def __call__(
326325
dataset = RowPreprocessor.get_features_dataset(dataset)
327326
if 'solution' in dataset.features:
328327
with safe_ddp_context(None, True):
329-
if not dataset.cache_files:
330-
cache_file_name = os.path.join(get_cache_dir(), 'datasets', 'map_cache',
331-
f'{dataset._fingerprint}.arrow')
332-
dataset = dataset.map(
333-
lambda x: {'__#solution': x['solution']}, **map_kwargs, cache_file_name=cache_file_name)
328+
if isinstance(dataset, HfDataset) and not dataset.cache_files:
329+
map_kwargs['cache_file_name'] = os.path.join(get_cache_dir(), 'datasets', 'map_cache',
330+
f'{dataset._fingerprint}.arrow')
331+
dataset = dataset.map(lambda x: {'__#solution': x['solution']}, **map_kwargs)
332+
map_kwargs.pop('cache_file_name', None)
334333
dataset = self._rename_columns(dataset)
335334
dataset = self.prepare_dataset(dataset)
336335
dataset = self._cast_pil_image(dataset)
337336

338337
ignore_max_length_error = True if isinstance(dataset, HfDataset) and num_proc > 1 else False
339338
with self._patch_arrow_writer(), safe_ddp_context(None, True):
340339
try:
341-
if not dataset.cache_files:
342-
cache_file_name = os.path.join(get_cache_dir(), 'datasets', 'map_cache',
343-
f'{dataset._fingerprint}.arrow')
340+
if isinstance(dataset, HfDataset) and not dataset.cache_files:
341+
map_kwargs['cache_file_name'] = os.path.join(get_cache_dir(), 'datasets', 'map_cache',
342+
f'{dataset._fingerprint}.arrow')
344343
dataset_mapped = dataset.map(
345344
self.batched_preprocess,
346345
fn_kwargs={
347346
'strict': strict,
348347
'ignore_max_length_error': ignore_max_length_error
349348
},
350349
remove_columns=list(dataset.features.keys()),
351-
cache_file_name=cache_file_name,
352350
**map_kwargs)
353351
except NotImplementedError:
354352
pass

swift/llm/template/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,7 @@ def packing_row(self, row: List[Dict[str, Any]]) -> Dict[str, Any]:
553553
length.append(r['length'])
554554
for key in keys:
555555
if key in {'input_ids', 'labels', 'loss_scale'}:
556-
packed[key] = sum((x[key] for x in row), start=[])
556+
packed[key] = sum((x.get(key) or [] for x in row), start=[])
557557
elif key == 'length':
558558
packed[key] = sum((x[key] for x in row))
559559
elif key == 'channel':

swift/megatron/model/gpt_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -250,9 +250,9 @@ def forward(
250250
logits, _ = self.output_layer(
251251
hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output)
252252
else:
253-
logits = self.output_layer(hidden_states)[0]
254253
if args.sequence_parallel and args.tensor_model_parallel_size > 1:
255-
logits = gather_from_sequence_parallel_region(logits)
254+
hidden_states = gather_from_sequence_parallel_region(hidden_states)
255+
logits = self.output_layer(hidden_states)[0]
256256
if has_config_logger_enabled(self.config):
257257
payload = OrderedDict({
258258
'input_ids': input_ids,

0 commit comments

Comments
 (0)