|
34 | 34 | from swift.llm import BatchSamplerShard, DataLoaderDispatcher, DataLoaderShard, Template |
35 | 35 | from swift.plugin import MeanMetric, compute_acc, extra_tuners |
36 | 36 | from swift.tuners import SwiftModel |
37 | | -from swift.utils import get_logger, is_mp_ddp, ms_logger_context, seed_worker, use_torchacc |
| 37 | +from swift.utils import get_logger, is_mp, is_mp_ddp, ms_logger_context, seed_worker, use_torchacc |
38 | 38 | from swift.utils.torchacc_utils import ta_trim_graph |
39 | 39 | from ..utils.torch_utils import get_device_count |
40 | 40 | from .arguments import TrainingArguments |
@@ -484,6 +484,36 @@ def _evalscope_eval(self): |
484 | 484 | self.model.train() |
485 | 485 | return eval_dict |
486 | 486 |
|
| 487 | + def get_logits_to_keep(self, labels): |
| 488 | + if labels.shape[0] == 1 and not is_mp(): |
| 489 | + # device_map may encounter device mismatch issues. |
| 490 | + loss_mask = (labels != -100)[0] |
| 491 | + labels = labels[:, loss_mask] |
| 492 | + labels = nn.functional.pad(labels, (1, 0), value=-100) |
| 493 | + logits_to_keep = nn.functional.pad(loss_mask[1:], (0, 1), value=True) |
| 494 | + else: |
| 495 | + logits_to_keep = labels.shape[-1] - ((labels != -100).int().argmax(-1).min().item()) + 1 |
| 496 | + assert logits_to_keep > 0 |
| 497 | + labels = labels[:, -logits_to_keep:] |
| 498 | + return labels, logits_to_keep |
| 499 | + |
| 500 | + def get_cu_seqlens(self, position_ids, logits_to_keep) -> torch.Tensor: |
| 501 | + assert position_ids.shape[0] == 1 |
| 502 | + position_ids = position_ids[0] |
| 503 | + indices = torch.arange(position_ids.shape[0], device=position_ids.device) |
| 504 | + cu_seqlens = torch.concat([ |
| 505 | + indices[position_ids == 0], |
| 506 | + torch.tensor(position_ids.shape, device=position_ids.device), |
| 507 | + ]) |
| 508 | + res_cu_seqlens = cu_seqlens.clone() |
| 509 | + if isinstance(logits_to_keep, torch.Tensor): |
| 510 | + for i in range(cu_seqlens.shape[0] - 1): |
| 511 | + start, end = cu_seqlens[i], cu_seqlens[i + 1] |
| 512 | + res_cu_seqlens[i + 1:] -= (~logits_to_keep[start:end]).sum() |
| 513 | + elif isinstance(logits_to_keep, int): |
| 514 | + res_cu_seqlens[1:] -= position_ids.shape[0] + 1 - logits_to_keep |
| 515 | + return res_cu_seqlens |
| 516 | + |
487 | 517 | def get_batch_samples(self, *args, **kwargs): |
488 | 518 | res = super().get_batch_samples(*args, **kwargs) |
489 | 519 | from swift.trainers.sequence_parallel import sequence_parallel |
|
0 commit comments