Skip to content

Commit 0247075

Browse files
Temporary revert code, waiting for refactor(#816)
1 parent fbf37a4 commit 0247075

File tree

5 files changed

+4
-207
lines changed

5 files changed

+4
-207
lines changed

swift/llm/sft.py

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import json
77
import numpy as np
88
import torch
9-
import torch.distributed as dist
109
from modelscope import BitsAndBytesConfig, GenerationConfig
1110
from transformers import IntervalStrategy
1211
from transformers.integrations import is_deepspeed_zero3_enabled
@@ -27,17 +26,6 @@
2726
print_example, set_generation_config, sort_by_max_length,
2827
stat_dataset)
2928

30-
SUPPORT_XTUNER = False
31-
32-
try:
33-
from xtuner.parallel.sequence import *
34-
# datasets is required in Xtuner
35-
from datasets import Dataset
36-
from xtuner.dataset.huggingface import pack_dataset
37-
SUPPORT_XTUNER = True
38-
except ImportError:
39-
pass
40-
4129
logger = get_logger()
4230

4331

@@ -208,25 +196,6 @@ def llm_sft(args: SftArguments) -> Dict[str, Union[str, Any]]:
208196
dataset_info['train_dataset'] = stat_dataset(train_dataset)
209197
if val_dataset is not None:
210198
dataset_info['val_dataset'] = stat_dataset(val_dataset)
211-
if args.pack_to_max_length:
212-
assert SUPPORT_XTUNER, \
213-
('Please install XTuner first to pack dataset to `max_length`.'
214-
'`pip install -U \'xtuner[deepspeed]\'`')
215-
if dist.get_rank() == 0:
216-
ds = [i[0] for i in train_dataset.data]
217-
train_dataset = Dataset.from_list(ds)
218-
train_dataset = pack_dataset(
219-
train_dataset,
220-
max_length=args.max_length,
221-
use_varlen_attn=False,
222-
shuffle_before_pack=True,
223-
map_num_proc=16)
224-
objects = [train_dataset]
225-
train_dataset.save_to_disk('alpaca_pack')
226-
else:
227-
objects = [None]
228-
dist.broadcast_object_list(objects, src=0)
229-
train_dataset = objects[0]
230199
else:
231200
dataset_info = None
232201
td0, tkwargs0 = template.encode(train_dataset[0])
@@ -267,7 +236,6 @@ def llm_sft(args: SftArguments) -> Dict[str, Union[str, Any]]:
267236
trainer_kwargs['check_model'] = False
268237

