44from typing import Any , Dict , List , Optional , Tuple , Union
55
66import torch
7- import torch .distributed as dist
87from peft import PeftModel
98from torch import Tensor , nn
109from torch .nn import CrossEntropyLoss
1514from transformers .modeling_utils import unwrap_model
1615from transformers .models .auto .modeling_auto import \
1716 MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
18- from transformers .trainer_utils import seed_worker
19- from transformers .utils import is_peft_available , is_torch_xla_available
17+ from transformers .utils import is_peft_available
2018
2119from swift .torchacc_utils import (ta_eval_dataloader , ta_test_dataloader ,
2220 ta_train_dataloader )
3028except ImportError :
3129 from transformers .deepspeed import is_deepspeed_zero3_enabled
3230
33- if is_torch_xla_available ():
34- import torch_xla .core .xla_model as xm
35-
36- SUPPORT_XTUNER = False
37-
38- try :
39- from xtuner .parallel .sequence import (init_sequence_parallel ,
40- SequenceParallelSampler ,
41- reduce_sequence_parallel_loss ,
42- get_sequence_parallel_world_size ,
43- get_sequence_parallel_group )
44- from mmengine .device .utils import get_max_cuda_memory
45- SUPPORT_XTUNER = True
46- except ImportError :
47- pass
48-
4931
5032class Trainer (PushToMsHubMixin , SwiftMixin , HfTrainer ):
5133 pass
5234
5335
5436class Seq2SeqTrainer (PushToMsHubMixin , SwiftMixin , HfSeq2SeqTrainer ):
5537
56- def __init__ (self , sequence_parallel_size = 1 , * args , ** kwargs ):
38+ def __init__ (self , * args , ** kwargs ):
5739 super ().__init__ (* args , ** kwargs )
5840 # performance
5941 self .perf : Dict [str , Any ] = {
@@ -67,9 +49,6 @@ def __init__(self, sequence_parallel_size=1, *args, **kwargs):
6749 self .model , 'get_trainable_parameters' ) else None ,
6850 }
6951 self ._acc = torch .tensor (0. ).to (self .args .device )
70- if SUPPORT_XTUNER :
71- self .sequence_parallel_size = sequence_parallel_size
72- init_sequence_parallel (sequence_parallel_size )
7352
7453 def train (self , * args , ** kwargs ) -> torch .Tensor :
7554 res = super ().train (* args , ** kwargs )
@@ -226,7 +205,6 @@ def compute_scaled_loss(self, labels: torch.Tensor,
226205 return loss .mean ()
227206
228207 def compute_loss (self , model , inputs , return_outputs = None ):
229- assert 'labels' in inputs
230208 if not hasattr (self , '_custom_metrics' ):
231209 self ._custom_metrics = {}
232210
@@ -262,17 +240,9 @@ def compute_loss(self, model, inputs, return_outputs=None):
262240 else :
263241 loss = outputs ['loss' ] if isinstance (outputs , dict ) else outputs [0 ]
264242
243+ preds = outputs .logits .argmax (dim = 2 )[..., :- 1 ]
265244 if labels is None :
266245 labels = inputs ['labels' ]
267-
268- if SUPPORT_XTUNER :
269- # reduce loss for logging correctly
270- num_tokens = (labels != - 100 ).sum ()
271- loss = reduce_sequence_parallel_loss (loss , num_tokens ,
272- get_sequence_parallel_group ())
273-
274- preds = outputs .logits .argmax (dim = 2 )[..., :- 1 ]
275-
276246 labels = labels [..., 1 :]
277247 masks = labels != - 100
278248 acc_strategy = getattr (self .args , 'acc_strategy' , 'token' )
@@ -296,90 +266,10 @@ def compute_loss(self, model, inputs, return_outputs=None):
296266 'acc' ] + acc / self .args .gradient_accumulation_steps
297267 return (loss , outputs ) if return_outputs else loss
298268
299- # Support logging cuda memory usage
300- # hacky: Override Trainer's private method
301- def _maybe_log_save_evaluate (self , tr_loss , grad_norm , model , trial , epoch ,
302- ignore_keys_for_eval ):
303- if self .control .should_log and self .state .global_step > self ._globalstep_last_logged :
304- if is_torch_xla_available ():
305- xm .mark_step ()
306-
307- logs : Dict [str , float ] = {}
308-
309- # all_gather + mean() to get average loss over all processes
310- tr_loss_scalar = self ._nested_gather (tr_loss ).mean ().item ()
311-
312- # reset tr_loss to zero
313- tr_loss -= tr_loss
314-
315- logs ['loss' ] = round (
316- tr_loss_scalar /
317- (self .state .global_step - self ._globalstep_last_logged ), 4 )
318- if grad_norm is not None :
319- logs ['grad_norm' ] = grad_norm .detach ().item () if isinstance (
320- grad_norm , torch .Tensor ) else grad_norm
321- logs ['learning_rate' ] = self ._get_learning_rate ()
322- logs ['memory' ] = get_max_cuda_memory ()
323-
324- self ._total_loss_scalar += tr_loss_scalar
325- self ._globalstep_last_logged = self .state .global_step
326- self .store_flos ()
327-
328- self .log (logs )
329-
330- metrics = None
331- if self .control .should_evaluate :
332- metrics = self .evaluate (ignore_keys = ignore_keys_for_eval )
333- self ._report_to_hp_search (trial , self .state .global_step , metrics )
334-
335- # Run delayed LR scheduler now that metrics are populated
336- if isinstance (self .lr_scheduler ,
337- torch .optim .lr_scheduler .ReduceLROnPlateau ):
338- metric_to_check = self .args .metric_for_best_model
339- if not metric_to_check .startswith ('eval_' ):
340- metric_to_check = f'eval_{ metric_to_check } '
341- self .lr_scheduler .step (metrics [metric_to_check ])
342-
343- if self .control .should_save :
344- self ._save_checkpoint (model , trial , metrics = metrics )
345- self .control = self .callback_handler .on_save (
346- self .args , self .state , self .control )
347-
348269 def get_train_dataloader (self ):
349270
350271 if not use_torchacc ():
351- # modified from HFTrainer.get_train_dataloader
352- # RandomSampler -> SequenceParallelSampler
353- if trainer .is_datasets_available ():
354- import datasets
355- if self .train_dataset is None :
356- raise ValueError ('Trainer: training requires a train_dataset.' )
357-
358- train_dataset = self .train_dataset
359- data_collator = self .data_collator
360- if trainer .is_datasets_available () and isinstance (
361- train_dataset , datasets .Dataset ):
362- train_dataset = self ._remove_unused_columns (
363- train_dataset , description = 'training' )
364- else :
365- data_collator = self ._get_collator_with_removed_columns (
366- data_collator , description = 'training' )
367-
368- dataloader_params = {
369- 'batch_size' : self ._train_batch_size ,
370- 'collate_fn' : data_collator ,
371- 'num_workers' : self .args .dataloader_num_workers ,
372- 'pin_memory' : self .args .dataloader_pin_memory ,
373- 'persistent_workers' : self .args .dataloader_persistent_workers ,
374- }
375-
376- if not isinstance (train_dataset , torch .utils .data .IterableDataset ):
377- dataloader_params ['sampler' ] = SequenceParallelSampler (
378- train_dataset , seed = 1024 )
379- dataloader_params ['drop_last' ] = self .args .dataloader_drop_last
380- dataloader_params ['worker_init_fn' ] = seed_worker
381-
382- return DataLoader (train_dataset , ** dataloader_params )
272+ return super ().get_train_dataloader ()
383273
384274 else :
385275 if trainer .is_datasets_available ():
0 commit comments