diff --git a/modelopt/torch/distill/distillation_model.py b/modelopt/torch/distill/distillation_model.py index 339ff2c3e..930b68560 100644 --- a/modelopt/torch/distill/distillation_model.py +++ b/modelopt/torch/distill/distillation_model.py @@ -239,7 +239,7 @@ def compute_kd_loss( student_loss: torch.Tensor | None = None, loss_reduction_fn: Callable | None = None, skip_balancer: bool = False, - labels: torch.Tensor | None = None, + **loss_fn_kwargs, ) -> torch.Tensor | dict[str, torch.Tensor]: """Compute total loss for distillation backpropagation. @@ -248,8 +248,8 @@ def compute_kd_loss( loss_reduction_fn: Callable to be called on each loss tensor prior to balancing. Useful for loss-masking situations where the callable changes arguments each iteration. skip_balancer: Whether or not to use loss balancer to reduce the loss dict into a scalar. - labels: Labels to be passed to the loss function, if needed. This is necessary for losses that - require labels, such as MFTLoss. + **loss_fn_kwargs: Additional keyword arguments to be passed to the loss function, if needed. + This facilitates losses that require extras, such as labels for ``mtd.MFTLoss``. Returns: If reduce is True, the scalar total loss weighted between ``student_loss`` and the distillation losses. @@ -268,8 +268,7 @@ def compute_kd_loss( student_layer._intermediate_output = None teacher_layer._intermediate_output = None - extra_kwargs = {"labels": labels} if labels is not None else {} - loss = loss_fn(out_s, out_t, **extra_kwargs) # Student is pred, Teacher is target + loss = loss_fn(out_s, out_t, **loss_fn_kwargs) # Student is pred, Teacher is target if loss_reduction_fn is not None: # Needed in cases where a loss mask is used on non-scalar loss-fn outputs, prior to # reducing to a scalar loss value. diff --git a/modelopt/torch/distill/plugins/megatron.py b/modelopt/torch/distill/plugins/megatron.py index bde873004..6e712fcbb 100644 --- a/modelopt/torch/distill/plugins/megatron.py +++ b/modelopt/torch/distill/plugins/megatron.py @@ -18,29 +18,67 @@ """Distillation loss function(s).""" import logging -import types +import re from abc import ABCMeta -from typing import Any +from collections.abc import Callable +from dataclasses import dataclass, field +from types import MethodType +from typing import TYPE_CHECKING import torch import torch.nn as nn import torch.nn.functional as F import yaml -from megatron.core.dist_checkpointing.mapping import ShardedStateDict -from megatron.core.parallel_state import get_tensor_model_parallel_group -from megatron.core.tensor_parallel import gather_from_sequence_parallel_region -from megatron.core.transformer import MegatronModule, TransformerConfig +from megatron.core import parallel_state +from megatron.core.pipeline_parallel.schedules import get_tensor_shapes +from megatron.core.transformer import MegatronModule, TransformerLayer +from megatron.core.utils import get_model_config from torch import Tensor from torch.nn.modules.loss import _Loss import modelopt.torch.distill as mtd +from modelopt.torch.distill.config import Criterion + +if TYPE_CHECKING: + from megatron.core.dist_checkpointing.mapping import ShardedStateDict + from megatron.core.transformer import TransformerConfig + logger = logging.getLogger(__name__) +@dataclass +class DistillationConfig: + """Knowledge-Distillation config. + + Args: + intermediate_layer_pairs: List of tuples of intermediate layer names. + logit_layers: Tuple of logit layer names. + skip_lm_loss: Whether to skip computing the standard language model loss (default: ``True``). + kd_loss_scale: Relative scaling factor for the distillation loss if ``skip_lm_loss`` is ``False``. + logit_kl_temperature: Temperature for the logit KL-divergence loss. + """ + + intermediate_layer_pairs: list[tuple[str, str]] = field(default_factory=list) + logit_layers: tuple[str, str] = ("output_layer", "output_layer") + skip_lm_loss: bool = True + kd_loss_scale: float = 1.0 + logit_kl_temperature: float = 1.0 + criterion: Criterion | None = None + loss_balancer: mtd.DistillationLossBalancer | None = None + + def __post_init__(self): + assert len(self.logit_layers) == 2, f"{self.logit_layers=}" + assert all(len(pair) == 2 for pair in self.intermediate_layer_pairs), ( + f"{self.intermediate_layer_pairs=}" + ) + assert self.kd_loss_scale > 0, f"{self.kd_loss_scale=}" + assert self.logit_kl_temperature > 0, f"{self.logit_kl_temperature=}" + + def load_distillation_config( - config_path: str | None, student_cfg: TransformerConfig, teacher_cfg: TransformerConfig -) -> dict[str, Any]: + config_path: str | None, student_cfg: "TransformerConfig", teacher_cfg: "TransformerConfig" +) -> DistillationConfig: """Read the distillation yaml config file specified by ``args.export_kd_cfg``. Args: @@ -51,43 +89,64 @@ def load_distillation_config( WARNING: Assumes intermediate hidden sizes are always that found in the model config's ``hidden_size`` attribute. """ - if not config_path: - logger.warning("Distillation config not provided. Using default.") - cfg = { - "logit_layers": ["output_layer", "output_layer"], - "intermediate_layer_pairs": [], - "skip_lm_loss": True, - "kd_loss_scale": 1.0, - } - else: + if config_path: with open(config_path) as f: cfg = yaml.safe_load(f) + cfg = DistillationConfig(**cfg) + else: + logger.warning("Distillation config not provided. Using default.") + cfg = DistillationConfig() - intermediate_pairs: list[str] = cfg["intermediate_layer_pairs"] - logit_pair: list[str] = cfg["logit_layers"] - skip_lm_loss: bool = cfg["skip_lm_loss"] - loss_scale: float = cfg["kd_loss_scale"] - - criterion = {tuple(logit_pair): LogitsKLLoss(student_cfg, teacher_cfg)} - for layer_names in intermediate_pairs: - if torch.distributed.get_rank() == 0: - print( - "Distillation: Adding intermediate loss between" - f" `{layer_names[0]}` of student (hidden size {student_cfg.hidden_size}) and" - f" `{layer_names[1]}` of teacher (hidden size {teacher_cfg.hidden_size})." + criterion = {} + if student_cfg.pipeline_model_parallel_size == 1 or parallel_state.is_pipeline_last_stage(): + criterion[tuple(cfg.logit_layers)] = LogitsKLLoss( + student_cfg, temperature=cfg.logit_kl_temperature + ) + # NOTE: Projection layer shared among intermediate layer pairs. + projection_layer = ProjectionLayer(student_cfg, teacher_cfg) + + for student_layer, teacher_layer in cfg.intermediate_layer_pairs: + if parallel_state.get_tensor_and_context_parallel_rank() == 0: + logger.info( + "Distillation: Adding intermediate loss between" + f" `{student_layer}` of student (hidden size {student_cfg.hidden_size}) and" + f" `{teacher_layer}` of teacher (hidden size {teacher_cfg.hidden_size})." + ) + student_layer = _adjust_layer_index_for_pp(student_layer, student_cfg) + teacher_layer = _adjust_layer_index_for_pp(teacher_layer, teacher_cfg) + criterion[(student_layer, teacher_layer)] = HiddenStateCosineLoss( + student_cfg, projection_layer=projection_layer ) - criterion[tuple(layer_names)] = HiddenStateCosineLoss(student_cfg, teacher_cfg) loss_balancer = LogitsAndIntermediatesLossBalancer( - kd_loss_scale=loss_scale, skip_original_loss=skip_lm_loss + kd_loss_scale=cfg.kd_loss_scale, skip_original_loss=cfg.skip_lm_loss ) - cfg["criterion"] = criterion - cfg["loss_balancer"] = loss_balancer + cfg.criterion = criterion + cfg.loss_balancer = loss_balancer return cfg +def _adjust_layer_index_for_pp(submodule_name, model_cfg): + """Adjust any sequence-based layer indices found in a submodule name for Pipeline Parallelism.""" + match = re.search(r"(?<=\.)\d+(?=\.)", submodule_name) + if not match: + return submodule_name + + offset = TransformerLayer._get_layer_offset(model_cfg) + new_layer_idx = int(match.group(0)) - offset + if new_layer_idx < 0: + raise ValueError(f"Layer {submodule_name} does not fall on final PP rank.") + + new_submodule_name = submodule_name.replace(match.group(0), str(new_layer_idx)) + if parallel_state.get_tensor_and_context_parallel_rank() == 0: + logger.info( + f'Distillation: Renamed layer "{submodule_name}" on final PP rank to "{new_submodule_name}"' + ) + return new_submodule_name + + ######################################################## @@ -95,27 +154,17 @@ class BaseLoss(_Loss, metaclass=ABCMeta): """Abstract base class for Megatron distillation losses.""" def __init__( - self, - student_config: TransformerConfig, - teacher_config: TransformerConfig, - projection_layer: bool = False, + self, model_config: "TransformerConfig", projection_layer: nn.Module | None = None ): """Constructor. Args: - student_config: Student's MCore transformer config. - teacher_config: Teacher's MCore transformer config. - projection_layer: If True, create a linear layer to project student tensor to teacher's hidden dim. + model_config: MCore transformer config. + projection_layer: Module which projects student activations to teacher's hidden dim. """ super().__init__() - self._config = student_config - self._tensor_parallel = self._config.tensor_model_parallel_size > 1 - self._sequence_parallel = self._config.sequence_parallel - - if projection_layer: - self._projection = ProjectionLayer(student_config, teacher_config) - else: - self._projection = None + self._config = model_config + self._projection = projection_layer def pre_forward(self, predictions: Tensor, targets: Tensor) -> tuple[Tensor, Tensor]: """Performs projection of student tensor to match teacher's size if necessary.""" @@ -129,23 +178,16 @@ def pre_forward(self, predictions: Tensor, targets: Tensor) -> tuple[Tensor, Ten return predictions, targets - def post_forward(self, loss: Tensor, tp_reduce: bool = False) -> Tensor: + def post_forward( + self, loss: Tensor, tp_reduce: bool = False, is_sequence_parallel: bool = False + ) -> Tensor: """Reshapes tensor from [s, b] to [b, s] for upcoming loss masking.""" loss = loss.transpose(0, 1).contiguous() - return (loss, tp_reduce) + return (loss, tp_reduce, is_sequence_parallel) class MSELoss(BaseLoss): - """Calculates Mean Squared Error loss between two tensors without reducing the sequence dim.""" - - def __init__(self, student_config: TransformerConfig, teacher_config: TransformerConfig): - """Constructor. - - Args: - student_config: Student's MCore transformer config. - teacher_config: Teacher's MCore transformer config. - """ - super().__init__(student_config, teacher_config) + """Calculates MSE loss between two tensors without reducing the sequence dim.""" def forward(self, predictions: Tensor, targets: Tensor) -> Tensor: """Forward function. @@ -159,7 +201,6 @@ def forward(self, predictions: Tensor, targets: Tensor) -> Tensor: """ predictions, targets = self.pre_forward(predictions, targets) - # TP irrelevant since MSE loss gradients are per-input element. loss = F.mse_loss(predictions, targets, reduction="none") loss = loss.sum(dim=-1) @@ -169,22 +210,26 @@ def forward(self, predictions: Tensor, targets: Tensor) -> Tensor: class HiddenStateCosineLoss(BaseLoss): """Calculates Cosine loss between two tensors without reducing the sequence dim. - The tensors are assumed to be intermediate activations, so extra restrictions are in place. + The tensors are assumed to be intermediate activations, with full hidden dimension size. + We recommend only applying this loss to LayerNorm outputs, which have full hidden dim even when TP is used. """ - def __init__(self, student_config: TransformerConfig, teacher_config: TransformerConfig): + def __init__( + self, model_config: "TransformerConfig", projection_layer: nn.Module | None = None + ): """Constructor. Args: - student_config: Student's MCore transformer config. - teacher_config: Teacher's MCore transformer config. + model_config: MCore transformer config. + projection_layer: Module which projects student activations to teacher's hidden dim. """ - super().__init__(student_config, teacher_config, projection_layer=True) + super().__init__(model_config, projection_layer=projection_layer) - if self._tensor_parallel and not self._sequence_parallel: + if self._config.tensor_model_parallel_size > 1: logger.warning( "``HiddenStateCosineLoss`` only works with tensors with full hidden dim. Ensure the " - "tensor inputs meet this requirement or use `--sequence_parallel` if tensor parallel is enabled." + "tensor inputs meet this requirement. We recommend only applying this loss to LayerNorm outputs, " + "which have full hidden dim even when TP is used." ) def forward(self, predictions: Tensor, targets: Tensor) -> Tensor: @@ -207,33 +252,24 @@ def forward(self, predictions: Tensor, targets: Tensor) -> Tensor: ) loss = loss.view(*predictions.shape[:2]) - if self._sequence_parallel: - # Can efficiently gather size [s, b] tensor now for loss-masking purposes. - # TODO(aanoosheh) Reconsider for memory savings by splitting loss mask instead. - loss = gather_from_sequence_parallel_region(loss) - - return self.post_forward(loss) + # NOTE: Tensor sequence length is still split among TP ranks. + return self.post_forward(loss, is_sequence_parallel=self._config.sequence_parallel) class LogitsKLLoss(BaseLoss): """Calculates KL-Divergence loss between two logits tensors without reducing the sequence dim.""" def __init__( - self, - student_config: TransformerConfig, - teacher_config: TransformerConfig, - temperature: float = 1.0, - reverse: bool = False, + self, model_config: "TransformerConfig", temperature: float = 1.0, reverse: bool = False ): """Constructor. Args: - student_config: Student's MCore transformer config. - teacher_config: Teacher's MCore transformer config. + model_config: MCore transformer config. temperature: Divide tensors by this value prior to calculating loss. reverse: Whether to reverse the loss as KLD(teacher, student) instead of KLD(student, teacher) """ - super().__init__(student_config, teacher_config) + super().__init__(model_config) self._temperature = temperature self._reverse = reverse @@ -255,21 +291,21 @@ def forward(self, predictions: Tensor, targets: Tensor) -> Tensor: output_student = predictions.float() / self._temperature # Compute local softmax, and the reweight to compute global softmax. - if self._tensor_parallel: + if self._config.tensor_model_parallel_size > 1: # Maximum value along vocab dimension across all GPUs. teacher_logits_max, _ = torch.max(output_teacher, dim=-1) torch.distributed.all_reduce( teacher_logits_max, op=torch.distributed.ReduceOp.MAX, - group=get_tensor_model_parallel_group(), + group=parallel_state.get_tensor_model_parallel_group(), ) output_teacher = output_teacher - teacher_logits_max.unsqueeze(dim=-1) denom_teacher = torch.sum(torch.exp(output_teacher), dim=-1) - # We can't use `gather_from_tensor_model_parallel_region` here since it discards - # gradients from other ranks - we need to all_reduce the gradients as well. + # We can't use standard reduction function here since the computation + # that follows it isn't identical across TP ranks. denom_teacher = all_reduce_autograd( - denom_teacher, group=get_tensor_model_parallel_group() + denom_teacher, group=parallel_state.get_tensor_model_parallel_group() ) # Maximum value along vocab dimension across all GPUs. @@ -277,13 +313,13 @@ def forward(self, predictions: Tensor, targets: Tensor) -> Tensor: torch.distributed.all_reduce( student_logits_max, op=torch.distributed.ReduceOp.MAX, - group=get_tensor_model_parallel_group(), + group=parallel_state.get_tensor_model_parallel_group(), ) output_student = output_student - student_logits_max.unsqueeze(dim=-1).detach() denom_student = torch.sum(torch.exp(output_student), dim=-1) denom_student = all_reduce_autograd( - denom_student, group=get_tensor_model_parallel_group() + denom_student, group=parallel_state.get_tensor_model_parallel_group() ) slen, bsz, sharded_vocab_size = output_student.shape @@ -327,9 +363,6 @@ def forward(self, predictions: Tensor, targets: Tensor) -> Tensor: return self.post_forward(loss, tp_reduce=True) -######################################################## - - class LogitsAndIntermediatesLossBalancer(mtd.DistillationLossBalancer): """LossBalancer implementation for Logit and Intermediate losses. @@ -359,40 +392,38 @@ def forward(self, loss_dict: dict[str, Tensor]) -> Tensor: Aggregate total scalar loss. """ original_loss = loss_dict.pop(mtd.loss_balancers.STUDENT_LOSS_KEY) - for _key, _loss in loss_dict.items(): + for _key in loss_dict: if _key.startswith(LogitsKLLoss.__name__): - logits_loss = _loss # should only be one - intermediate_loss = sum(loss_dict.values()) + logits_key = _key # should only be one + logits_loss = loss_dict.pop(logits_key) + intermediate_loss = sum(loss_dict.values()) / max(len(loss_dict), 1) if intermediate_loss > 0: dynamic_scale = logits_loss.item() / intermediate_loss.item() - intermediate_loss *= dynamic_scale - kd_loss_scale = self._kd_loss_scale / 2.0 + intermediate_loss_scaled = intermediate_loss * dynamic_scale else: - kd_loss_scale = self._kd_loss_scale + intermediate_loss = logits_loss.new_tensor(intermediate_loss) + intermediate_loss_scaled = intermediate_loss if self._skip_original_loss: - kd_loss = logits_loss + intermediate_loss - total_loss = kd_loss + total_loss = logits_loss + intermediate_loss_scaled else: - kd_loss = (logits_loss + intermediate_loss) * kd_loss_scale - dynamic_scale = original_loss.item() / kd_loss.item() - total_loss = original_loss + kd_loss * dynamic_scale - - return total_loss - - -######################################################## + kd_loss = logits_loss + intermediate_loss_scaled + kd_loss *= original_loss.item() / kd_loss.item() + total_loss = original_loss + kd_loss * self._kd_loss_scale + + out_dict = { + "kd_loss": total_loss, + "logits_loss": logits_loss, + "intermediate_loss": intermediate_loss, + } + return out_dict class ProjectionLayer(MegatronModule): """Module to project student layer activations to teacher's size.""" - def __init__( - self, - student_config: TransformerConfig, - teacher_config: TransformerConfig, - ): + def __init__(self, student_config: "TransformerConfig", teacher_config: "TransformerConfig"): """Constructor. Args: @@ -405,6 +436,7 @@ def __init__( else: self._fit = nn.Linear(student_config.hidden_size, teacher_config.hidden_size) self.apply(self._init_weights) + # Attribute below needed to reduce gradients during backward properly. setattr(self._fit.weight, "sequence_parallel", self.config.sequence_parallel) setattr(self._fit.bias, "sequence_parallel", self.config.sequence_parallel) @@ -418,15 +450,10 @@ def forward(self, student_tensor: Tensor): def _init_weights(self, module): """Initialize the weights.""" - if isinstance(module, (nn.Linear, nn.Embedding)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 - module.weight.data.normal_(mean=0.0, std=0.01) - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - if isinstance(module, nn.Linear) and module.bias is not None: - module.bias.data.zero_() + if isinstance(module, nn.Linear): + self.config.init_method(module.weight.data) + if module.bias is not None: + module.bias.data.zero_() class _AllReduce(torch.autograd.Function): @@ -447,30 +474,142 @@ def backward(ctx, grad_output): def all_reduce_autograd( tensor, op=torch.distributed.ReduceOp.SUM, group=torch.distributed.group.WORLD ): - """AllReduce with autograd.""" + """Custom all-reduce function. + + Needed instead of other all-reduce functions available when the computation following + the all-reduce call differs per rank. In KL loss, this corresponds to the different numerators. + """ return _AllReduce.apply(op, group, tensor) ######################################################## -def adjust_distillation_model_for_mcore(model: mtd.DistillationModel, distill_cfg: dict[str, Any]): +def adjust_distillation_model_for_mcore( + model: mtd.DistillationModel, distill_cfg: DistillationConfig +): """Extra modifications to ``mtd.DistillationModel`` required for Megatron-Core.""" - # HACK: Hide teacher during `sharded_state_dict` method. - def _sharded_state_dict(self, *args, **kwargs) -> ShardedStateDict: + # Hide teacher during `sharded_state_dict` method. + def _sharded_state_dict(self, *args, **kwargs) -> "ShardedStateDict": with self.hide_teacher_model(): - return self._sharded_state_dict(*args, **kwargs) + return type(self).sharded_state_dict(self, *args, **kwargs) + + model.sharded_state_dict = MethodType(_sharded_state_dict, model) + + # Skip `lm_loss` bypassing it when training if not needed for backprop. + def _compute_student_lm_loss(self, labels, logits) -> Tensor: + if distill_cfg.skip_lm_loss and self.training: + return torch.zeros_like(labels, dtype=logits.dtype) + return type(self).compute_language_model_loss(self, labels, logits) + + model.compute_language_model_loss = MethodType(_compute_student_lm_loss, model) + + # Skip `lm_loss` always for teacher. + def _compute_teacher_lm_loss(self, labels, logits) -> Tensor: + return torch.zeros_like(labels, dtype=logits.dtype) + + model.teacher_model.compute_language_model_loss = MethodType( + _compute_teacher_lm_loss, model.teacher_model + ) + + # HACK: Pipeline-parallel Distillation requires splitting input tensor into student and teacher parts. + def _set_student_input_tensor_shape(self, shapes: list[tuple[int]]): + self._tensor_split_idx = shapes[0][-1] + + def _set_input_tensor(self, input_tensors: list[Tensor]): + teacher_inputs = [ + t[..., self._tensor_split_idx :] if t is not None else t for t in input_tensors + ] + student_inputs = [ + t[..., : self._tensor_split_idx] if t is not None else t for t in input_tensors + ] + type(self).set_input_tensor(self.teacher_model, teacher_inputs) + type(self).set_input_tensor(self, student_inputs) + + model.set_student_input_tensor_shape = MethodType(_set_student_input_tensor_shape, model) + model.set_input_tensor = MethodType(_set_input_tensor, model) + + # HACK: Concatenate output tensors when PP>1 so they can be passed between ranks. + def _forward(self, *args, **kwargs): + if not self.training: + with self.only_student_forward(): + return type(self).forward(self, *args, **kwargs) + + with torch.no_grad(): + self._teacher_model.eval() + teacher_output = self._teacher_model(*args, **kwargs) + with self.only_student_forward(): + student_output = type(self).forward(self, *args, **kwargs) + + if not parallel_state.is_pipeline_last_stage(): + return torch.cat([student_output, teacher_output], dim=-1) + else: + return student_output + + model.forward = MethodType(_forward, model) + + +def get_tensor_shapes_adjust_fn_for_distillation( + model: torch.nn.Module | list[torch.nn.Module], + seq_length: int, + micro_batch_size: int, + decoder_seq_length: int | None = None, + forward_only: bool = False, +) -> Callable | None: + """Return the function to adjust tensor shapes for Distillation in Megatron-Core's forward pass. + + Currently only used during non-interleaved pipelining for Distillation. + Concatenates sizes of student and teacher output tensors for inter-process communication. + """ + if ( + forward_only + or parallel_state.get_pipeline_model_parallel_world_size() == 1 + or parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None + ): + return None + # Unwrap + if isinstance(model, list): + model = model[0] + while hasattr(model, "module"): + model = model.module + if not isinstance(model, mtd.DistillationModel): + return None + + def adjust_tensor_shapes( + recv_tensor_shapes: list[tuple[int, ...]], send_tensor_shapes: list[tuple[int, ...]] + ): + teacher_config = get_model_config(model.teacher_model) + tp_group = parallel_state.get_tensor_model_parallel_group() + cp_group = parallel_state.get_context_parallel_group() + + teacher_recv_tensor_shapes = get_tensor_shapes( + seq_length=seq_length, + micro_batch_size=micro_batch_size, + decoder_seq_length=decoder_seq_length, + config=teacher_config, + tp_group=tp_group, + cp_group=cp_group, + ) + teacher_send_tensor_shapes = get_tensor_shapes( + seq_length=seq_length, + micro_batch_size=micro_batch_size, + decoder_seq_length=decoder_seq_length, + config=teacher_config, + tp_group=tp_group, + cp_group=cp_group, + ) + model.set_student_input_tensor_shape(recv_tensor_shapes) - model._sharded_state_dict = model.sharded_state_dict - model.sharded_state_dict = types.MethodType(_sharded_state_dict, model) + for i, shape in enumerate(recv_tensor_shapes): + shape = list(shape) + shape[-1] += teacher_recv_tensor_shapes[0][-1] # type: ignore[index] + recv_tensor_shapes[i] = tuple(shape) + for i, shape in enumerate(send_tensor_shapes): + shape = list(shape) + shape[-1] += teacher_send_tensor_shapes[0][-1] # type: ignore[index] + send_tensor_shapes[i] = tuple(shape) - # HACK: Skip `lm_loss` bypassing it when training if not needed for backprop. - def _compute_language_model_loss(self, labels, logits) -> Tensor: - if self.training: - return torch.zeros_like(labels) - return self._compute_language_model_loss(labels, logits) + return recv_tensor_shapes, send_tensor_shapes - if distill_cfg["skip_lm_loss"]: - model._compute_language_model_loss = model.compute_language_model_loss - model.compute_language_model_loss = types.MethodType(_compute_language_model_loss, model) + return adjust_tensor_shapes