269238
trainer = Seq2SeqTrainer(
270-
sequence_parallel_size=args.sequence_parallel_size,
271239
model=model,
272240
args=training_args,
273241
data_collator=data_collator,

swift/llm/tuner.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,6 @@
1919
from .utils import (SftArguments, find_all_linears, find_embedding, find_ln,
2020
is_adapter)
2121

22-
SUPPORT_XTUNER = False
23-
24-
try:
25-
from xtuner.model.modules.dispatch import dispatch_modules
26-
from xtuner.parallel.sequence import *
27-
SUPPORT_XTUNER = True
28-
except ImportError:
29-
pass
30-
3122
logger = get_logger()
3223

3324

@@ -208,9 +199,6 @@ def prepare_model(model, args: SftArguments):
208199
model.load_state_dict(state_dict, False)
209200
# release memory
210201
del state_dict
211-
if SUPPORT_XTUNER:
212-
dispatch_modules(model)
213-
logger.info('Dispatch modules for sequence parallel.')
214202
else:
215203
raise ValueError(f'args.sft_type: {args.sft_type}')
216204

swift/llm/utils/argument.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -491,10 +491,6 @@ class SftArguments(ArgumentsBase):
491491
# fsdp config file
492492
fsdp_config: Optional[str] = None
493493

494-
# xtuner config
495-
sequence_parallel_size: int = 1
496-
pack_to_max_length: bool = False
497-
498494
def handle_dataset_mixture(self, train_dataset: HfDataset) -> None:
499495
if train_dataset is None:
500496
return train_dataset

swift/llm/utils/template.py

Lines changed: 0 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,6 @@
1616
from swift.torchacc_utils import pad_and_split_batch
1717
from swift.utils import get_dist_setting, use_torchacc
1818

19-
SUPPORT_XTUNER = False
20-
21-
try:
22-
from xtuner.parallel.sequence import (pad_for_sequence_parallel,
23-
split_for_sequence_parallel,
24-
get_sequence_parallel_group,
25-
get_sequence_parallel_world_size)
26-
SUPPORT_XTUNER = True
27-
except ImportError:
28-
pass
29-
3019
DEFAULT_SYSTEM = 'You are a helpful assistant.'
3120
History = List[Union[Tuple[str, str], List[str]]]
3221

@@ -432,31 +421,6 @@ def _concat_tokenizer_kwargs(
432421
assert len(old_tokenizer_kwargs) == 0
433422
return curr_tokenizer_kwargs
434423

435-
def _pad_and_split_for_sequence_parallel(self, tokenizer, input_ids,
436-
labels, position_ids,
437-
attention_mask, loss_scale):
438-
input_ids = pad_for_sequence_parallel(
439-
input_ids, padding_value=tokenizer.pad_token_id, dim=-1)
440-
labels = pad_for_sequence_parallel(labels, padding_value=-100, dim=-1)
441-
position_ids = pad_for_sequence_parallel(
442-
position_ids, padding_value=0, dim=-1)
443-
attention_mask = pad_for_sequence_parallel(
444-
attention_mask, padding_value=0, dim=-1)
445-
446-
sp_group = get_sequence_parallel_group()
447-
input_ids = split_for_sequence_parallel(
448-
input_ids, dim=1, sp_group=sp_group)
449-
labels = split_for_sequence_parallel(labels, dim=1, sp_group=sp_group)
450-
position_ids = split_for_sequence_parallel(
451-
position_ids, dim=1, sp_group=sp_group)
452-
if loss_scale is not None:
453-
loss_scale = pad_for_sequence_parallel(
454-
loss_scale, padding_value=0., dim=-1)
455-
loss_scale = split_for_sequence_parallel(
456-
loss_scale, dim=1, sp_group=sp_group)
457-
458-
return input_ids, labels, position_ids, attention_mask, loss_scale
459-
460424
def data_collator(self,
461425
batch: List[Dict[str, Any]],
462426
padding_to: Optional[int] = None) -> Dict[str, Any]:
@@ -506,19 +470,10 @@ def data_collator(self,
506470
padding_to, input_ids, attention_mask, labels, loss_scale,
507471
self.max_length, self.tokenizer, rank, world_size)
508472

509-
bs, seq_len = input_ids.shape
510-
position_ids = torch.arange(seq_len).unsqueeze(0).long().repeat(bs, 1)
511-
512-
if get_sequence_parallel_world_size() > 1:
513-
input_ids, labels, position_ids, attention_mask, loss_scale = \
514-
self._pad_and_split_for_sequence_parallel(
515-
tokenizer, input_ids, labels, position_ids, attention_mask, loss_scale)
516-
517473
res = {
518474
'input_ids': input_ids,
519475
'attention_mask': attention_mask,
520476
'labels': labels,
521-
'position_ids': position_ids,
522477
}
523478
if loss_scale is not None:
524479
res['loss_scale'] = loss_scale

swift/trainers/trainers.py

Lines changed: 4 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from typing import Any, Dict, List, Optional, Tuple, Union
55

66
import torch
7-
import torch.distributed as dist
87
from peft import PeftModel
98
from torch import Tensor, nn
109
from torch.nn import CrossEntropyLoss
@@ -15,8 +14,7 @@
1514
from transformers.modeling_utils import unwrap_model
1615
from transformers.models.auto.modeling_auto import \
1716
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
18-
from transformers.trainer_utils import seed_worker
19-
from transformers.utils import is_peft_available, is_torch_xla_available
17+
from transformers.utils import is_peft_available
2018

2119
from swift.torchacc_utils import (ta_eval_dataloader, ta_test_dataloader,
2220
ta_train_dataloader)
@@ -30,30 +28,14 @@
3028
except ImportError:
3129
from transformers.deepspeed import is_deepspeed_zero3_enabled
3230

33-
if is_torch_xla_available():
34-
import torch_xla.core.xla_model as xm
35-
36-
SUPPORT_XTUNER = False
37-
38-
try:
39-
from xtuner.parallel.sequence import (init_sequence_parallel,
40-
SequenceParallelSampler,
41-
reduce_sequence_parallel_loss,
42-
get_sequence_parallel_world_size,
43-
get_sequence_parallel_group)
44-
from mmengine.device.utils import get_max_cuda_memory
45-
SUPPORT_XTUNER = True
46-
except ImportError:
47-
pass
48-
4931

5032
class Trainer(PushToMsHubMixin, SwiftMixin, HfTrainer):
5133
pass
5234

5335

5436
class Seq2SeqTrainer(PushToMsHubMixin, SwiftMixin, HfSeq2SeqTrainer):
5537

56-
def __init__(self, sequence_parallel_size=1, *args, **kwargs):
38+
def __init__(self, *args, **kwargs):
5739
super().__init__(*args, **kwargs)
5840
# performance
5941
self.perf: Dict[str, Any] = {
@@ -67,9 +49,6 @@ def __init__(self, sequence_parallel_size=1, *args, **kwargs):
6749
self.model, 'get_trainable_parameters') else None,
6850
}
6951
self._acc = torch.tensor(0.).to(self.args.device)
70-
if SUPPORT_XTUNER:
71-
self.sequence_parallel_size = sequence_parallel_size
72-
init_sequence_parallel(sequence_parallel_size)
7352

7453
def train(self, *args, **kwargs) -> torch.Tensor:
7554
res = super().train(*args, **kwargs)
@@ -226,7 +205,6 @@ def compute_scaled_loss(self, labels: torch.Tensor,
226205
return loss.mean()
227206

228207
def compute_loss(self, model, inputs, return_outputs=None):
229-
assert 'labels' in inputs
230208
if not hasattr(self, '_custom_metrics'):
231209
self._custom_metrics = {}
232210

@@ -262,17 +240,9 @@ def compute_loss(self, model, inputs, return_outputs=None):
262240
else:
263241
loss = outputs['loss'] if isinstance(outputs, dict) else outputs[0]
264242

243+
preds = outputs.logits.argmax(dim=2)[..., :-1]
265244
if labels is None:
266245
labels = inputs['labels']
267-
268-
if SUPPORT_XTUNER:
269-
# reduce loss for logging correctly
270-
num_tokens = (labels != -100).sum()
271-
loss = reduce_sequence_parallel_loss(loss, num_tokens,
272-
get_sequence_parallel_group())
273-
274-
preds = outputs.logits.argmax(dim=2)[..., :-1]
275-
276246
labels = labels[..., 1:]
277247
masks = labels != -100
278248
acc_strategy = getattr(self.args, 'acc_strategy', 'token')
@@ -296,90 +266,10 @@ def compute_loss(self, model, inputs, return_outputs=None):
296266
'acc'] + acc / self.args.gradient_accumulation_steps
297267
return (loss, outputs) if return_outputs else loss
298268

299-
# Support logging cuda memory usage
300-
# hacky: Override Trainer's private method
301-
def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch,
302-
ignore_keys_for_eval):
303-
if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
304-
if is_torch_xla_available():
305-
xm.mark_step()
306-
307-
logs: Dict[str, float] = {}
308-
309-
# all_gather + mean() to get average loss over all processes
310-
tr_loss_scalar = self._nested_gather(tr_loss).mean().item()
311-
312-
# reset tr_loss to zero
313-
tr_loss -= tr_loss
314-
315-
logs['loss'] = round(
316-
tr_loss_scalar /
317-
(self.state.global_step - self._globalstep_last_logged), 4)
318-
if grad_norm is not None:
319-
logs['grad_norm'] = grad_norm.detach().item() if isinstance(
320-
grad_norm, torch.Tensor) else grad_norm
321-
logs['learning_rate'] = self._get_learning_rate()
322-
logs['memory'] = get_max_cuda_memory()
323-
324-
self._total_loss_scalar += tr_loss_scalar
325-
self._globalstep_last_logged = self.state.global_step
326-
self.store_flos()
327-
328-
self.log(logs)
329-
330-
metrics = None
331-
if self.control.should_evaluate:
332-
metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
333-
self._report_to_hp_search(trial, self.state.global_step, metrics)
334-
335-
# Run delayed LR scheduler now that metrics are populated
336-
if isinstance(self.lr_scheduler,
337-
torch.optim.lr_scheduler.ReduceLROnPlateau):
338-
metric_to_check = self.args.metric_for_best_model
339-
if not metric_to_check.startswith('eval_'):
340-
metric_to_check = f'eval_{metric_to_check}'
341-
self.lr_scheduler.step(metrics[metric_to_check])
342-
343-
if self.control.should_save:
344-
self._save_checkpoint(model, trial, metrics=metrics)
345-
self.control = self.callback_handler.on_save(
346-
self.args, self.state, self.control)
347-
348269
def get_train_dataloader(self):
349270

350271
if not use_torchacc():
351-
# modified from HFTrainer.get_train_dataloader
352-
# RandomSampler -> SequenceParallelSampler
353-
if trainer.is_datasets_available():
354-
import datasets
355-
if self.train_dataset is None:
356-
raise ValueError('Trainer: training requires a train_dataset.')
357-
358-
train_dataset = self.train_dataset
359-
data_collator = self.data_collator
360-
if trainer.is_datasets_available() and isinstance(
361-
train_dataset, datasets.Dataset):
362-
train_dataset = self._remove_unused_columns(
363-
train_dataset, description='training')
364-
else:
365-
data_collator = self._get_collator_with_removed_columns(
366-
data_collator, description='training')
367-
368-
dataloader_params = {
369-
'batch_size': self._train_batch_size,
370-
'collate_fn': data_collator,
371-
'num_workers': self.args.dataloader_num_workers,
372-
'pin_memory': self.args.dataloader_pin_memory,
373-
'persistent_workers': self.args.dataloader_persistent_workers,
374-
}
375-
376-
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
377-
dataloader_params['sampler'] = SequenceParallelSampler(
378-
train_dataset, seed=1024)
379-
dataloader_params['drop_last'] = self.args.dataloader_drop_last
380-
dataloader_params['worker_init_fn'] = seed_worker
381-
382-
return DataLoader(train_dataset, **dataloader_params)
272+
return super().get_train_dataloader()
383273

384274
else:
385275
if trainer.is_datasets_available():

0 commit comments

Comments
 (0)