From c335f6ef7751f379aac9f48f6c26cafc90f52103 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 16 Dec 2025 01:01:18 +0000 Subject: [PATCH 01/29] train with only layer distillation losses --- fast_llm/layers/language_model/head.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index b1d0c2acd..db768ca12 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -409,14 +409,23 @@ def _logits_cross_entropy_forward_backward( else: distillation_loss, distillation_grad = None, None - # TODO: de-allocate earlier. - del logits - # TODO: Accumulate grads in-place to reduce memory and compute overhead. grad = _add_tensors(dpo_grad, lm_grad, distillation_grad) # TODO: Return individual losses? loss = _add_tensors(dpo_loss, lm_loss, distillation_loss) + + # When using only activation distillation, loss and grad are None. + # Create zero tensors to allow activation distillation gradients to flow through. + if loss is None: + loss = torch.zeros(1, device=input_.device, dtype=input_.dtype, requires_grad=True) + if grad is None: + # Zero gradient means no loss at the head, but activation distillation gradients + grad = torch.zeros_like(logits) + + # TODO: de-allocate earlier. + del logits + if self.training and losses is not None: if dpo_loss is not None: losses[self._dpo_loss_name].append(dpo_loss.detach()) @@ -502,11 +511,12 @@ def _format_name(name: str) -> str: return name.replace("_", " ") -def _add_tensors(*tensors: torch.Tensor | None) -> torch.Tensor: +def _add_tensors(*tensors: torch.Tensor | None) -> torch.Tensor | None: tensors = [tensor for tensor in tensors if tensor is not None] if len(tensors) > 1: return sum(tensors) elif len(tensors) == 1: return tensors[0] else: - raise RuntimeError() + # All tensors are None - this is valid when using only activation distillation + return None From e06a4b2ca02b22dc56e798aabf0b8c30fe280417 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 16 Dec 2025 14:15:45 +0000 Subject: [PATCH 02/29] unscaled loss llogging + training with distillation loss factor = 0 --- fast_llm/layers/language_model/head.py | 53 +++++++++++++++++++------- 1 file changed, 39 insertions(+), 14 deletions(-) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index db768ca12..733311d39 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -370,11 +370,13 @@ def _logits_cross_entropy_forward_backward( logits_scale_factor=self._config.logits_scale_factor, target_format=TargetFormat.labels, ) + if self.training and losses is not None: + losses[self._ce_loss_name_unscaled].append(lm_loss.detach()) lm_loss = lm_loss * self._config.language_model_loss_factor else: lm_loss, lm_grad = None, None - if distillation_target is not None and self._config.distillation_loss_factor > 0.0: + if distillation_target is not None: if self._config.distillation_loss_implementation == DistillationLossImpl.reverse_kl: distillation_loss, distillation_grad = reverse_kl_forward_backward( logits.flatten(0, -2), @@ -405,9 +407,9 @@ def _logits_cross_entropy_forward_backward( raise ValueError( f"Invalid distillation loss implementation: {self._config.distillation_loss_implementation}" ) + if self.training and losses is not None: # we keep track of unscaled losses for model comparison purposes + losses[self._distillation_loss_name_unscaled].append(distillation_loss.detach()) distillation_loss = distillation_loss * self._config.distillation_loss_factor - else: - distillation_loss, distillation_grad = None, None # TODO: Accumulate grads in-place to reduce memory and compute overhead. grad = _add_tensors(dpo_grad, lm_grad, distillation_grad) @@ -415,14 +417,6 @@ def _logits_cross_entropy_forward_backward( # TODO: Return individual losses? loss = _add_tensors(dpo_loss, lm_loss, distillation_loss) - # When using only activation distillation, loss and grad are None. - # Create zero tensors to allow activation distillation gradients to flow through. - if loss is None: - loss = torch.zeros(1, device=input_.device, dtype=input_.dtype, requires_grad=True) - if grad is None: - # Zero gradient means no loss at the head, but activation distillation gradients - grad = torch.zeros_like(logits) - # TODO: de-allocate earlier. del logits @@ -443,6 +437,13 @@ def _loss_name(self) -> str: name = f"{name}_{self._prediction_distance}" return name + @functools.cached_property + def _ce_loss_name_unscaled(self) -> str: + name = "language_model_loss_unscaled" + if self._prediction_distance > 0: + name = f"{name}_{self._prediction_distance}" + return name + @functools.cached_property def _z_loss_name(self) -> str: name = "z_loss" @@ -471,8 +472,24 @@ def _distillation_loss_name(self) -> str: name = f"{name}_{self._prediction_distance}" return name + @functools.cached_property + def _distillation_loss_name_unscaled(self) -> str: + name = "distillation_loss_unscaled" + if self._prediction_distance > 0: + name = f"{name}_{self._prediction_distance}" + return name + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: loss_defs = [LossDef(name=self._loss_name, formatted_name=_format_name(self._loss_name), count=count)] + if self._config.distillation_model is None or self._config.language_model_loss_factor > 0.0: + # unscaled CE loss (NTP) + loss_defs = [ + LossDef( + name=self._ce_loss_name_unscaled, + formatted_name=_format_name(self._ce_loss_name_unscaled), + count=count, + ) + ] if self._config.logit_z_loss: loss_defs.append( LossDef(name=self._z_loss_name, formatted_name=_format_name(self._z_loss_name), count=count) @@ -490,6 +507,15 @@ def get_loss_definitions(self, count: int = 1) -> list[LossDef]: count=count, ) ) + # unscaled distillation loss for comparison purposes + loss_defs.append( + LossDef( + name=self._distillation_loss_name_unscaled, + formatted_name=_format_name(self._distillation_loss_name_unscaled), + count=count, + ) + ) + # if we mix distillation loss and CE loss for NTP, we want to log both if self._config.language_model_loss_factor > 0.0: loss_defs.append( LossDef( @@ -511,12 +537,11 @@ def _format_name(name: str) -> str: return name.replace("_", " ") -def _add_tensors(*tensors: torch.Tensor | None) -> torch.Tensor | None: +def _add_tensors(*tensors: torch.Tensor | None) -> torch.Tensor: tensors = [tensor for tensor in tensors if tensor is not None] if len(tensors) > 1: return sum(tensors) elif len(tensors) == 1: return tensors[0] else: - # All tensors are None - this is valid when using only activation distillation - return None + raise RuntimeError() From 179ae25e9db3ecda3c75762288abe824c31e65fd Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 17 Dec 2025 21:07:54 +0000 Subject: [PATCH 03/29] make logging more explicit --- fast_llm/layers/language_model/config.py | 12 ++ fast_llm/layers/language_model/head.py | 217 +++++++++++++++-------- 2 files changed, 153 insertions(+), 76 deletions(-) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 53dac2892..13c6d87eb 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -168,11 +168,21 @@ class LanguageModelHeadConfig(LanguageModelHeadBaseConfig): desc="Factor to scale the language modeling loss by when using distillation.", hint=FieldHint.feature, ) + track_language_model_loss: bool = Field( + default=False, + desc="Track the unscaled language modeling loss for logging purposes. Will always do if language_model_loss_factor > 0.", + hint=FieldHint.feature, + ) distillation_loss_factor: float = Field( default=1.0, desc="Factor to scale the distillation loss by when using distillation.", hint=FieldHint.feature, ) + track_distillation_loss: bool = Field( + default=False, + desc="Track the unscaled distillation loss for logging purposes. Will always do if distillation_loss_factor > 0.", + hint=FieldHint.feature, + ) logits_scale_factor: float = Field( default=1.0, desc="Multiply output logits by scale factor.", @@ -243,6 +253,8 @@ def _validate(self) -> None: else: self.language_model_loss_factor = 0.0 super()._validate() + if self.distillation_model is None: + Assert.is_(self.track_distillation_loss, False) assert self.dpo_reference_model is None or self.distillation_model is None # currently don't support both @property diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 733311d39..e785c09e5 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -113,6 +113,12 @@ def __init__( peft=self._peft, ) + self._compute_lm_loss = self.config.language_model_loss_factor > 0.0 or self.config.track_language_model_loss + self._compute_dpo_loss = self._config.enable_dpo + self._compute_distillation_loss = self._config.distillation_model is not None and ( + self._config.distillation_loss_factor > 0.0 or self._config.track_distillation_loss + ) + def forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None ) -> torch.Tensor: @@ -137,8 +143,6 @@ def forward( # TODO: Drop autograd entirely. # TODO: Skip cross-entropy backward if not needed. language_model_loss = self._forward(input_, kwargs, losses) - if losses is not None and language_model_loss is not None: - losses[self._loss_name].append(language_model_loss.detach()) # TODO: Return the model output when needed. if self._is_last_head: # Last head should return the loss for backward. @@ -205,25 +209,22 @@ def _get_targets( if loss_mask is not None: loss_mask = loss_mask.flatten() - if self._config.distillation_model is None or self._config.language_model_loss_factor > 0.0: - lm_target = kwargs.get(LanguageModelKwargs.labels) - if lm_target is not None: - # MTP: Shift the labels - lm_target_sequence_length = ( - lm_target.size(1 - kwargs[LanguageModelKwargs.sequence_first]) + 1 - self._prediction_heads - ) - if LanguageModelKwargs.sequence_q_dim in kwargs: - Assert.eq(lm_target_sequence_length, kwargs[LanguageModelKwargs.sequence_q_dim].size) - lm_target_slice = slice( - self._prediction_distance, self._prediction_distance + lm_target_sequence_length - ) - lm_target = ( - lm_target[lm_target_slice] - if kwargs[LanguageModelKwargs.sequence_first] - else lm_target[:, lm_target_slice] - ).flatten() - else: - lm_target = None + lm_target = kwargs.get(LanguageModelKwargs.labels) + if lm_target is not None: + # MTP: Shift the labels + lm_target_sequence_length = ( + lm_target.size(1 - kwargs[LanguageModelKwargs.sequence_first]) + 1 - self._prediction_heads + ) + if LanguageModelKwargs.sequence_q_dim in kwargs: + Assert.eq(lm_target_sequence_length, kwargs[LanguageModelKwargs.sequence_q_dim].size) + lm_target_slice = slice( + self._prediction_distance, self._prediction_distance + lm_target_sequence_length + ) + lm_target = ( + lm_target[lm_target_slice] + if kwargs[LanguageModelKwargs.sequence_first] + else lm_target[:, lm_target_slice] + ).flatten() targets = (dpo_target, lm_target, distillation_target, loss_mask) if self._sequence_parallel_logits: @@ -246,7 +247,7 @@ def _logits_cross_entropy_forward_backward_split( losses: dict | None = None, ) -> tuple[torch.Tensor | None, torch.Tensor | None]: if self._config.cross_entropy_splits is None or targets is None: - loss, logit_input_grad = self._logits_cross_entropy_forward_backward( + loss, logit_input_grad = self._logits_loss_forward_backward( input_, targets, weight, grad_output, kwargs, losses ) if targets is None: @@ -279,7 +280,7 @@ def _logits_cross_entropy_forward_backward_split( for tensor in [logit_input, *targets, logit_input_grad] ] for logit_input_, *targets_, logit_input_grad_ in zip(*tensors_split, strict=True): - loss_, grad_ = self._logits_cross_entropy_forward_backward( + loss_, grad_ = self._logits_loss_forward_backward( logit_input_, targets_, weight, @@ -301,7 +302,7 @@ def _logits_cross_entropy_forward_backward_split( all_reduce(loss, group=self._parallel_dim.group) return loss, logit_input_grad.view_as(input_) if logit_input_grad is not None else None - def _logits_cross_entropy_forward_backward( + def _logits_loss_forward_backward( self, input_: torch.Tensor, targets: tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None], @@ -359,7 +360,7 @@ def _logits_cross_entropy_forward_backward( else: dpo_loss, dpo_grad = None, None - if lm_target is not None: + if lm_target is not None and self._compute_lm_loss: lm_loss, lm_grad = cross_entropy_forward_backward( logits.flatten(0, -2), lm_target, @@ -370,13 +371,10 @@ def _logits_cross_entropy_forward_backward( logits_scale_factor=self._config.logits_scale_factor, target_format=TargetFormat.labels, ) - if self.training and losses is not None: - losses[self._ce_loss_name_unscaled].append(lm_loss.detach()) - lm_loss = lm_loss * self._config.language_model_loss_factor else: lm_loss, lm_grad = None, None - if distillation_target is not None: + if distillation_target is not None and self._compute_distillation_loss: if self._config.distillation_loss_implementation == DistillationLossImpl.reverse_kl: distillation_loss, distillation_grad = reverse_kl_forward_backward( logits.flatten(0, -2), @@ -407,39 +405,121 @@ def _logits_cross_entropy_forward_backward( raise ValueError( f"Invalid distillation loss implementation: {self._config.distillation_loss_implementation}" ) - if self.training and losses is not None: # we keep track of unscaled losses for model comparison purposes - losses[self._distillation_loss_name_unscaled].append(distillation_loss.detach()) - distillation_loss = distillation_loss * self._config.distillation_loss_factor - - # TODO: Accumulate grads in-place to reduce memory and compute overhead. - grad = _add_tensors(dpo_grad, lm_grad, distillation_grad) - - # TODO: Return individual losses? - loss = _add_tensors(dpo_loss, lm_loss, distillation_loss) + else: + distillation_loss, distillation_grad = None, None # TODO: de-allocate earlier. del logits + loss, grad = self._post_process_loss_and_grad( + dpo_loss, + dpo_grad, + lm_loss, + lm_grad, + distillation_loss, + distillation_grad, + losses, + loss_mask, + kwargs, + ) + + return loss, output_parallel_linear_backward(grad, context) if self.training else None - if self.training and losses is not None: - if dpo_loss is not None: + def _post_process_loss_and_grad( + self, + dpo_loss: torch.Tensor | None, + dpo_grad: torch.Tensor | None, + lm_loss: torch.Tensor | None, + lm_grad: torch.Tensor | None, + distillation_loss: torch.Tensor | None, + distillation_grad: torch.Tensor | None, + losses: dict | None, + loss_mask: torch.Tensor | None, + kwargs, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + If loss is provided (i.e. not None) it will be logged in scaled and unscaled version. The total loss is also logged. + + Arguments: + - Losses: unscaled losses from different components (DPO, LM CE, Distillation) + - Grads: gradients of the losses w.r.t. logits from different components, already scaled by loss factors. + """ + # Extremely explicit but easier to follow. + ############ + if dpo_loss is not None: + if self.training and losses is not None: losses[self._dpo_loss_name].append(dpo_loss.detach()) - if self._config.distillation_model is not None and distillation_loss is not None: + else: + Assert.is_(dpo_grad, None) + + if lm_loss is not None: + if self.training and losses is not None: + losses[self._lm_loss_name_unscaled].append(lm_loss.detach()) + lm_loss = lm_loss * self._config.language_model_loss_factor # does not need scaling by loss_scalor_df + if self.training and losses is not None: + losses[self._lm_loss_name].append(lm_loss.detach()) + else: + Assert.is_(lm_grad, None) + + if distillation_loss is not None: + # We need to scale the loss by (valid_tokens * num_micro_batches) / total_valid_tokens to correctly average the loss over micro-batches. + # The runner averages losses by dividing by num_micro_batches, so we need to account for that. + # Note: for grads this scaling is already in the 'grad_output' + total_valid_tokens = kwargs.get( + LanguageModelKwargs.total_valid_tokens + ) # number of not masked tokens across all micro-batches. + num_micro_batches = kwargs.get("num_micro_batches", 1) + + if loss_mask is None or total_valid_tokens is None: + loss_scalor_df = 1 + else: + valid_tokens = loss_mask.sum() + # Scale by (valid_tokens * num_micro_batches) / total_valid_tokens + # This accounts for the runner dividing by num_micro_batches + loss_scalor_df = (valid_tokens * num_micro_batches) / total_valid_tokens + distillation_loss = distillation_loss * loss_scalor_df + if self.training and losses is not None: + losses[self._distillation_loss_name_unscaled].append(distillation_loss.detach()) + distillation_loss = distillation_loss * self._config.distillation_loss_factor + if self.training and losses is not None: losses[self._distillation_loss_name].append(distillation_loss.detach()) - if self._config.distillation_model is not None and lm_loss is not None: - losses[self._distillation_language_model_loss_name].append(lm_loss.detach()) + else: + Assert.is_(distillation_grad, None) - return loss, output_parallel_linear_backward(grad, context) if self.training else None + ############ + # TODO: Accumulate grads in-place to reduce memory and compute overhead. + grad = _add_tensors(dpo_grad, lm_grad, distillation_grad) + total_loss = _add_tensors(dpo_loss, lm_loss, distillation_loss) + if losses is not None and total_loss is not None: + losses[self._total_loss_name].append(total_loss.detach()) + + return total_loss, grad @functools.cached_property - def _loss_name(self) -> str: - name = "language_model_loss" + def _total_loss_name(self) -> str: + """ + Combined total scaled loss used for training. + """ + name = "lm_head_loss" if self._prediction_distance > 0: name = f"{name}_{self._prediction_distance}" return name @functools.cached_property - def _ce_loss_name_unscaled(self) -> str: - name = "language_model_loss_unscaled" + def _lm_loss_name_unscaled(self) -> str: + """ + Unscaled language model cross-entropy loss. + """ + name = "lm_loss_unscaled" + if self._prediction_distance > 0: + name = f"{name}_{self._prediction_distance}" + return name + + @functools.cached_property + def _lm_loss_name(self) -> str: + """ + Scaled language model cross-entropy loss. + """ + name = "lm_loss" if self._prediction_distance > 0: name = f"{name}_{self._prediction_distance}" return name @@ -459,8 +539,8 @@ def _dpo_loss_name(self) -> str: return name @functools.cached_property - def _distillation_language_model_loss_name(self) -> str: - name = "distillation_language_model_loss" + def _distillation_loss_name_unscaled(self) -> str: + name = "distillation_loss_unscaled" if self._prediction_distance > 0: name = f"{name}_{self._prediction_distance}" return name @@ -472,34 +552,28 @@ def _distillation_loss_name(self) -> str: name = f"{name}_{self._prediction_distance}" return name - @functools.cached_property - def _distillation_loss_name_unscaled(self) -> str: - name = "distillation_loss_unscaled" - if self._prediction_distance > 0: - name = f"{name}_{self._prediction_distance}" - return name - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: - loss_defs = [LossDef(name=self._loss_name, formatted_name=_format_name(self._loss_name), count=count)] - if self._config.distillation_model is None or self._config.language_model_loss_factor > 0.0: - # unscaled CE loss (NTP) - loss_defs = [ + loss_defs = [ + LossDef(name=self._total_loss_name, formatted_name=_format_name(self._total_loss_name), count=count) + ] + if self._compute_lm_loss: + loss_defs.append( LossDef( - name=self._ce_loss_name_unscaled, - formatted_name=_format_name(self._ce_loss_name_unscaled), + name=self._lm_loss_name_unscaled, + formatted_name=_format_name(self._lm_loss_name_unscaled), count=count, ) - ] + ) if self._config.logit_z_loss: loss_defs.append( LossDef(name=self._z_loss_name, formatted_name=_format_name(self._z_loss_name), count=count) ) - if self._config.enable_dpo: + if self._compute_dpo_loss: loss_defs.append( LossDef(name=self._dpo_loss_name, formatted_name=_format_name(self._dpo_loss_name), count=count) ) - if self._config.distillation_model is not None: + if self._compute_distillation_loss: loss_defs.append( LossDef( name=self._distillation_loss_name, @@ -515,15 +589,6 @@ def get_loss_definitions(self, count: int = 1) -> list[LossDef]: count=count, ) ) - # if we mix distillation loss and CE loss for NTP, we want to log both - if self._config.language_model_loss_factor > 0.0: - loss_defs.append( - LossDef( - name=self._distillation_language_model_loss_name, - formatted_name=_format_name(self._distillation_language_model_loss_name), - count=count, - ) - ) return loss_defs @@ -544,4 +609,4 @@ def _add_tensors(*tensors: torch.Tensor | None) -> torch.Tensor: elif len(tensors) == 1: return tensors[0] else: - raise RuntimeError() + raise RuntimeError("No tensors to add.") From 9968aac14c439823c6850e0dcc4e2210b5ad2cf3 Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 17 Dec 2025 22:38:28 +0000 Subject: [PATCH 04/29] clean + tests --- fast_llm/layers/language_model/head.py | 24 ++---- tests/layers/test_lm_head.py | 107 +++++++++++++++++++++---- 2 files changed, 98 insertions(+), 33 deletions(-) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index e785c09e5..8a4601941 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -461,22 +461,7 @@ def _post_process_loss_and_grad( Assert.is_(lm_grad, None) if distillation_loss is not None: - # We need to scale the loss by (valid_tokens * num_micro_batches) / total_valid_tokens to correctly average the loss over micro-batches. - # The runner averages losses by dividing by num_micro_batches, so we need to account for that. - # Note: for grads this scaling is already in the 'grad_output' - total_valid_tokens = kwargs.get( - LanguageModelKwargs.total_valid_tokens - ) # number of not masked tokens across all micro-batches. - num_micro_batches = kwargs.get("num_micro_batches", 1) - - if loss_mask is None or total_valid_tokens is None: - loss_scalor_df = 1 - else: - valid_tokens = loss_mask.sum() - # Scale by (valid_tokens * num_micro_batches) / total_valid_tokens - # This accounts for the runner dividing by num_micro_batches - loss_scalor_df = (valid_tokens * num_micro_batches) / total_valid_tokens - distillation_loss = distillation_loss * loss_scalor_df + distillation_loss = distillation_loss if self.training and losses is not None: losses[self._distillation_loss_name_unscaled].append(distillation_loss.detach()) distillation_loss = distillation_loss * self._config.distillation_loss_factor @@ -564,6 +549,13 @@ def get_loss_definitions(self, count: int = 1) -> list[LossDef]: count=count, ) ) + loss_defs.append( + LossDef( + name=self._lm_loss_name, + formatted_name=_format_name(self._lm_loss_name), + count=count, + ) + ) if self._config.logit_z_loss: loss_defs.append( LossDef(name=self._z_loss_name, formatted_name=_format_name(self._z_loss_name), count=count) diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 623a30d82..88ff9d612 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -55,6 +55,8 @@ def _lm_head( logit_scale_factor: float = 1.0, logit_z_loss=0.0, distillation_loss_implementation: DistillationLossImpl = DistillationLossImpl.cross_entropy, + language_model_loss_factor: float = 1.0, + distillation_loss_factor: float = 1.0, ): hidden = torch.rms_norm( input_.to(rms_weight.dtype), @@ -69,23 +71,31 @@ def _lm_head( loss = _reverse_kl_loss( (logits * logit_scale_factor).flatten(0, -2), (target * logit_scale_factor).flatten(0, -2), loss_mask ) - loss.backward(torch.full_like(loss, grad_output)) - return loss, None + # Apply distillation_loss_factor to grad_output for backward + loss.backward(torch.full_like(loss, grad_output * distillation_loss_factor)) + # Return scaled loss + return loss * distillation_loss_factor, None if logit_scale_factor != 1.0: logits *= logit_scale_factor z_loss = torch.mean(torch.logsumexp(logits, dim=-1) ** 2) if logit_z_loss > 0 else None if target.ndim == logits.ndim: + # Distillation loss (cross-entropy with soft targets) loss = torch.nn.functional.cross_entropy( logits.flatten(0, -2), target.float().softmax(-1).flatten(0, -2), reduction="none" ) if loss_mask is not None: loss = loss * loss_mask.flatten() loss = loss.mean() + # Apply distillation_loss_factor + loss.backward(torch.full_like(loss, grad_output * distillation_loss_factor)) + return loss * distillation_loss_factor, z_loss else: + # Language model loss (cross-entropy with hard labels) loss = torch.nn.functional.cross_entropy(logits.flatten(0, -2), target.flatten()) - loss.backward(torch.full_like(loss, grad_output)) - return loss, z_loss + # Apply language_model_loss_factor + loss.backward(torch.full_like(loss, grad_output * language_model_loss_factor)) + return loss * language_model_loss_factor, z_loss SEQUENCE_LENGTH = 200 @@ -154,6 +164,54 @@ def _lm_head( True, 1, ), + pytest.param( + { + "head": { + "distillation_model": "distillation", + "language_model_loss_factor": 0.0, + "track_language_model_loss": True, + "distillation_loss_factor": 1.0, + } + }, + {}, + False, + 1, + id="track_lm_zero_factor", + ), + pytest.param( + { + "head": { + "distillation_model": "distillation", + "language_model_loss_factor": 0.0, + "distillation_loss_factor": 0.0, + "track_language_model_loss": True, + "track_distillation_loss": True, + } + }, + {}, + False, + 1, + id="track_both_zero_factors", + ), + pytest.param( + { + "head": { + "distillation_model": "distillation", + "language_model_loss_factor": 0.0, + "distillation_loss_factor": 0.0, + "track_language_model_loss": False, + "track_distillation_loss": False, + } + }, + {}, + False, + 1, + marks=pytest.mark.xfail( + reason="No losses computed when all factors=0 and tracking=False, raises RuntimeError in _add_tensors", + strict=True, + ), + id="zero_factors_no_tracking", + ), ), ) def test_lm_head( @@ -292,6 +350,10 @@ def test_lm_head( logit_scale_factor=head_config.logits_scale_factor, logit_z_loss=head_config.logit_z_loss, distillation_loss_implementation=head_config.distillation_loss_implementation, + language_model_loss_factor=( + head_config.language_model_loss_factor if head_config.language_model_loss_factor is not None else 1.0 + ), + distillation_loss_factor=head_config.distillation_loss_factor, ) # Prepare LM head inputs @@ -303,20 +365,27 @@ def test_lm_head( head_input = torch.stack((shared_hidden, input_.detach())).requires_grad_() output_grad = torch.randn_like(shared_hidden) - loss_name = f"language_model_loss_{prediction_distance}" if prediction_distance > 0 else "language_model_loss" - loss_keys = {loss_name} + lm_head_loss_name = f"lm_head_loss_{prediction_distance}" if prediction_distance > 0 else "lm_head_loss" + expected_loss_keys = {lm_head_loss_name} + if head._compute_lm_loss: + lm_loss_name_unscaled = ( + f"lm_loss_unscaled_{prediction_distance}" if prediction_distance > 0 else "lm_loss_unscaled" + ) + lm_loss_name = f"lm_loss_{prediction_distance}" if prediction_distance > 0 else "lm_loss" + + expected_loss_keys.add(lm_loss_name_unscaled) + expected_loss_keys.add(lm_loss_name) if ref_z_loss is not None: - loss_keys.add(f"z_loss_{prediction_distance}" if prediction_distance > 0 else "z_loss") - if head_config.distillation_model is not None: - loss_keys.add("distillation_loss") - if head_config.language_model_loss_factor > 0: - loss_keys.add("distillation_language_model_loss") + expected_loss_keys.add(f"z_loss_{prediction_distance}" if prediction_distance > 0 else "z_loss") + if head._compute_distillation_loss: + expected_loss_keys.add("distillation_loss") + expected_loss_keys.add("distillation_loss_unscaled") Assert.eq( {loss_definition.name: loss_definition.count for loss_definition in head.get_loss_definitions()}, - {loss_key: 1 for loss_key in loss_keys}, + {loss_key: 1 for loss_key in expected_loss_keys}, ) - losses = {key: [] for key in loss_keys} + losses = {key: [] for key in expected_loss_keys} output, context = stage.forward(head_input, kwargs, losses) stage.backward(output_grad, context) @@ -325,16 +394,16 @@ def test_lm_head( 1e-5 if distributed.config.compute_dtype == DataType.float32 else 1e-4 ) * head_config.logits_scale_factor - Assert.eq(losses.keys(), loss_keys) - Assert.eq(len(losses[loss_name]), 1) + Assert.eq(losses.keys(), expected_loss_keys) + Assert.eq(len(losses[lm_head_loss_name]), 1) if ref_z_loss is not None: Assert.eq(len(losses["z_loss"]), 1) Assert.rms_close_relative(losses["z_loss"][0], ref_z_loss, threshold, min_threshold) - Assert.rms_close_relative(losses[loss_name][0], ref_loss, threshold, min_threshold) + Assert.rms_close_relative(losses[lm_head_loss_name][0], ref_loss, threshold, min_threshold) if head._is_last_head: - Assert.all_equal(output, losses[loss_name][0]) + Assert.all_equal(output, losses[lm_head_loss_name][0]) input_grad = head_input.grad else: Assert.all_equal(output, shared_hidden) @@ -344,3 +413,7 @@ def test_lm_head( Assert.rms_close_relative(input_grad, ref_input.grad, threshold, min_threshold) Assert.rms_close_relative(head.final_norm.weight.grad_buffer, ref_rms_weight.grad, threshold, min_threshold) Assert.rms_close_relative(logit_weight.grad_buffer, ref_logit_weight.grad, threshold, min_threshold) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) From 945c5a774bf30fbb088a818f12f5510e98f99bbb Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 17 Dec 2025 22:38:54 +0000 Subject: [PATCH 05/29] nvm --- tests/layers/test_lm_head.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 88ff9d612..c6d806db8 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -413,7 +413,3 @@ def test_lm_head( Assert.rms_close_relative(input_grad, ref_input.grad, threshold, min_threshold) Assert.rms_close_relative(head.final_norm.weight.grad_buffer, ref_rms_weight.grad, threshold, min_threshold) Assert.rms_close_relative(logit_weight.grad_buffer, ref_logit_weight.grad, threshold, min_threshold) - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) From 4b6e3d7503b0cf8a93aef156a0328c2b6dc67cc8 Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 19 Dec 2025 21:28:55 +0000 Subject: [PATCH 06/29] forward KL --- fast_llm/functional/config.py | 1 + fast_llm/functional/cross_entropy.py | 128 +++++++++++++++++++++++++ fast_llm/layers/language_model/head.py | 21 +++- 3 files changed, 149 insertions(+), 1 deletion(-) diff --git a/fast_llm/functional/config.py b/fast_llm/functional/config.py index 4cfc3b61d..20ed99fde 100644 --- a/fast_llm/functional/config.py +++ b/fast_llm/functional/config.py @@ -102,6 +102,7 @@ class CrossEntropyImpl(str, enum.Enum): class DistillationLossImpl(str, enum.Enum): reverse_kl = "reverse_kl" + forward_kl = "forward_kl" cross_entropy = "cross_entropy" diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 8c9ea9399..5a618eea0 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -359,3 +359,131 @@ def reverse_kl_forward_backward( group=group, ) return distillation_loss, distillation_grad + + +@torch.compile +def _forward_kl_forward_backward( + logits: torch.Tensor, + target: torch.Tensor, + loss_mask: torch.Tensor | None, + grad_output: float | None, + group: ProcessGroup | None = None, + logits_scale_factor: float = 1.0, + teacher_softmax_temperature: float = 1.0, +) -> tuple[torch.Tensor, torch.Tensor | None]: + """ + Forward KL: KL(p||q) where p=teacher, q=student. + This is reverse KL with roles swapped in the loss computation. + + Key insight: KL(p||q) = sum_i p_i * log(p_i/q_i) + = sum_i p_i * (log(p_i) - log(q_i)) + which is reverse KL with p and q swapped. + + However, we still need grad w.r.t. student logits, so gradient is different: + d/d(student_logits) KL(p||q) = student_probs - teacher_probs + """ + Assert.eq( + teacher_softmax_temperature, + 1, + msg="Teacher softmax temperature must be 1 for sequence-tensor-parallel forward KL", + ) + Assert.eq(logits_scale_factor, 1, msg="Logits scale factor must be 1 for sequence-tensor-parallel forward KL") + Assert.eq(target.shape, logits.shape) + assert target.dtype.is_floating_point, target.dtype + if loss_mask is not None: + Assert.eq(loss_mask.shape, logits.shape[:-1]) + + # Compute log softmax for both teacher and student + teacher_log_probs = distributed_log_softmax(target.float(), group=group) + student_log_probs = distributed_log_softmax(logits, group=group) + + teacher_probs = teacher_log_probs.exp() + # Forward KL: p * log(p/q) = p * (log_p - log_q) + log_ratio = teacher_log_probs - student_log_probs + del teacher_log_probs + + # Compute loss: sum over vocab of teacher_probs * log_ratio + loss_terms = (teacher_probs * log_ratio).sum(dim=-1) + del log_ratio + + if loss_mask is not None: + valid = loss_mask.to(loss_terms.dtype) + loss_terms = loss_terms * valid + valid_tokens = torch.prod(torch.tensor(loss_terms.shape, device=loss_terms.device, dtype=loss_terms.dtype)) + loss = loss_terms.sum() + + if group is not None: + all_reduce(loss, op=ReduceOp.SUM, group=group) + loss /= valid_tokens + + if grad_output is not None: + # Gradient: d/d(student_logits) KL(p||q) = student_probs - teacher_probs + student_probs = student_log_probs.exp() + grad_base = student_probs - teacher_probs + del student_probs, teacher_probs, student_log_probs + + if loss_mask is not None: + grad_base.mul_(loss_mask.to(logits.dtype).unsqueeze(-1)) + + grad_base.mul_(grad_output / valid_tokens) + grad = grad_base.to(logits.dtype) + else: + grad = None + + return loss.detach_(), grad + + +def forward_kl_forward_backward( + logits: torch.Tensor, + target: torch.Tensor, + loss_mask: torch.Tensor | None, + grad_output: float | None, + group: ProcessGroup | None = None, + logits_scale_factor: float = 1.0, + teacher_softmax_temperature: float = 1.0, + target_format: TargetFormat = TargetFormat.labels, + sequence_parallel_logits: bool = False, +) -> tuple[torch.Tensor, torch.Tensor | None]: + """ + Compute forward KL divergence: KL(p||q) where p is the target distribution (teacher) and q is the predicted (student). + This is mode-covering (vs. mode-seeking for reverse KL) and useful for: + - Encouraging the model to cover all modes of the target distribution + - Spreading probability mass broadly across the target support + - Standard distillation scenarios where you want to match the full teacher distribution + + Key differences from reverse KL: + - Forward KL: KL(p||q) = mode-covering (spreads mass broadly) + - Reverse KL: KL(q||p) = mode-seeking (focuses on target modes) + + Takes: + logits: [BxS, V] or [B, S, V], where V is local vocab size + target: [BxS, V] or [B, S, V] (logits format) + loss_mask: [BxS] or [B, S] or None + ... + + Returns: + loss: Forward KL divergence loss + grad: Gradients w.r.t. logits + """ + + if sequence_parallel_logits: + # TODO: see hybrid dev branch where it is implemented + raise NotImplementedError("Sequence-parallel forward KL is not implemented yet, set vocab_parallel true") + + Assert.eq(target_format, TargetFormat.logits, msg="Forward KL only supports logits format") + Assert.eq(target.shape, logits.shape) + assert target.dtype.is_floating_point, target.dtype + if loss_mask is not None: + Assert.eq(loss_mask.shape, logits.shape[:-1]) + + # TODO: implement fused? + distillation_loss, distillation_grad = _forward_kl_forward_backward( + logits=logits, + target=target, + loss_mask=loss_mask, + grad_output=grad_output, + logits_scale_factor=logits_scale_factor, + teacher_softmax_temperature=teacher_softmax_temperature, + group=group, + ) + return distillation_loss, distillation_grad diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 8a4601941..b8a8f0cbb 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -14,7 +14,11 @@ from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.autograd import grad_is_context, wrap_forward_backward from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl, TargetFormat, TritonConfig -from fast_llm.functional.cross_entropy import cross_entropy_forward_backward, reverse_kl_forward_backward +from fast_llm.functional.cross_entropy import ( + cross_entropy_forward_backward, + forward_kl_forward_backward, + reverse_kl_forward_backward, +) from fast_llm.functional.dpo import compute_dpo_loss from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward from fast_llm.layers.block.block import Block @@ -390,6 +394,21 @@ def _logits_loss_forward_backward( sequence_parallel_logits=self._sequence_parallel_logits, ) + elif self._config.distillation_loss_implementation == DistillationLossImpl.forward_kl: + distillation_loss, distillation_grad = forward_kl_forward_backward( + logits.flatten(0, -2), + distillation_target, + loss_mask, + grad_output=grad_output * self._loss_coefficient * self._config.distillation_loss_factor, + group=group, + logits_scale_factor=self._config.logits_scale_factor, + teacher_softmax_temperature=self._config.teacher_softmax_temperature, + target_format=( + TargetFormat.labels if self._config.distillation_model is None else TargetFormat.logits + ), + sequence_parallel_logits=self._sequence_parallel_logits, + ) + elif self._config.distillation_loss_implementation == DistillationLossImpl.cross_entropy: distillation_loss, distillation_grad = cross_entropy_forward_backward( logits.flatten(0, -2), From c5fefa0a13b1903bf88e7187790a94211b8d40cb Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 19 Dec 2025 22:19:52 +0000 Subject: [PATCH 07/29] test forward kl --- tests/functional/test_cross_entropy.py | 43 ++++++++++++++++++++++++-- 1 file changed, 41 insertions(+), 2 deletions(-) diff --git a/tests/functional/test_cross_entropy.py b/tests/functional/test_cross_entropy.py index 72644d061..716c56ba3 100644 --- a/tests/functional/test_cross_entropy.py +++ b/tests/functional/test_cross_entropy.py @@ -8,7 +8,11 @@ import torch from fast_llm.functional.config import CrossEntropyImpl, TargetFormat, TritonConfig -from fast_llm.functional.cross_entropy import cross_entropy_forward_backward, reverse_kl_forward_backward +from fast_llm.functional.cross_entropy import ( + cross_entropy_forward_backward, + forward_kl_forward_backward, + reverse_kl_forward_backward, +) from fast_llm.utils import Assert from tests.utils.utils import requires_cuda @@ -127,6 +131,41 @@ def test_reverse_kl(loss_masking, target_format): _compare_cross_entropy_outputs(out, out_ref, True, grad, grad_ref, 1e-3) +def _forward_kl_forward_backward_torch(logits: torch.Tensor, target: torch.Tensor, loss_mask: torch.Tensor | None): + # Manual reference: sum over vocab then average over all tokens (not just valid ones). + # Forward KL: KL(p||q) where p=teacher, q=student + logits = logits.detach().requires_grad_(True) + per_sample = torch.nn.functional.kl_div( + torch.log_softmax(logits.float(), dim=-1), + torch.log_softmax(target.float(), dim=-1), + reduction="none", + log_target=True, + ).sum(dim=-1) + if loss_mask is not None: + per_sample = per_sample * loss_mask + output = per_sample.sum() / per_sample.numel() + output.backward() + return output, logits.grad + + +@requires_cuda +@pytest.mark.slow +# TODO: Support the same parameterization as above in the reference implementation. +@pytest.mark.parametrize("loss_masking", [False, True]) +@pytest.mark.parametrize("target_format", (TargetFormat.logits,)) +def test_forward_kl(loss_masking, target_format): + logits, target, loss_mask = _get_cross_entropy_inputs(1000, loss_masking, target_format) + out_ref, grad_ref = _forward_kl_forward_backward_torch(logits, target, loss_mask) + out, grad = forward_kl_forward_backward( + logits=logits, + target=target, + loss_mask=loss_mask, + grad_output=1.0, + target_format=TargetFormat.logits, + ) + _compare_cross_entropy_outputs(out, out_ref, True, grad, grad_ref, 1e-3) + + def _mp_worker(rank: int, world_size: int, init_method: str, fn_args: tuple): try: torch.distributed.init_process_group(backend="gloo", rank=rank, world_size=world_size, init_method=init_method) @@ -189,7 +228,7 @@ def _compare_parallel_cross_entropy( def compare_parallel_cross_entropy(rank: int, group: torch.distributed.ProcessGroup): success = True - for function in (reverse_kl_forward_backward, cross_entropy_forward_backward): + for function in (reverse_kl_forward_backward, forward_kl_forward_backward, cross_entropy_forward_backward): for target_format in (TargetFormat.logits,): for loss_masking in [False, True]: try: From 411959616793a78f49e76b9c0767d055ba2c1971 Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 19 Dec 2025 22:48:44 +0000 Subject: [PATCH 08/29] wip: report unscaled + kl loss --- fast_llm/layers/language_model/config.py | 35 ++++- fast_llm/layers/language_model/head.py | 158 +++++++++++++---------- 2 files changed, 122 insertions(+), 71 deletions(-) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 13c6d87eb..807b39703 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -173,16 +173,37 @@ class LanguageModelHeadConfig(LanguageModelHeadBaseConfig): desc="Track the unscaled language modeling loss for logging purposes. Will always do if language_model_loss_factor > 0.", hint=FieldHint.feature, ) - distillation_loss_factor: float = Field( - default=1.0, - desc="Factor to scale the distillation loss by when using distillation.", + track_forward_kl_loss: bool = Field( + default=False, + desc="Track the unscaled forward KL loss for logging purposes. Will always do if distillation_loss_implementation is forward_kl.", + hint=FieldHint.feature, + ) + track_reverse_kl_loss: bool = Field( + default=False, + desc="Track the unscaled reverse KL loss for logging purposes. Will always do if distillation_loss_implementation is reverse_kl.", hint=FieldHint.feature, ) - track_distillation_loss: bool = Field( + track_distillation_ce_loss: bool = Field( default=False, - desc="Track the unscaled distillation loss for logging purposes. Will always do if distillation_loss_factor > 0.", + desc="Track the unscaled distillation cross-entropy loss for logging purposes. Will always do if distillation_loss_implementation is cross_entropy.", + hint=FieldHint.feature, + ) + forward_kl_loss_factor: float = Field( + default=0.0, + desc="Factor to scale the forward KL loss by when using distillation with forward KL.", hint=FieldHint.feature, ) + reverse_kl_loss_factor: float = Field( + default=1.0, + desc="Factor to scale the reverse KL loss by when using distillation with reverse KL.", + hint=FieldHint.feature, + ) + distillation_ce_loss_factor: float = Field( + default=0.0, + desc="Factor to scale the distillation cross-entropy loss by when using distillation with cross-entropy.", + hint=FieldHint.feature, + ) + logits_scale_factor: float = Field( default=1.0, desc="Multiply output logits by scale factor.", @@ -254,7 +275,9 @@ def _validate(self) -> None: self.language_model_loss_factor = 0.0 super()._validate() if self.distillation_model is None: - Assert.is_(self.track_distillation_loss, False) + Assert.is_(self.track_forward_kl_loss, False) + Assert.is_(self.track_reverse_kl_loss, False) + Assert.is_(self.track_distillation_ce_loss, False) assert self.dpo_reference_model is None or self.distillation_model is None # currently don't support both @property diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index b8a8f0cbb..040dc55dc 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -13,7 +13,7 @@ from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.autograd import grad_is_context, wrap_forward_backward -from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl, TargetFormat, TritonConfig +from fast_llm.functional.config import CrossEntropyImpl, TargetFormat, TritonConfig from fast_llm.functional.cross_entropy import ( cross_entropy_forward_backward, forward_kl_forward_backward, @@ -119,8 +119,18 @@ def __init__( self._compute_lm_loss = self.config.language_model_loss_factor > 0.0 or self.config.track_language_model_loss self._compute_dpo_loss = self._config.enable_dpo - self._compute_distillation_loss = self._config.distillation_model is not None and ( - self._config.distillation_loss_factor > 0.0 or self._config.track_distillation_loss + self._compute_rkl_loss = self._config.distillation_model is not None and ( + self._config.reverse_kl_loss_factor > 0.0 or self._config.track_reverse_kl_loss + ) + self._compute_kl_loss = self._config.distillation_model is not None and ( + self._config.forward_kl_loss_factor > 0.0 or self._config.track_forward_kl_loss + ) + self._compute_dist_ce_loss = self._config.distillation_model is not None and ( + self._config.distillation_ce_loss_factor > 0.0 or self._config.track_distillation_ce_loss + ) + + self._compute_distillation_loss = any( + [self._compute_rkl_loss, self._compute_kl_loss, self._compute_dist_ce_loss] ) def forward( @@ -378,13 +388,16 @@ def _logits_loss_forward_backward( else: lm_loss, lm_grad = None, None + distillation_rkl_grad, distillation_kl_grad, distillation_ce_grad = None, None, None + distillation_rkl_loss, distillation_kl_loss, distillation_ce_loss = None, None, None + if distillation_target is not None and self._compute_distillation_loss: - if self._config.distillation_loss_implementation == DistillationLossImpl.reverse_kl: - distillation_loss, distillation_grad = reverse_kl_forward_backward( + if self._compute_rkl_loss: + distillation_rkl_loss, distillation_rkl_grad = reverse_kl_forward_backward( logits.flatten(0, -2), distillation_target, loss_mask, - grad_output=grad_output * self._loss_coefficient * self._config.distillation_loss_factor, + grad_output=grad_output * self._loss_coefficient * self._config.reverse_kl_loss_factor, group=group, logits_scale_factor=self._config.logits_scale_factor, teacher_softmax_temperature=self._config.teacher_softmax_temperature, @@ -394,12 +407,12 @@ def _logits_loss_forward_backward( sequence_parallel_logits=self._sequence_parallel_logits, ) - elif self._config.distillation_loss_implementation == DistillationLossImpl.forward_kl: - distillation_loss, distillation_grad = forward_kl_forward_backward( + if self._compute_kl_loss: + distillation_kl_loss, distillation_kl_grad = forward_kl_forward_backward( logits.flatten(0, -2), distillation_target, loss_mask, - grad_output=grad_output * self._loss_coefficient * self._config.distillation_loss_factor, + grad_output=grad_output * self._loss_coefficient * self._config.forward_kl_loss_factor, group=group, logits_scale_factor=self._config.logits_scale_factor, teacher_softmax_temperature=self._config.teacher_softmax_temperature, @@ -409,13 +422,13 @@ def _logits_loss_forward_backward( sequence_parallel_logits=self._sequence_parallel_logits, ) - elif self._config.distillation_loss_implementation == DistillationLossImpl.cross_entropy: - distillation_loss, distillation_grad = cross_entropy_forward_backward( + if self._compute_dist_ce_loss: + distillation_ce_loss, distillation_ce_grad = cross_entropy_forward_backward( logits.flatten(0, -2), distillation_target, loss_mask, group=group, - grad_output=grad_output * self._loss_coefficient * self._config.distillation_loss_factor, + grad_output=grad_output * self._loss_coefficient * self._config.distillation_ce_loss_factor, implementation=self._cross_entropy_impl, logits_scale_factor=self._config.logits_scale_factor, target_format=TargetFormat.logits, @@ -424,8 +437,6 @@ def _logits_loss_forward_backward( raise ValueError( f"Invalid distillation loss implementation: {self._config.distillation_loss_implementation}" ) - else: - distillation_loss, distillation_grad = None, None # TODO: de-allocate earlier. del logits @@ -434,10 +445,13 @@ def _logits_loss_forward_backward( dpo_grad, lm_loss, lm_grad, - distillation_loss, - distillation_grad, + distillation_rkl_loss, + distillation_rkl_grad, + distillation_kl_loss, + distillation_kl_grad, + distillation_ce_loss, + distillation_ce_grad, losses, - loss_mask, kwargs, ) @@ -449,10 +463,13 @@ def _post_process_loss_and_grad( dpo_grad: torch.Tensor | None, lm_loss: torch.Tensor | None, lm_grad: torch.Tensor | None, - distillation_loss: torch.Tensor | None, - distillation_grad: torch.Tensor | None, + distillation_rkl_loss: torch.Tensor | None, + distillation_rkl_grad: torch.Tensor | None, + distillation_kl_loss: torch.Tensor | None, + distillation_kl_grad: torch.Tensor | None, + distillation_ce_loss: torch.Tensor | None, + distillation_ce_grad: torch.Tensor | None, losses: dict | None, - loss_mask: torch.Tensor | None, kwargs, ) -> tuple[torch.Tensor, torch.Tensor]: """ @@ -463,6 +480,7 @@ def _post_process_loss_and_grad( - Grads: gradients of the losses w.r.t. logits from different components, already scaled by loss factors. """ # Extremely explicit but easier to follow. + # TODO: simplify / shrten / make seperate dataclass? ############ if dpo_loss is not None: if self.training and losses is not None: @@ -471,28 +489,38 @@ def _post_process_loss_and_grad( Assert.is_(dpo_grad, None) if lm_loss is not None: - if self.training and losses is not None: - losses[self._lm_loss_name_unscaled].append(lm_loss.detach()) - lm_loss = lm_loss * self._config.language_model_loss_factor # does not need scaling by loss_scalor_df if self.training and losses is not None: losses[self._lm_loss_name].append(lm_loss.detach()) + lm_loss = lm_loss * self._config.language_model_loss_factor else: Assert.is_(lm_grad, None) - if distillation_loss is not None: - distillation_loss = distillation_loss + if distillation_rkl_loss is not None: + distillation_rkl_loss = distillation_rkl_loss if self.training and losses is not None: - losses[self._distillation_loss_name_unscaled].append(distillation_loss.detach()) - distillation_loss = distillation_loss * self._config.distillation_loss_factor + losses[self._distillation_rkl_loss_name].append(distillation_rkl_loss.detach()) + distillation_rkl_loss = distillation_rkl_loss * self._config.distillation_loss_factor + else: + Assert.is_(distillation_rkl_grad, None) + if distillation_kl_loss is not None: + distillation_kl_loss = distillation_kl_loss + if self.training and losses is not None: + losses[self._distillation_kl_loss_name].append(distillation_kl_loss.detach()) + distillation_kl_loss = distillation_kl_loss * self._config.distillation_loss_factor + else: + Assert.is_(distillation_kl_grad, None) + if distillation_ce_loss is not None: + distillation_ce_loss = distillation_ce_loss if self.training and losses is not None: - losses[self._distillation_loss_name].append(distillation_loss.detach()) + losses[self._distillation_ce_loss_name].append(distillation_ce_loss.detach()) + distillation_ce_loss = distillation_ce_loss * self._config.distillation_loss_factor else: - Assert.is_(distillation_grad, None) + Assert.is_(distillation_ce_grad, None) ############ # TODO: Accumulate grads in-place to reduce memory and compute overhead. - grad = _add_tensors(dpo_grad, lm_grad, distillation_grad) - total_loss = _add_tensors(dpo_loss, lm_loss, distillation_loss) + grad = _add_tensors(dpo_grad, lm_grad, distillation_rkl_grad, distillation_kl_grad, distillation_ce_grad) + total_loss = _add_tensors(dpo_loss, lm_loss, distillation_rkl_loss, distillation_kl_loss, distillation_ce_loss) if losses is not None and total_loss is not None: losses[self._total_loss_name].append(total_loss.detach()) @@ -509,7 +537,7 @@ def _total_loss_name(self) -> str: return name @functools.cached_property - def _lm_loss_name_unscaled(self) -> str: + def _lm_loss_name(self) -> str: """ Unscaled language model cross-entropy loss. """ @@ -519,39 +547,36 @@ def _lm_loss_name_unscaled(self) -> str: return name @functools.cached_property - def _lm_loss_name(self) -> str: - """ - Scaled language model cross-entropy loss. - """ - name = "lm_loss" + def _z_loss_name(self) -> str: + name = "z_loss" if self._prediction_distance > 0: name = f"{name}_{self._prediction_distance}" return name @functools.cached_property - def _z_loss_name(self) -> str: - name = "z_loss" + def _dpo_loss_name(self) -> str: + name = "dpo_loss" if self._prediction_distance > 0: name = f"{name}_{self._prediction_distance}" return name @functools.cached_property - def _dpo_loss_name(self) -> str: - name = "dpo_loss" + def _distillation_kl_loss_name(self) -> str: + name = "distillation_kl_loss_unscaled" if self._prediction_distance > 0: name = f"{name}_{self._prediction_distance}" return name @functools.cached_property - def _distillation_loss_name_unscaled(self) -> str: - name = "distillation_loss_unscaled" + def _distillation_rkl_loss_name(self) -> str: + name = "distillation_rkl_loss_unscaled" if self._prediction_distance > 0: name = f"{name}_{self._prediction_distance}" return name @functools.cached_property - def _distillation_loss_name(self) -> str: - name = "distillation_loss" + def _distillation_ce_loss_name(self) -> str: + name = "distillation_ce_loss_unscaled" if self._prediction_distance > 0: name = f"{name}_{self._prediction_distance}" return name @@ -568,13 +593,6 @@ def get_loss_definitions(self, count: int = 1) -> list[LossDef]: count=count, ) ) - loss_defs.append( - LossDef( - name=self._lm_loss_name, - formatted_name=_format_name(self._lm_loss_name), - count=count, - ) - ) if self._config.logit_z_loss: loss_defs.append( LossDef(name=self._z_loss_name, formatted_name=_format_name(self._z_loss_name), count=count) @@ -585,21 +603,31 @@ def get_loss_definitions(self, count: int = 1) -> list[LossDef]: ) if self._compute_distillation_loss: - loss_defs.append( - LossDef( - name=self._distillation_loss_name, - formatted_name=_format_name(self._distillation_loss_name), - count=count, - ) - ) # unscaled distillation loss for comparison purposes - loss_defs.append( - LossDef( - name=self._distillation_loss_name_unscaled, - formatted_name=_format_name(self._distillation_loss_name_unscaled), - count=count, + if self._compute_kl_loss: + loss_defs.append( + LossDef( + name=self._distillation_kl_loss_name, + formatted_name=_format_name(self._distillation_kl_loss_name), + count=count, + ) + ) + if self._compute_rkl_loss: + loss_defs.append( + LossDef( + name=self._distillation_rkl_loss_name, + formatted_name=_format_name(self._distillation_rkl_loss_name), + count=count, + ) + ) + if self._compute_dist_ce_loss: + loss_defs.append( + LossDef( + name=self._distillation_ce_loss_name, + formatted_name=_format_name(self._distillation_ce_loss_name), + count=count, + ) ) - ) return loss_defs From b55a0a428fb85dc3ce16ec061d1bed5ea2ac619a Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 22 Dec 2025 13:42:48 +0000 Subject: [PATCH 09/29] loss config --- fast_llm/functional/cross_entropy.py | 2 + fast_llm/layers/language_model/config.py | 97 +---- fast_llm/layers/language_model/head.py | 408 +++++------------- .../layers/language_model/lm_head_losses.py | 280 ++++++++++++ 4 files changed, 405 insertions(+), 382 deletions(-) create mode 100644 fast_llm/layers/language_model/lm_head_losses.py diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 5a618eea0..f534d8a78 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -314,6 +314,7 @@ def reverse_kl_forward_backward( teacher_softmax_temperature: float = 1.0, target_format: TargetFormat = TargetFormat.labels, sequence_parallel_logits: bool = False, + **kwargs, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Compute reverse KL divergence: KL(q||p) where q is the predicted distribution (student) and p is the target (teacher). @@ -443,6 +444,7 @@ def forward_kl_forward_backward( teacher_softmax_temperature: float = 1.0, target_format: TargetFormat = TargetFormat.labels, sequence_parallel_logits: bool = False, + **kwargs, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Compute forward KL divergence: KL(p||q) where p is the target distribution (teacher) and q is the predicted (student). diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 807b39703..6fc92eaa4 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -5,11 +5,11 @@ from fast_llm.engine.config_utils.parameter import OptionalParameterConfig, ParameterConfig, combine_lr_scales from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl from fast_llm.layers.block.config import BlockConfig, BlockKwargs, BlockSequenceConfig from fast_llm.layers.common.normalization.config import NormalizationConfig from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.config import DecoderBlockConfig +from fast_llm.layers.language_model.lm_head_losses import LossConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -135,75 +135,22 @@ class LanguageModelHeadConfig(LanguageModelHeadBaseConfig): desc="Configuration for the final normalization layer.", hint=FieldHint.architecture, ) + losses: dict[str, LossConfig] = Field( + default_factory=dict, + desc="A dictionary of loss names and their configurations.", + hint=FieldHint.core, + ) # TODO: Cleanup output_weight: ParameterConfig = Field( desc="Configuration for the LM output layer (weight). Ignored for tied embeddings", hint=FieldHint.architecture, ) - cross_entropy_implementation: CrossEntropyImpl = Field( - default=CrossEntropyImpl.auto, - desc="Implementation for the cross-entropy computation.", - hint=FieldHint.performance, - ) - distillation_loss_implementation: DistillationLossImpl = Field( - default=DistillationLossImpl.cross_entropy, - desc="Implementation for the distillation cross-entropy computation.", - hint=FieldHint.performance, - ) cross_entropy_splits: int | None = Field( default=None, desc="Split the logit and cross-entropy computation into this many fragment, to reduce memory usage.", hint=FieldHint.feature, valid=skip_valid_if_none(check_field(Assert.gt, 0)), ) - logit_z_loss: float = Field( - default=0.0, - desc="Regularize the logits with Z-loss.", - doc="We recommend 1e-4 for stability, as used for training PaLM.", - hint=FieldHint.feature, - valid=check_field(Assert.geq, 0), - ) - language_model_loss_factor: float = Field( - default=None, - desc="Factor to scale the language modeling loss by when using distillation.", - hint=FieldHint.feature, - ) - track_language_model_loss: bool = Field( - default=False, - desc="Track the unscaled language modeling loss for logging purposes. Will always do if language_model_loss_factor > 0.", - hint=FieldHint.feature, - ) - track_forward_kl_loss: bool = Field( - default=False, - desc="Track the unscaled forward KL loss for logging purposes. Will always do if distillation_loss_implementation is forward_kl.", - hint=FieldHint.feature, - ) - track_reverse_kl_loss: bool = Field( - default=False, - desc="Track the unscaled reverse KL loss for logging purposes. Will always do if distillation_loss_implementation is reverse_kl.", - hint=FieldHint.feature, - ) - track_distillation_ce_loss: bool = Field( - default=False, - desc="Track the unscaled distillation cross-entropy loss for logging purposes. Will always do if distillation_loss_implementation is cross_entropy.", - hint=FieldHint.feature, - ) - forward_kl_loss_factor: float = Field( - default=0.0, - desc="Factor to scale the forward KL loss by when using distillation with forward KL.", - hint=FieldHint.feature, - ) - reverse_kl_loss_factor: float = Field( - default=1.0, - desc="Factor to scale the reverse KL loss by when using distillation with reverse KL.", - hint=FieldHint.feature, - ) - distillation_ce_loss_factor: float = Field( - default=0.0, - desc="Factor to scale the distillation cross-entropy loss by when using distillation with cross-entropy.", - hint=FieldHint.feature, - ) - logits_scale_factor: float = Field( default=1.0, desc="Multiply output logits by scale factor.", @@ -212,10 +159,10 @@ class LanguageModelHeadConfig(LanguageModelHeadBaseConfig): hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) - teacher_softmax_temperature: float = Field( - default=1.0, - desc="Divides distillation target logits by this factor.", - doc="Divides distillation target logits by this factor.", + logit_z_loss: float = Field( + default=0.0, + desc="Regularize the logits with Z-loss.", + doc="We recommend 1e-4 for stability, as used for training PaLM.", hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) @@ -224,11 +171,6 @@ class LanguageModelHeadConfig(LanguageModelHeadBaseConfig): desc="Name of the reference model to use for dpo.", hint=FieldHint.feature, ) - dpo_beta: float | None = Field( - default=1.0, - desc="Beta value for DPO loss.", - hint=FieldHint.feature, - ) distillation_model: str | None = Field( default=None, desc="Name of the reference model to use for knowledge distillation." @@ -268,16 +210,17 @@ def layer_class(self) -> "type[LanguageModelHead]": def _validate(self) -> None: with self._set_implicit_default(): - if self.language_model_loss_factor is None: - if self.distillation_model is None: - self.language_model_loss_factor = 1.0 - else: - self.language_model_loss_factor = 0.0 + if not self.losses: + self.losses = { + "lm_loss": LossConfig._from_dict( + {"type": "cross_entropy_lm_loss", "weight_scalor": 1.0, "log_it": True} + ) + } + + for loss_config in self.losses.values(): + if "dist" in loss_config.type: + assert self.distillation_model is not None, "Distillation loss requires a distillation model." super()._validate() - if self.distillation_model is None: - Assert.is_(self.track_forward_kl_loss, False) - Assert.is_(self.track_reverse_kl_loss, False) - Assert.is_(self.track_distillation_ce_loss, False) assert self.dpo_reference_model is None or self.distillation_model is None # currently don't support both @property diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 040dc55dc..f23bb6f1c 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -13,13 +13,6 @@ from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.autograd import grad_is_context, wrap_forward_backward -from fast_llm.functional.config import CrossEntropyImpl, TargetFormat, TritonConfig -from fast_llm.functional.cross_entropy import ( - cross_entropy_forward_backward, - forward_kl_forward_backward, - reverse_kl_forward_backward, -) -from fast_llm.functional.dpo import compute_dpo_loss from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward from fast_llm.layers.block.block import Block from fast_llm.layers.block.config import BlockDimNames @@ -31,6 +24,7 @@ LanguageModelHeadConfig, LanguageModelKwargs, ) +from fast_llm.layers.language_model.lm_head_losses import Targets, _format_name from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert, div, get_unique @@ -91,16 +85,6 @@ def __init__( if self._config.cross_entropy_splits is not None and self._sequence_parallel: assert not self._vocab_parallel - if not self._config.enable_dpo: - self._cross_entropy_impl = self._config.cross_entropy_implementation - if self._cross_entropy_impl == CrossEntropyImpl.auto: - if self._vocab_parallel: - self._cross_entropy_impl = CrossEntropyImpl.fused - elif TritonConfig.TRITON_ENABLED: - self._cross_entropy_impl = CrossEntropyImpl.triton - else: - self._cross_entropy_impl = CrossEntropyImpl.fused - self._forward = wrap_forward_backward(self._forward_backward, grad_is_context) self.final_norm = self._config.normalization.get_layer( @@ -116,22 +100,10 @@ def __init__( lr_scale=self._lr_scale, peft=self._peft, ) - - self._compute_lm_loss = self.config.language_model_loss_factor > 0.0 or self.config.track_language_model_loss - self._compute_dpo_loss = self._config.enable_dpo - self._compute_rkl_loss = self._config.distillation_model is not None and ( - self._config.reverse_kl_loss_factor > 0.0 or self._config.track_reverse_kl_loss - ) - self._compute_kl_loss = self._config.distillation_model is not None and ( - self._config.forward_kl_loss_factor > 0.0 or self._config.track_forward_kl_loss - ) - self._compute_dist_ce_loss = self._config.distillation_model is not None and ( - self._config.distillation_ce_loss_factor > 0.0 or self._config.track_distillation_ce_loss - ) - - self._compute_distillation_loss = any( - [self._compute_rkl_loss, self._compute_kl_loss, self._compute_dist_ce_loss] - ) + self._formatted_loss_names = { + loss_name: loss_config.get_formatted_name(loss_name, self._prediction_distance) + for loss_name, loss_config in self._config.losses.items() + } def forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None @@ -203,22 +175,25 @@ def _forward_backward( else: return loss, None - def _get_targets( - self, kwargs: dict - ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None] | None: - # Loss mask for distillation. (Labels are already masked.) + def _get_targets(self, kwargs: dict) -> Targets | None: + ( + lm_target, + dpo_target, + reference_model_logits, + loss_mask, + chosen_spans, + rejected_spans, + dpo_reference_model_logits, + ) = (None, None, None, None, None, None, None) if self._config.enable_dpo: dpo_target = kwargs.get(LanguageModelKwargs.labels) - lm_target = None - distillation_target = None - loss_mask = None + chosen_spans = kwargs.get(LanguageModelKwargs.chosen_spans) + rejected_spans = kwargs.get(LanguageModelKwargs.rejected_spans) + dpo_reference_model_logits = (kwargs.get(f"{self._config.dpo_reference_model}_logits"),) else: - dpo_target = None - if self._config.distillation_model is None: - distillation_target, loss_mask = None, None - else: + if self._config.distillation_model is not None: # Target is reference model logits. - distillation_target = kwargs[f"{self._config.distillation_model}_logits"].flatten(0, -2) + reference_model_logits = kwargs[f"{self._config.distillation_model}_logits"].flatten(0, -2) loss_mask = kwargs.get(LanguageModelKwargs.loss_mask) if loss_mask is not None: loss_mask = loss_mask.flatten() @@ -240,12 +215,29 @@ def _get_targets( else lm_target[:, lm_target_slice] ).flatten() - targets = (dpo_target, lm_target, distillation_target, loss_mask) if self._sequence_parallel_logits: - targets = [None if target is None else split_op(target, self._parallel_dim.group, 0) for target in targets] - if not any(target is not None for target in targets): - # Simplify so we don't have to check every time. - targets = None + if dpo_target is not None: + dpo_target = split_op(dpo_target, self._parallel_dim.group, 0) + if lm_target is not None: + lm_target = split_op(lm_target, self._parallel_dim.group, 0) + if loss_mask is not None: + loss_mask = split_op(loss_mask, self._parallel_dim.group, 0) + if reference_model_logits is not None: + reference_model_logits = split_op(reference_model_logits, self._parallel_dim.group, 0) + + targets = Targets( + dpo_target=dpo_target, + lm_target=lm_target, + loss_mask=loss_mask, + chosen_spans=chosen_spans, + rejected_spans=rejected_spans, + reference_model_logits=reference_model_logits, + dpo_reference_model_logits=dpo_reference_model_logits, + ) + + # Return None if no targets are set + if not targets.has_any_target(): + return None return targets def get_output_weights(self) -> list[torch.Tensor]: @@ -254,7 +246,7 @@ def get_output_weights(self) -> list[torch.Tensor]: def _logits_cross_entropy_forward_backward_split( self, input_: torch.Tensor, - targets: tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None] | None, + targets: Targets | None, weight: torch.Tensor, grad_output: float, kwargs: dict, @@ -285,15 +277,34 @@ def _logits_cross_entropy_forward_backward_split( logit_input_grad = torch.empty_like(logit_input) else: logit_input_grad = None + + # Extract target tensors for splitting (keep same order as original tuple) + target_tensors = [ + targets.lm_target, + targets.dpo_target, + targets.reference_model_logits, + targets.loss_mask, + ] split_size = div( - get_unique(target.size(0) for target in targets if target is not None), + get_unique(target.size(0) for target in target_tensors if target is not None), self._config.cross_entropy_splits, ) tensors_split = [ [None] * self._config.cross_entropy_splits if tensor is None else tensor.split(split_size) - for tensor in [logit_input, *targets, logit_input_grad] + for tensor in [logit_input, *target_tensors, logit_input_grad] ] - for logit_input_, *targets_, logit_input_grad_ in zip(*tensors_split, strict=True): + for logit_input_, lm_target_, dpo_target_, reference_model_logits_, loss_mask_, logit_input_grad_ in zip( + *tensors_split, strict=True + ): + targets_ = Targets( + lm_target=lm_target_, + dpo_target=dpo_target_, + reference_model_logits=reference_model_logits_, + loss_mask=loss_mask_, + chosen_spans=targets.chosen_spans, + rejected_spans=targets.rejected_spans, + dpo_reference_model_logits=targets.dpo_reference_model_logits, + ) loss_, grad_ = self._logits_loss_forward_backward( logit_input_, targets_, @@ -319,7 +330,7 @@ def _logits_cross_entropy_forward_backward_split( def _logits_loss_forward_backward( self, input_: torch.Tensor, - targets: tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None], + targets: Targets | None, weight: torch.Tensor, grad_output: float, kwargs: dict, @@ -334,6 +345,7 @@ def _logits_loss_forward_backward( sequence_parallel=self._sequence_parallel and self._vocab_parallel, ) + # TODO: also move to lm_head_losses? if self._config.logit_z_loss > 0.0: logits = z_loss( logits, @@ -359,175 +371,48 @@ def _logits_loss_forward_backward( if targets is None: return logits * self._config.logits_scale_factor, None - dpo_target, lm_target, distillation_target, loss_mask = targets - if dpo_target is not None: - dpo_loss, dpo_grad = compute_dpo_loss( + total_loss, grad = None, None + for loss_name, loss_config in self._config.losses.items(): + if loss_config.weight_scalor == 0.0 and not loss_config.log_it: + continue + # losses are returned unscaled but the grads are already scaled + # we log unscaled losses seperately and the scaled total loss + loss_unscaled_, grad_ = loss_config.compute_loss( logits, - dpo_target, - kwargs.get(f"{self._config.dpo_reference_model}_logits"), - kwargs[LanguageModelKwargs.chosen_spans], - kwargs[LanguageModelKwargs.rejected_spans], - self._config.dpo_beta, - grad_output * self._loss_coefficient, - ) - else: - dpo_loss, dpo_grad = None, None - - if lm_target is not None and self._compute_lm_loss: - lm_loss, lm_grad = cross_entropy_forward_backward( - logits.flatten(0, -2), - lm_target, - None, + targets, + grad_output=( + grad_output * self._loss_coefficient * loss_config.weight_scalor + if grad_output is not None + else None + ), group=group, - grad_output=grad_output * self._loss_coefficient * self._config.language_model_loss_factor, - implementation=self._cross_entropy_impl, logits_scale_factor=self._config.logits_scale_factor, - target_format=TargetFormat.labels, + vocab_parallel=self._vocab_parallel, ) - else: - lm_loss, lm_grad = None, None - - distillation_rkl_grad, distillation_kl_grad, distillation_ce_grad = None, None, None - distillation_rkl_loss, distillation_kl_loss, distillation_ce_loss = None, None, None - - if distillation_target is not None and self._compute_distillation_loss: - if self._compute_rkl_loss: - distillation_rkl_loss, distillation_rkl_grad = reverse_kl_forward_backward( - logits.flatten(0, -2), - distillation_target, - loss_mask, - grad_output=grad_output * self._loss_coefficient * self._config.reverse_kl_loss_factor, - group=group, - logits_scale_factor=self._config.logits_scale_factor, - teacher_softmax_temperature=self._config.teacher_softmax_temperature, - target_format=( - TargetFormat.labels if self._config.distillation_model is None else TargetFormat.logits - ), - sequence_parallel_logits=self._sequence_parallel_logits, - ) + loss_ = loss_unscaled_ * loss_config.weight_scalor * self._loss_coefficient - if self._compute_kl_loss: - distillation_kl_loss, distillation_kl_grad = forward_kl_forward_backward( - logits.flatten(0, -2), - distillation_target, - loss_mask, - grad_output=grad_output * self._loss_coefficient * self._config.forward_kl_loss_factor, - group=group, - logits_scale_factor=self._config.logits_scale_factor, - teacher_softmax_temperature=self._config.teacher_softmax_temperature, - target_format=( - TargetFormat.labels if self._config.distillation_model is None else TargetFormat.logits - ), - sequence_parallel_logits=self._sequence_parallel_logits, - ) + if losses is not None and loss_config.log_it: + losses[self._formatted_loss_names[loss_name]].append(loss_unscaled_.detach()) - if self._compute_dist_ce_loss: - distillation_ce_loss, distillation_ce_grad = cross_entropy_forward_backward( - logits.flatten(0, -2), - distillation_target, - loss_mask, - group=group, - grad_output=grad_output * self._loss_coefficient * self._config.distillation_ce_loss_factor, - implementation=self._cross_entropy_impl, - logits_scale_factor=self._config.logits_scale_factor, - target_format=TargetFormat.logits, - ) + if total_loss is None: + total_loss = loss_ else: - raise ValueError( - f"Invalid distillation loss implementation: {self._config.distillation_loss_implementation}" - ) - - # TODO: de-allocate earlier. - del logits - loss, grad = self._post_process_loss_and_grad( - dpo_loss, - dpo_grad, - lm_loss, - lm_grad, - distillation_rkl_loss, - distillation_rkl_grad, - distillation_kl_loss, - distillation_kl_grad, - distillation_ce_loss, - distillation_ce_grad, - losses, - kwargs, - ) - - return loss, output_parallel_linear_backward(grad, context) if self.training else None - - def _post_process_loss_and_grad( - self, - dpo_loss: torch.Tensor | None, - dpo_grad: torch.Tensor | None, - lm_loss: torch.Tensor | None, - lm_grad: torch.Tensor | None, - distillation_rkl_loss: torch.Tensor | None, - distillation_rkl_grad: torch.Tensor | None, - distillation_kl_loss: torch.Tensor | None, - distillation_kl_grad: torch.Tensor | None, - distillation_ce_loss: torch.Tensor | None, - distillation_ce_grad: torch.Tensor | None, - losses: dict | None, - kwargs, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - If loss is provided (i.e. not None) it will be logged in scaled and unscaled version. The total loss is also logged. - - Arguments: - - Losses: unscaled losses from different components (DPO, LM CE, Distillation) - - Grads: gradients of the losses w.r.t. logits from different components, already scaled by loss factors. - """ - # Extremely explicit but easier to follow. - # TODO: simplify / shrten / make seperate dataclass? - ############ - if dpo_loss is not None: - if self.training and losses is not None: - losses[self._dpo_loss_name].append(dpo_loss.detach()) - else: - Assert.is_(dpo_grad, None) + total_loss = total_loss + loss_ - if lm_loss is not None: - if self.training and losses is not None: - losses[self._lm_loss_name].append(lm_loss.detach()) - lm_loss = lm_loss * self._config.language_model_loss_factor - else: - Assert.is_(lm_grad, None) - - if distillation_rkl_loss is not None: - distillation_rkl_loss = distillation_rkl_loss - if self.training and losses is not None: - losses[self._distillation_rkl_loss_name].append(distillation_rkl_loss.detach()) - distillation_rkl_loss = distillation_rkl_loss * self._config.distillation_loss_factor - else: - Assert.is_(distillation_rkl_grad, None) - if distillation_kl_loss is not None: - distillation_kl_loss = distillation_kl_loss - if self.training and losses is not None: - losses[self._distillation_kl_loss_name].append(distillation_kl_loss.detach()) - distillation_kl_loss = distillation_kl_loss * self._config.distillation_loss_factor - else: - Assert.is_(distillation_kl_grad, None) - if distillation_ce_loss is not None: - distillation_ce_loss = distillation_ce_loss - if self.training and losses is not None: - losses[self._distillation_ce_loss_name].append(distillation_ce_loss.detach()) - distillation_ce_loss = distillation_ce_loss * self._config.distillation_loss_factor - else: - Assert.is_(distillation_ce_grad, None) + if grad_ is not None: + if grad is None: + grad = grad_ + else: + grad = grad + grad_ - ############ - # TODO: Accumulate grads in-place to reduce memory and compute overhead. - grad = _add_tensors(dpo_grad, lm_grad, distillation_rkl_grad, distillation_kl_grad, distillation_ce_grad) - total_loss = _add_tensors(dpo_loss, lm_loss, distillation_rkl_loss, distillation_kl_loss, distillation_ce_loss) if losses is not None and total_loss is not None: - losses[self._total_loss_name].append(total_loss.detach()) + losses[self._total_head_loss_name].append(total_loss.detach()) - return total_loss, grad + return total_loss, output_parallel_linear_backward(grad, context) if self.training else None @functools.cached_property - def _total_loss_name(self) -> str: + def _total_head_loss_name(self) -> str: """ Combined total scaled loss used for training. """ @@ -536,16 +421,6 @@ def _total_loss_name(self) -> str: name = f"{name}_{self._prediction_distance}" return name - @functools.cached_property - def _lm_loss_name(self) -> str: - """ - Unscaled language model cross-entropy loss. - """ - name = "lm_loss_unscaled" - if self._prediction_distance > 0: - name = f"{name}_{self._prediction_distance}" - return name - @functools.cached_property def _z_loss_name(self) -> str: name = "z_loss" @@ -553,81 +428,18 @@ def _z_loss_name(self) -> str: name = f"{name}_{self._prediction_distance}" return name - @functools.cached_property - def _dpo_loss_name(self) -> str: - name = "dpo_loss" - if self._prediction_distance > 0: - name = f"{name}_{self._prediction_distance}" - return name - - @functools.cached_property - def _distillation_kl_loss_name(self) -> str: - name = "distillation_kl_loss_unscaled" - if self._prediction_distance > 0: - name = f"{name}_{self._prediction_distance}" - return name - - @functools.cached_property - def _distillation_rkl_loss_name(self) -> str: - name = "distillation_rkl_loss_unscaled" - if self._prediction_distance > 0: - name = f"{name}_{self._prediction_distance}" - return name - - @functools.cached_property - def _distillation_ce_loss_name(self) -> str: - name = "distillation_ce_loss_unscaled" - if self._prediction_distance > 0: - name = f"{name}_{self._prediction_distance}" - return name - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: loss_defs = [ - LossDef(name=self._total_loss_name, formatted_name=_format_name(self._total_loss_name), count=count) - ] - if self._compute_lm_loss: - loss_defs.append( - LossDef( - name=self._lm_loss_name_unscaled, - formatted_name=_format_name(self._lm_loss_name_unscaled), - count=count, - ) + LossDef( + name=self._total_head_loss_name, formatted_name=_format_name(self._total_head_loss_name), count=count ) - if self._config.logit_z_loss: - loss_defs.append( - LossDef(name=self._z_loss_name, formatted_name=_format_name(self._z_loss_name), count=count) - ) - if self._compute_dpo_loss: - loss_defs.append( - LossDef(name=self._dpo_loss_name, formatted_name=_format_name(self._dpo_loss_name), count=count) - ) - - if self._compute_distillation_loss: - # unscaled distillation loss for comparison purposes - if self._compute_kl_loss: - loss_defs.append( - LossDef( - name=self._distillation_kl_loss_name, - formatted_name=_format_name(self._distillation_kl_loss_name), - count=count, - ) - ) - if self._compute_rkl_loss: - loss_defs.append( - LossDef( - name=self._distillation_rkl_loss_name, - formatted_name=_format_name(self._distillation_rkl_loss_name), - count=count, - ) - ) - if self._compute_dist_ce_loss: - loss_defs.append( - LossDef( - name=self._distillation_ce_loss_name, - formatted_name=_format_name(self._distillation_ce_loss_name), - count=count, - ) + ] + for loss_name, loss_config in self._config.losses.items(): + if loss_config.log_it: + loss_def: LossDef = loss_config.get_loss_def( + name=loss_name, count=count, prediction_distance=self._prediction_distance ) + loss_defs.append(loss_def) return loss_defs @@ -635,17 +447,3 @@ def get_loss_definitions(self, count: int = 1) -> list[LossDef]: def heads(self): # For compatibility with MTP. return [self] - - -def _format_name(name: str) -> str: - return name.replace("_", " ") - - -def _add_tensors(*tensors: torch.Tensor | None) -> torch.Tensor: - tensors = [tensor for tensor in tensors if tensor is not None] - if len(tensors) > 1: - return sum(tensors) - elif len(tensors) == 1: - return tensors[0] - else: - raise RuntimeError("No tensors to add.") diff --git a/fast_llm/layers/language_model/lm_head_losses.py b/fast_llm/layers/language_model/lm_head_losses.py new file mode 100644 index 000000000..cc8e5ebc5 --- /dev/null +++ b/fast_llm/layers/language_model/lm_head_losses.py @@ -0,0 +1,280 @@ +import abc +import dataclasses +import logging +import typing + +import torch + +from fast_llm.config import Config, Field, FieldHint, check_field, config_class +from fast_llm.core.distributed import ProcessGroup +from fast_llm.engine.base_model.config import LossDef +from fast_llm.engine.config_utils.data_type import DataType +from fast_llm.functional.config import CrossEntropyImpl, TargetFormat, TritonConfig +from fast_llm.utils import Assert + +if typing.TYPE_CHECKING: + pass + +logger = logging.getLogger(__name__) + +# +# CE loss on lm_targets for standard LM training. Here targets are already masked. +# CE loss for distillation: cross entropuy that uses reference_model_logits as soft targets, not implemented, TODO. +# Forward KL divergence loss on reference_model_logits for distillation (mode-covering). +# Reverse KL divergence loss on reference_model_logits for distillation (mode-seeking). +# DPO loss for alignment using chosen and rejected spans. +# + + +def _format_name(name: str) -> str: + return name.replace("_", " ") + + +@dataclasses.dataclass +class Targets: + lm_target: torch.Tensor | None = None + dpo_target: torch.Tensor | None = None + loss_mask: torch.Tensor | None = None + chosen_spans: list[list[tuple[int, int]]] | None = None + rejected_spans: list[list[tuple[int, int]]] | None = None + reference_model_logits: torch.Tensor | None = None + dpo_reference_model_logits: torch.Tensor | None = None + + def has_any_target(self) -> bool: + return any(getattr(self, field.name) is not None for field in dataclasses.fields(self)) + + +@config_class(registry=True) +class LossConfig(Config): + """ + Losses canm register themselves + using @config_class(dynamic_type={LossConfig: "loss_type_name"}) + """ + + _name: typing.ClassVar[str] + _abstract: typing.ClassVar[bool] = True + + weight_scalor: float = Field( + default=1.0, + hint=FieldHint.core, + desc="Weight for this loss in the total loss computation.", + valid=check_field(Assert.geq, 0.0), + ) + + log_it: bool = Field( + default=True, + hint=FieldHint.optional, + desc="Whether to log this loss.", + ) + + @abc.abstractmethod + def compute_loss( + self, + logits: torch.Tensor, + target: Targets, + grad_output: float | None = None, + group: ProcessGroup | None = None, + logits_scale_factor: float | None = None, + vocab_parallel: bool = False, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + pass + + def get_loss_def(self, name: str, count: int = 1, prediction_distance: int | None = None) -> LossDef: + name = self.get_formatted_name(name, prediction_distance) + return LossDef( + name=name, + formatted_name=_format_name(name), + count=count, + dtype=DataType.float32, + ) + + def _validate(self): + Assert.geq(self.weight_scalor, 0.0) + if self.weight_scalor > 0.0: + with self._set_implicit_default(): + if "log_it" not in self._explicit_fields: + self.log_it = True + super()._validate() + + def get_formatted_name(self, name=None, prediction_distance: int | None = None) -> str: + name = f"{self._name}({name})" + if prediction_distance is not None: + name = f"{name}_{prediction_distance}" + return name + + +@config_class(dynamic_type={LossConfig: "cross_entropy_lm_loss"}) +class CrossEntropyLMLossConfig(LossConfig): + _name: typing.ClassVar[str] = "CE" + _abstract: typing.ClassVar[bool] = False + + implementation: CrossEntropyImpl = Field( + default=CrossEntropyImpl.auto, + desc="Implementation for the cross-entropy computation.", + hint=FieldHint.performance, + ) + + teacher_softmax_temperature: float = Field( + default=1.0, + hint=FieldHint.optional, + desc="Temperature for teacher softmax (used in distillation losses).", + valid=check_field(Assert.gt, 0.0), + ) + + def compute_loss( + self, + logits: torch.Tensor, + targets: Targets, + grad_output: float | None = None, + group: ProcessGroup | None = None, + logits_scale_factor: float | None = None, + vocab_parallel: bool = False, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + from fast_llm.functional.cross_entropy import cross_entropy_forward_backward + + target = targets.lm_target + if target is None: + raise ValueError("CrossEntropyLoss requires lm_target to be set in Targets") + implementation = self.implementation + if implementation == CrossEntropyImpl.auto: + if vocab_parallel: + implementation = CrossEntropyImpl.fused + elif TritonConfig.TRITON_ENABLED: + implementation = CrossEntropyImpl.triton + else: + implementation = CrossEntropyImpl.fused + + return cross_entropy_forward_backward( + logits=logits.flatten(0, -2), + target=target, + loss_mask=None, # Labels are already masked + grad_output=grad_output, + group=group, + implementation=implementation, + logits_scale_factor=logits_scale_factor, + teacher_softmax_temperature=self.teacher_softmax_temperature, + target_format=TargetFormat.labels, + **kwargs, + ) + + +@config_class(dynamic_type={LossConfig: "fkl_dist"}) +class ForwardKLLossConfig(LossConfig): + """Forward KL divergence KL(p||q) for distillation (mode-covering).""" + + _name: typing.ClassVar[str] = "FwdKL" + _abstract: typing.ClassVar[bool] = False + + teacher_softmax_temperature: float = Field( + default=1.0, + hint=FieldHint.optional, + desc="Temperature for teacher softmax.", + valid=check_field(Assert.gt, 0.0), + ) + + def compute_loss( + self, + logits: torch.Tensor, + targets: Targets, + grad_output: float | None = None, + group: ProcessGroup | None = None, + logits_scale_factor: float | None = None, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + from fast_llm.functional.cross_entropy import forward_kl_forward_backward + + target = targets.reference_model_logits + if target is None: + raise ValueError("ForwardKLLoss requires distillation_target to be set in Targets") + + return forward_kl_forward_backward( + logits=logits.flatten(0, -2), + target=target, + loss_mask=targets.loss_mask, + grad_output=grad_output, + group=group, + logits_scale_factor=logits_scale_factor, + teacher_softmax_temperature=self.teacher_softmax_temperature, + target_format=TargetFormat.logits, + **kwargs, + ) + + +@config_class(dynamic_type={LossConfig: "revkl_dist"}) +class ReverseKLLossConfig(LossConfig): + """Reverse KL divergence KL(q||p) for distillation (mode-seeking).""" + + _name: typing.ClassVar[str] = "RevKL" + _abstract: typing.ClassVar[bool] = False + + teacher_softmax_temperature: float = Field( + default=1.0, + hint=FieldHint.optional, + desc="Temperature for teacher softmax.", + valid=check_field(Assert.gt, 0.0), + ) + + def compute_loss( + self, + logits: torch.Tensor, + targets: Targets, + grad_output: float | None = None, + group: ProcessGroup | None = None, + logits_scale_factor: float | None = None, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + from fast_llm.functional.cross_entropy import reverse_kl_forward_backward + + # Use distillation_target for KL losses + target = targets.reference_model_logits + if target is None: + raise ValueError("ReverseKLLoss requires distillation_target to be set in Targets") + + return reverse_kl_forward_backward( + logits=logits.flatten(0, -2), + target=target, + loss_mask=targets.loss_mask, + grad_output=grad_output, + group=group, + logits_scale_factor=logits_scale_factor, + teacher_softmax_temperature=self.teacher_softmax_temperature, + target_format=TargetFormat.logits, + **kwargs, + ) + + +@config_class(dynamic_type={LossConfig: "dpo"}) +class DPOLossConfig(LossConfig): + """Direct Preference Optimization (DPO) loss for alignment.""" + + _name: typing.ClassVar[str] = "DPO" + _abstract: typing.ClassVar[bool] = False + + beta: float = Field( + default=1.0, + hint=FieldHint.core, + desc="Beta parameter for DPO loss (controls strength of preference optimization).", + valid=check_field(Assert.gt, 0.0), + ) + + def compute_loss( + self, + logits: torch.Tensor, + targets: Targets, + grad_output: float | None = None, + group: ProcessGroup | None = None, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + from fast_llm.functional.dpo import compute_dpo_loss + + return compute_dpo_loss( + logits=logits, + targets=targets.dpo_target, + reference_model_logits=targets.dpo_reference_model_logits, + chosen_spans=targets.chosen_spans, + rejected_spans=targets.rejected_spans, + beta=self.beta, + grad_output=grad_output, + ) From 097baeb4c2396575066f96ced831771e0054ea76 Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 22 Dec 2025 14:24:57 +0000 Subject: [PATCH 10/29] wip --- fast_llm/functional/config.py | 6 - fast_llm/layers/language_model/config.py | 4 +- fast_llm/layers/language_model/head.py | 8 +- .../layers/language_model/lm_head_losses.py | 6 +- tests/layers/test_lm_head.py | 188 +++++++++--------- tests/utils/model_configs.py | 8 +- 6 files changed, 108 insertions(+), 112 deletions(-) diff --git a/fast_llm/functional/config.py b/fast_llm/functional/config.py index 20ed99fde..511c2d9f3 100644 --- a/fast_llm/functional/config.py +++ b/fast_llm/functional/config.py @@ -100,12 +100,6 @@ class CrossEntropyImpl(str, enum.Enum): triton = "triton" -class DistillationLossImpl(str, enum.Enum): - reverse_kl = "reverse_kl" - forward_kl = "forward_kl" - cross_entropy = "cross_entropy" - - class TargetFormat(enum.StrEnum): labels = "labels" logits = "logits" diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 6fc92eaa4..786d312d8 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -212,9 +212,7 @@ def _validate(self) -> None: with self._set_implicit_default(): if not self.losses: self.losses = { - "lm_loss": LossConfig._from_dict( - {"type": "cross_entropy_lm_loss", "weight_scalor": 1.0, "log_it": True} - ) + "lm_loss": LossConfig._from_dict({"type": "cross_entropy_lm_loss", "factor": 1.0, "log_it": True}) } for loss_config in self.losses.values(): diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index f23bb6f1c..c8c3be797 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -374,7 +374,7 @@ def _logits_loss_forward_backward( total_loss, grad = None, None for loss_name, loss_config in self._config.losses.items(): - if loss_config.weight_scalor == 0.0 and not loss_config.log_it: + if loss_config.factor == 0.0 and not loss_config.log_it: continue # losses are returned unscaled but the grads are already scaled # we log unscaled losses seperately and the scaled total loss @@ -382,15 +382,13 @@ def _logits_loss_forward_backward( logits, targets, grad_output=( - grad_output * self._loss_coefficient * loss_config.weight_scalor - if grad_output is not None - else None + grad_output * self._loss_coefficient * loss_config.factor if grad_output is not None else None ), group=group, logits_scale_factor=self._config.logits_scale_factor, vocab_parallel=self._vocab_parallel, ) - loss_ = loss_unscaled_ * loss_config.weight_scalor * self._loss_coefficient + loss_ = loss_unscaled_ * loss_config.factor * self._loss_coefficient if losses is not None and loss_config.log_it: losses[self._formatted_loss_names[loss_name]].append(loss_unscaled_.detach()) diff --git a/fast_llm/layers/language_model/lm_head_losses.py b/fast_llm/layers/language_model/lm_head_losses.py index cc8e5ebc5..a231efa5a 100644 --- a/fast_llm/layers/language_model/lm_head_losses.py +++ b/fast_llm/layers/language_model/lm_head_losses.py @@ -54,7 +54,7 @@ class LossConfig(Config): _name: typing.ClassVar[str] _abstract: typing.ClassVar[bool] = True - weight_scalor: float = Field( + factor: float = Field( default=1.0, hint=FieldHint.core, desc="Weight for this loss in the total loss computation.", @@ -90,8 +90,8 @@ def get_loss_def(self, name: str, count: int = 1, prediction_distance: int | Non ) def _validate(self): - Assert.geq(self.weight_scalor, 0.0) - if self.weight_scalor > 0.0: + Assert.geq(self.factor, 0.0) + if self.factor > 0.0: with self._set_implicit_default(): if "log_it" not in self._explicit_fields: self.log_it = True diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index c6d806db8..917bb7efd 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -5,7 +5,7 @@ from fast_llm.config import UpdateType from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl +from fast_llm.functional.config import CrossEntropyImpl from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.language_model.config import LanguageModelHeadConfig, LanguageModelKwargs from fast_llm.layers.language_model.head import LanguageModelHead @@ -119,99 +119,99 @@ def _lm_head( ({"tied_embedding_weight": True}, {}, False, 1), ({}, {}, False, 2), ({}, {}, True, 1), - ( - { - "head": { - "distillation_model": "distillation", - "distillation_loss_implementation": DistillationLossImpl.cross_entropy, - } - }, - {}, - False, - 1, - ), - ( - { - "head": { - "distillation_model": "distillation", - "distillation_loss_implementation": DistillationLossImpl.reverse_kl, - } - }, - {}, - False, - 1, - ), - ( - { - "head": { - "distillation_model": "distillation", - "distillation_loss_implementation": DistillationLossImpl.cross_entropy, - "language_model_loss_factor": 1.0, - } - }, - {}, - True, - 1, - ), - ( - { - "head": { - "distillation_model": "distillation", - "distillation_loss_implementation": DistillationLossImpl.reverse_kl, - } - }, - {}, - True, - 1, - ), - pytest.param( - { - "head": { - "distillation_model": "distillation", - "language_model_loss_factor": 0.0, - "track_language_model_loss": True, - "distillation_loss_factor": 1.0, - } - }, - {}, - False, - 1, - id="track_lm_zero_factor", - ), - pytest.param( - { - "head": { - "distillation_model": "distillation", - "language_model_loss_factor": 0.0, - "distillation_loss_factor": 0.0, - "track_language_model_loss": True, - "track_distillation_loss": True, - } - }, - {}, - False, - 1, - id="track_both_zero_factors", - ), - pytest.param( - { - "head": { - "distillation_model": "distillation", - "language_model_loss_factor": 0.0, - "distillation_loss_factor": 0.0, - "track_language_model_loss": False, - "track_distillation_loss": False, - } - }, - {}, - False, - 1, - marks=pytest.mark.xfail( - reason="No losses computed when all factors=0 and tracking=False, raises RuntimeError in _add_tensors", - strict=True, - ), - id="zero_factors_no_tracking", - ), + # ( + # { + # "head": { + # "distillation_model": "distillation", + # "distillation_loss_implementation": DistillationLossImpl.cross_entropy, + # } + # }, + # {}, + # False, + # 1, + # ), + # ( + # { + # "head": { + # "distillation_model": "distillation", + # "distillation_loss_implementation": DistillationLossImpl.reverse_kl, + # } + # }, + # {}, + # False, + # 1, + # ), + # ( + # { + # "head": { + # "distillation_model": "distillation", + # "distillation_loss_implementation": DistillationLossImpl.cross_entropy, + # "language_model_loss_factor": 1.0, + # } + # }, + # {}, + # True, + # 1, + # ), + # ( + # { + # "head": { + # "distillation_model": "distillation", + # "distillation_loss_implementation": DistillationLossImpl.reverse_kl, + # } + # }, + # {}, + # True, + # 1, + # ), + # pytest.param( + # { + # "head": { + # "distillation_model": "distillation", + # "language_model_loss_factor": 0.0, + # "track_language_model_loss": True, + # "distillation_loss_factor": 1.0, + # } + # }, + # {}, + # False, + # 1, + # id="track_lm_zero_factor", + # ), + # pytest.param( + # { + # "head": { + # "distillation_model": "distillation", + # "language_model_loss_factor": 0.0, + # "distillation_loss_factor": 0.0, + # "track_language_model_loss": True, + # "track_distillation_loss": True, + # } + # }, + # {}, + # False, + # 1, + # id="track_both_zero_factors", + # ), + # pytest.param( + # { + # "head": { + # "distillation_model": "distillation", + # "language_model_loss_factor": 0.0, + # "distillation_loss_factor": 0.0, + # "track_language_model_loss": False, + # "track_distillation_loss": False, + # } + # }, + # {}, + # False, + # 1, + # marks=pytest.mark.xfail( + # reason="No losses computed when all factors=0 and tracking=False, raises RuntimeError in _add_tensors", + # strict=True, + # ), + # id="zero_factors_no_tracking", + # ), ), ) def test_lm_head( diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 6156cb709..f4e3ecea7 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -552,6 +552,12 @@ def _update_and_add_testing_config( "mistral_distill_logits", updates={ ("model", "base_model", "head", "distillation_model"): "teacher", + ("model", "base_model", "head", "losses"): { + "distillation_loss": { + "type": "revkl_dist", + "factor": 1.0, + }, + }, ("batch", "use_loss_masking_spans"): True, ("reference_models"): { "teacher": { @@ -599,7 +605,7 @@ def _update_and_add_testing_config( "mistral_distill_logits", "mistral_distill_activations", updates={ - ("model", "base_model", "head", "distillation_loss_factor"): 0.001, + ("model", "base_model", "head", "losses", "distillation_loss", "factor"): 0.001, ("model", "base_model", "decoder", "block", "distillation_model"): "teacher", ("model", "base_model", "decoder", "block", "activation_distillation_factor"): 0.1, ("reference_models"): { From d773d986d54ed3cc1729d9bd8992af116c8f20de Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 22 Dec 2025 16:47:11 +0000 Subject: [PATCH 11/29] tests --- fast_llm/layers/language_model/head.py | 4 + tests/layers/test_lm_head.py | 340 +++++++++++++++---------- 2 files changed, 214 insertions(+), 130 deletions(-) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index c8c3be797..c47a87de1 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -432,6 +432,10 @@ def get_loss_definitions(self, count: int = 1) -> list[LossDef]: name=self._total_head_loss_name, formatted_name=_format_name(self._total_head_loss_name), count=count ) ] + if self._config.logit_z_loss > 0.0: + loss_defs.append( + LossDef(name=self._z_loss_name, formatted_name=_format_name(self._z_loss_name), count=count) + ) for loss_name, loss_config in self._config.losses.items(): if loss_config.log_it: loss_def: LossDef = loss_config.get_loss_def( diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 917bb7efd..5835b6673 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -9,6 +9,7 @@ from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.language_model.config import LanguageModelHeadConfig, LanguageModelKwargs from fast_llm.layers.language_model.head import LanguageModelHead +from fast_llm.layers.language_model.lm_head_losses import LossConfig from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.utils import Assert from tests.utils.utils import get_base_model, get_stage, requires_cuda @@ -43,6 +44,20 @@ def _reverse_kl_loss( return loss +def _kl_loss( + logits: torch.Tensor, + target: torch.Tensor, + loss_mask: torch.Tensor | None, + teacher_softmax_temperature: float = 1.0, +): + return _reverse_kl_loss( + target, + logits, + loss_mask, + teacher_softmax_temperature, + ) + + def _lm_head( input_: torch.Tensor, target: torch.Tensor, @@ -54,9 +69,7 @@ def _lm_head( grad_output: float = 1.0, logit_scale_factor: float = 1.0, logit_z_loss=0.0, - distillation_loss_implementation: DistillationLossImpl = DistillationLossImpl.cross_entropy, - language_model_loss_factor: float = 1.0, - distillation_loss_factor: float = 1.0, + losses: dict[str, LossConfig], ): hidden = torch.rms_norm( input_.to(rms_weight.dtype), @@ -66,36 +79,34 @@ def _lm_head( ) logits = torch.nn.functional.linear(hidden, logit_weight).float() - if distillation_loss_implementation == DistillationLossImpl.reverse_kl: - Assert.eq(logits.shape, target.shape) - loss = _reverse_kl_loss( - (logits * logit_scale_factor).flatten(0, -2), (target * logit_scale_factor).flatten(0, -2), loss_mask - ) - # Apply distillation_loss_factor to grad_output for backward - loss.backward(torch.full_like(loss, grad_output * distillation_loss_factor)) - # Return scaled loss - return loss * distillation_loss_factor, None + if "dist_loss" in losses: + if losses["dist_loss"].type == "revkl_dist": + Assert.eq(logits.shape, target.shape) + loss = _reverse_kl_loss( + (logits * logit_scale_factor).flatten(0, -2), (target * logit_scale_factor).flatten(0, -2), loss_mask + ) + # Apply distillation_loss_factor to grad_output for backward + loss.backward(torch.full_like(loss, grad_output * losses["dist_loss"].factor)) + # Return scaled loss + return loss * losses["dist_loss"].factor, None + elif losses["dist_loss"].type == "fkl_dist": + Assert.eq(logits.shape, target.shape) + loss = _kl_loss( + (logits * logit_scale_factor).flatten(0, -2), (target * logit_scale_factor).flatten(0, -2), loss_mask + ) + # Apply distillation_loss_factor to grad_output for backward + loss.backward(torch.full_like(loss, grad_output * losses["dist_loss"].factor)) + # Return scaled loss + return loss * losses["dist_loss"].factor, None if logit_scale_factor != 1.0: logits *= logit_scale_factor z_loss = torch.mean(torch.logsumexp(logits, dim=-1) ** 2) if logit_z_loss > 0 else None - if target.ndim == logits.ndim: - # Distillation loss (cross-entropy with soft targets) - loss = torch.nn.functional.cross_entropy( - logits.flatten(0, -2), target.float().softmax(-1).flatten(0, -2), reduction="none" - ) - if loss_mask is not None: - loss = loss * loss_mask.flatten() - loss = loss.mean() - # Apply distillation_loss_factor - loss.backward(torch.full_like(loss, grad_output * distillation_loss_factor)) - return loss * distillation_loss_factor, z_loss - else: - # Language model loss (cross-entropy with hard labels) - loss = torch.nn.functional.cross_entropy(logits.flatten(0, -2), target.flatten()) - # Apply language_model_loss_factor - loss.backward(torch.full_like(loss, grad_output * language_model_loss_factor)) - return loss * language_model_loss_factor, z_loss + # Language model loss (cross-entropy with hard labels) + loss = torch.nn.functional.cross_entropy(logits.flatten(0, -2), target.flatten()) + # Apply language_model_loss_factor + loss.backward(torch.full_like(loss, grad_output * losses["lm_loss"].factor)) + return loss * losses["lm_loss"].factor, z_loss SEQUENCE_LENGTH = 200 @@ -119,99 +130,169 @@ def _lm_head( ({"tied_embedding_weight": True}, {}, False, 1), ({}, {}, False, 2), ({}, {}, True, 1), + # Skip CE distillation for now - not yet implemented in new losses system # ( # { # "head": { # "distillation_model": "distillation", - # "distillation_loss_implementation": DistillationLossImpl.cross_entropy, - # } - # }, - # {}, - # False, - # 1, - # ), - # ( - # { - # "head": { - # "distillation_model": "distillation", - # "distillation_loss_implementation": DistillationLossImpl.reverse_kl, + # "losses": { + # "lm_loss": { + # "type": "cross_entropy_lm_loss", + # "weight_scalor": 0.0, + # "log_it": False, + # }, + # "dist_loss": { + # "type": "cross_entropy_dist", # TODO: Not implemented yet + # "weight_scalor": 1.0, + # "log_it": True, + # } + # } # } # }, # {}, # False, # 1, # ), + ( + { + "head": { + "distillation_model": "distillation", + "losses": { + "lm_loss": { + "type": "cross_entropy_lm_loss", + "factor": 0.0, + "log_it": False, + }, + "dist_loss": { + "type": "revkl_dist", + "factor": 1.0, + "log_it": True, + }, + }, + } + }, + {}, + False, + 1, + ), + # Skip - CE distillation not implemented # ( # { # "head": { # "distillation_model": "distillation", - # "distillation_loss_implementation": DistillationLossImpl.cross_entropy, - # "language_model_loss_factor": 1.0, - # } - # }, - # {}, - # True, - # 1, - # ), - # ( - # { - # "head": { - # "distillation_model": "distillation", - # "distillation_loss_implementation": DistillationLossImpl.reverse_kl, + # "losses": { + # "lm_loss": { + # "type": "cross_entropy_lm_loss", + # "weight_scalor": 1.0, + # "log_it": True, + # }, + # "dist_loss": { + # "type": "cross_entropy_dist", # TODO + # "weight_scalor": 1.0, + # "log_it": True, + # } + # } # } # }, # {}, # True, # 1, # ), - # pytest.param( - # { - # "head": { - # "distillation_model": "distillation", - # "language_model_loss_factor": 0.0, - # "track_language_model_loss": True, - # "distillation_loss_factor": 1.0, - # } - # }, - # {}, - # False, - # 1, - # id="track_lm_zero_factor", - # ), - # pytest.param( - # { - # "head": { - # "distillation_model": "distillation", - # "language_model_loss_factor": 0.0, - # "distillation_loss_factor": 0.0, - # "track_language_model_loss": True, - # "track_distillation_loss": True, - # } - # }, - # {}, - # False, - # 1, - # id="track_both_zero_factors", - # ), - # pytest.param( - # { - # "head": { - # "distillation_model": "distillation", - # "language_model_loss_factor": 0.0, - # "distillation_loss_factor": 0.0, - # "track_language_model_loss": False, - # "track_distillation_loss": False, - # } - # }, - # {}, - # False, - # 1, - # marks=pytest.mark.xfail( - # reason="No losses computed when all factors=0 and tracking=False, raises RuntimeError in _add_tensors", - # strict=True, - # ), - # id="zero_factors_no_tracking", - # ), + ( + { + "head": { + "distillation_model": "distillation", + "losses": { + "lm_loss": { + "type": "cross_entropy_lm_loss", + "factor": 0.0, + "log_it": False, + }, + "dist_loss": { + "type": "revkl_dist", + "factor": 1.0, + "log_it": True, + }, + }, + } + }, + {}, + True, + 1, + ), + pytest.param( + { + "head": { + "distillation_model": "distillation", + "losses": { + "lm_loss": { + "type": "cross_entropy_lm_loss", + "factor": 0.0, + "log_it": True, # tracking even with zero weight + }, + "dist_loss": { + "type": "revkl_dist", + "factor": 1.0, + "log_it": True, + }, + }, + } + }, + {}, + False, + 1, + id="track_lm_zero_factor", + ), + pytest.param( + { + "head": { + "distillation_model": "distillation", + "losses": { + "lm_loss": { + "type": "cross_entropy_lm_loss", + "factor": 0.0, + "log_it": True, # tracking with zero weight + }, + "dist_loss": { + "type": "revkl_dist", + "factor": 0.0, + "log_it": True, # tracking with zero weight + }, + }, + } + }, + {}, + False, + 1, + id="track_both_zero_factors", + ), + pytest.param( + { + "head": { + "distillation_model": "distillation", + "losses": { + "lm_loss": { + "type": "cross_entropy_lm_loss", + "factor": 0.0, + "log_it": False, # not tracking with zero weight + }, + "dist_loss": { + "type": "revkl_dist", + "factor": 0.0, + "log_it": False, # not tracking with zero weight + }, + }, + } + }, + {}, + False, + 1, + marks=pytest.mark.xfail( + reason="No losses computed when all factors=0 and log_it=False", + strict=True, + ), + id="zero_factors_no_tracking", + ), ), ) def test_lm_head( @@ -222,8 +303,15 @@ def test_lm_head( prediction_heads: int, ): head_config = { - "cross_entropy_implementation": cross_entropy_impl, "normalization": {"type": "rms_norm"}, + "losses": { + "lm_loss": { + "type": "cross_entropy_lm_loss", + "implementation": cross_entropy_impl, + "factor": 1.0, + "log_it": True, + } + }, } config = GPTBaseModelConfig.from_dict( { @@ -280,19 +368,19 @@ def test_lm_head( AttentionKwargs.sequence_first: sequence_first, AttentionKwargs.grad_output: 1.0, } - if head_config.distillation_model is None: - target = torch.randint( - 0, - VOCAB_SIZE, - label_shape, - dtype=torch.int64, - device=distributed.device, - ) - if loss_mask is not None: - target *= loss_mask + # always set lm targets + target = torch.randint( + 0, + VOCAB_SIZE, + label_shape, + dtype=torch.int64, + device=distributed.device, + ) + if loss_mask is not None: + target *= loss_mask - kwargs[LanguageModelKwargs.labels] = target - else: + kwargs[LanguageModelKwargs.labels] = target + if head_config.distillation_model is not None: assert config.head.max_prediction_distance == 1 target = torch.randn( input_.shape[:-1] + (VOCAB_SIZE,), @@ -349,11 +437,7 @@ def test_lm_head( logit_weight=ref_logit_weight, logit_scale_factor=head_config.logits_scale_factor, logit_z_loss=head_config.logit_z_loss, - distillation_loss_implementation=head_config.distillation_loss_implementation, - language_model_loss_factor=( - head_config.language_model_loss_factor if head_config.language_model_loss_factor is not None else 1.0 - ), - distillation_loss_factor=head_config.distillation_loss_factor, + losses=head_config.losses, ) # Prepare LM head inputs @@ -367,19 +451,15 @@ def test_lm_head( lm_head_loss_name = f"lm_head_loss_{prediction_distance}" if prediction_distance > 0 else "lm_head_loss" expected_loss_keys = {lm_head_loss_name} - if head._compute_lm_loss: - lm_loss_name_unscaled = ( - f"lm_loss_unscaled_{prediction_distance}" if prediction_distance > 0 else "lm_loss_unscaled" - ) - lm_loss_name = f"lm_loss_{prediction_distance}" if prediction_distance > 0 else "lm_loss" - expected_loss_keys.add(lm_loss_name_unscaled) - expected_loss_keys.add(lm_loss_name) + # Get expected loss names from the loss configs + for loss_name, loss_config in head._config.losses.items(): + if loss_config.log_it: + formatted_name = loss_config.get_formatted_name(loss_name, prediction_distance) + expected_loss_keys.add(formatted_name) + if ref_z_loss is not None: expected_loss_keys.add(f"z_loss_{prediction_distance}" if prediction_distance > 0 else "z_loss") - if head._compute_distillation_loss: - expected_loss_keys.add("distillation_loss") - expected_loss_keys.add("distillation_loss_unscaled") Assert.eq( {loss_definition.name: loss_definition.count for loss_definition in head.get_loss_definitions()}, From 282925c5bcd6f3b2648aa1cfd4d40bed4058a739 Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 22 Dec 2025 16:51:37 +0000 Subject: [PATCH 12/29] test --- tests/layers/test_lm_head.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 5835b6673..6bdaf3f67 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -293,6 +293,32 @@ def _lm_head( ), id="zero_factors_no_tracking", ), + pytest.param( + { + "head": { + "losses": { + "lm_loss": { + "type": "cross_entropy_lm_loss", + "factor": 1.0, + "log_it": False, # not tracking with zero weight + }, + "dist_loss": { + "type": "revkl_dist", + "factor": 1.0, + "log_it": True, # not tracking with zero weight + }, + }, + } + }, + {}, + False, + 1, + marks=pytest.mark.xfail( + reason="Cannot track distillation loss without distillation model being set", + strict=True, + ), + id="track_distillation_without_model", + ), ), ) def test_lm_head( From 0f73ea23d62e43c41c45a9e755e9e3db38a3a5a3 Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 22 Dec 2025 17:54:53 +0000 Subject: [PATCH 13/29] tests --- fast_llm/layers/language_model/config.py | 13 ++--- fast_llm/layers/language_model/head.py | 1 + .../layers/language_model/lm_head_losses.py | 47 +++++++++---------- tests/test_config.py | 1 + tests/utils/model_configs.py | 28 +++-------- 5 files changed, 35 insertions(+), 55 deletions(-) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 786d312d8..411e98f4c 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -209,17 +209,12 @@ def layer_class(self) -> "type[LanguageModelHead]": return LanguageModelHead def _validate(self) -> None: - with self._set_implicit_default(): - if not self.losses: - self.losses = { - "lm_loss": LossConfig._from_dict({"type": "cross_entropy_lm_loss", "factor": 1.0, "log_it": True}) - } - - for loss_config in self.losses.values(): - if "dist" in loss_config.type: - assert self.distillation_model is not None, "Distillation loss requires a distillation model." + for loss_config in self.losses.values(): + if "dist" in loss_config.type: + assert self.distillation_model is not None, "Distillation loss requires a distillation model." super()._validate() assert self.dpo_reference_model is None or self.distillation_model is None # currently don't support both + # Note: Default loss is handled at runtime in head.py if losses dict is empty @property def max_prediction_distance(self) -> int: diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index c47a87de1..e1f303323 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -100,6 +100,7 @@ def __init__( lr_scale=self._lr_scale, peft=self._peft, ) + assert self._config.losses, "At least one loss must be configured." self._formatted_loss_names = { loss_name: loss_config.get_formatted_name(loss_name, self._prediction_distance) for loss_name, loss_config in self._config.losses.items() diff --git a/fast_llm/layers/language_model/lm_head_losses.py b/fast_llm/layers/language_model/lm_head_losses.py index a231efa5a..9fd946625 100644 --- a/fast_llm/layers/language_model/lm_head_losses.py +++ b/fast_llm/layers/language_model/lm_head_losses.py @@ -3,17 +3,16 @@ import logging import typing -import torch - from fast_llm.config import Config, Field, FieldHint, check_field, config_class -from fast_llm.core.distributed import ProcessGroup from fast_llm.engine.base_model.config import LossDef from fast_llm.engine.config_utils.data_type import DataType from fast_llm.functional.config import CrossEntropyImpl, TargetFormat, TritonConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: - pass + import torch + + from fast_llm.core.distributed import ProcessGroup logger = logging.getLogger(__name__) @@ -32,13 +31,13 @@ def _format_name(name: str) -> str: @dataclasses.dataclass class Targets: - lm_target: torch.Tensor | None = None - dpo_target: torch.Tensor | None = None - loss_mask: torch.Tensor | None = None + lm_target: "torch.Tensor | None" = None + dpo_target: "torch.Tensor | None" = None + loss_mask: "torch.Tensor | None" = None chosen_spans: list[list[tuple[int, int]]] | None = None rejected_spans: list[list[tuple[int, int]]] | None = None - reference_model_logits: torch.Tensor | None = None - dpo_reference_model_logits: torch.Tensor | None = None + reference_model_logits: "torch.Tensor | None" = None + dpo_reference_model_logits: "torch.Tensor | None" = None def has_any_target(self) -> bool: return any(getattr(self, field.name) is not None for field in dataclasses.fields(self)) @@ -70,14 +69,14 @@ class LossConfig(Config): @abc.abstractmethod def compute_loss( self, - logits: torch.Tensor, + logits: "torch.Tensor", target: Targets, grad_output: float | None = None, - group: ProcessGroup | None = None, + group: "ProcessGroup" = None, logits_scale_factor: float | None = None, vocab_parallel: bool = False, **kwargs, - ) -> tuple[torch.Tensor, torch.Tensor | None]: + ) -> "tuple[torch.Tensor, torch.Tensor | None]": pass def get_loss_def(self, name: str, count: int = 1, prediction_distance: int | None = None) -> LossDef: @@ -124,14 +123,14 @@ class CrossEntropyLMLossConfig(LossConfig): def compute_loss( self, - logits: torch.Tensor, + logits: "torch.Tensor", targets: Targets, grad_output: float | None = None, - group: ProcessGroup | None = None, + group: "ProcessGroup" = None, logits_scale_factor: float | None = None, vocab_parallel: bool = False, **kwargs, - ) -> tuple[torch.Tensor, torch.Tensor | None]: + ) -> "tuple[torch.Tensor, torch.Tensor | None]": from fast_llm.functional.cross_entropy import cross_entropy_forward_backward target = targets.lm_target @@ -176,13 +175,13 @@ class ForwardKLLossConfig(LossConfig): def compute_loss( self, - logits: torch.Tensor, + logits: "torch.Tensor", targets: Targets, grad_output: float | None = None, - group: ProcessGroup | None = None, + group: "ProcessGroup" = None, logits_scale_factor: float | None = None, **kwargs, - ) -> tuple[torch.Tensor, torch.Tensor | None]: + ) -> "tuple[torch.Tensor, torch.Tensor | None]": from fast_llm.functional.cross_entropy import forward_kl_forward_backward target = targets.reference_model_logits @@ -218,13 +217,13 @@ class ReverseKLLossConfig(LossConfig): def compute_loss( self, - logits: torch.Tensor, + logits: "torch.Tensor", targets: Targets, grad_output: float | None = None, - group: ProcessGroup | None = None, + group: "ProcessGroup" = None, logits_scale_factor: float | None = None, **kwargs, - ) -> tuple[torch.Tensor, torch.Tensor | None]: + ) -> "tuple[torch.Tensor, torch.Tensor | None]": from fast_llm.functional.cross_entropy import reverse_kl_forward_backward # Use distillation_target for KL losses @@ -261,12 +260,12 @@ class DPOLossConfig(LossConfig): def compute_loss( self, - logits: torch.Tensor, + logits: "torch.Tensor", targets: Targets, grad_output: float | None = None, - group: ProcessGroup | None = None, + group: "ProcessGroup" = None, **kwargs, - ) -> tuple[torch.Tensor, torch.Tensor | None]: + ) -> "tuple[torch.Tensor, torch.Tensor | None]": from fast_llm.functional.dpo import compute_dpo_loss return compute_dpo_loss( diff --git a/tests/test_config.py b/tests/test_config.py index 4020b6fbc..8d6f39249 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -147,6 +147,7 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): "normalization": {"implementation": "triton"}, }, "num_blocks": 12, + "head": {}, }, "hidden_size": 512, "tied_embedding_weight": False, diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index f4e3ecea7..3cadb4e20 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -240,7 +240,12 @@ def _update_and_add_testing_config( }, "num_blocks": 2, }, - "head": {"output_weight": init_1}, + "head": { + "output_weight": init_1, + "losses": { + "lm_loss": {"type": "cross_entropy_lm_loss", "factor": 1.0, "log_it": True}, + }, + }, "hidden_size": 256, "tied_embedding_weight": True, }, @@ -580,27 +585,6 @@ def _update_and_add_testing_config( skip_tests=("ms", "pp2s1_bf4", "pp2s2_bf4", "sdp2"), ) -_update_and_add_testing_config( - "mistral_distill_logits", - "mistral_reverse_kl", - updates={ - ("model", "base_model", "head", "distillation_loss_implementation"): "reverse_kl", - }, - megatron_args=None, - checkpoint_format=MistralCheckpointFormat, - groups={ - ModelTestingGroup.basic: ModelTestingGroupAction.normal, - ModelTestingGroup.checkpoint: ModelTestingGroupAction.unimportant, - ModelTestingGroup.convert: ModelTestingGroupAction.unimportant, - ModelTestingGroup.generate: ModelTestingGroupAction.unimportant, - ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, - ModelTestingGroup.distributed: ModelTestingGroupAction.broken, # failing: fp16, tp2, stp2, stp2_ce4 - }, - compare_factor=2, - # Modes not supported with reference models - skip_tests=("sdp", "ms", "pp"), -) - _update_and_add_testing_config( "mistral_distill_logits", "mistral_distill_activations", From fa85c415abd4481baba7ac9b9e037854e72cea82 Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 22 Dec 2025 22:27:28 +0000 Subject: [PATCH 14/29] wip --- fast_llm/functional/cross_entropy.py | 104 +++----------- fast_llm/layers/language_model/config.py | 4 +- fast_llm/layers/language_model/head.py | 13 +- .../layers/language_model/lm_head_losses.py | 30 ++-- tests/layers/test_lm_head.py | 132 +++--------------- tests/utils/model_configs.py | 4 +- 6 files changed, 55 insertions(+), 232 deletions(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index f534d8a78..06c85848c 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -85,6 +85,7 @@ def _fused_cross_entropy_forward_backward( target_format: TargetFormat, group: ProcessGroup | None = None, teacher_softmax_temperature: float = 1.0, + return_target_entropy: bool = False, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ A fused implementation of cross-entropy with torch compile. @@ -158,6 +159,16 @@ def _fused_cross_entropy_forward_backward( loss = per_sample_loss.mean() if target_format != TargetFormat.labels and group is not None: all_reduce(loss, op=ReduceOp.AVG, group=group) + if return_target_entropy and target_format == TargetFormat.logits: + # Compute teacher entropy + teacher_log_prob = torch.log(target + 1e-20) + target_entropy = -(target * teacher_log_prob).sum(dim=-1) + if loss_mask is not None: + target_entropy = target_entropy * loss_mask.squeeze(-1) + target_entropy = target_entropy.mean() + if group is not None: + all_reduce(target_entropy, op=ReduceOp.SUM, group=group) + return loss, grad, target_entropy return loss, grad @@ -362,78 +373,6 @@ def reverse_kl_forward_backward( return distillation_loss, distillation_grad -@torch.compile -def _forward_kl_forward_backward( - logits: torch.Tensor, - target: torch.Tensor, - loss_mask: torch.Tensor | None, - grad_output: float | None, - group: ProcessGroup | None = None, - logits_scale_factor: float = 1.0, - teacher_softmax_temperature: float = 1.0, -) -> tuple[torch.Tensor, torch.Tensor | None]: - """ - Forward KL: KL(p||q) where p=teacher, q=student. - This is reverse KL with roles swapped in the loss computation. - - Key insight: KL(p||q) = sum_i p_i * log(p_i/q_i) - = sum_i p_i * (log(p_i) - log(q_i)) - which is reverse KL with p and q swapped. - - However, we still need grad w.r.t. student logits, so gradient is different: - d/d(student_logits) KL(p||q) = student_probs - teacher_probs - """ - Assert.eq( - teacher_softmax_temperature, - 1, - msg="Teacher softmax temperature must be 1 for sequence-tensor-parallel forward KL", - ) - Assert.eq(logits_scale_factor, 1, msg="Logits scale factor must be 1 for sequence-tensor-parallel forward KL") - Assert.eq(target.shape, logits.shape) - assert target.dtype.is_floating_point, target.dtype - if loss_mask is not None: - Assert.eq(loss_mask.shape, logits.shape[:-1]) - - # Compute log softmax for both teacher and student - teacher_log_probs = distributed_log_softmax(target.float(), group=group) - student_log_probs = distributed_log_softmax(logits, group=group) - - teacher_probs = teacher_log_probs.exp() - # Forward KL: p * log(p/q) = p * (log_p - log_q) - log_ratio = teacher_log_probs - student_log_probs - del teacher_log_probs - - # Compute loss: sum over vocab of teacher_probs * log_ratio - loss_terms = (teacher_probs * log_ratio).sum(dim=-1) - del log_ratio - - if loss_mask is not None: - valid = loss_mask.to(loss_terms.dtype) - loss_terms = loss_terms * valid - valid_tokens = torch.prod(torch.tensor(loss_terms.shape, device=loss_terms.device, dtype=loss_terms.dtype)) - loss = loss_terms.sum() - - if group is not None: - all_reduce(loss, op=ReduceOp.SUM, group=group) - loss /= valid_tokens - - if grad_output is not None: - # Gradient: d/d(student_logits) KL(p||q) = student_probs - teacher_probs - student_probs = student_log_probs.exp() - grad_base = student_probs - teacher_probs - del student_probs, teacher_probs, student_log_probs - - if loss_mask is not None: - grad_base.mul_(loss_mask.to(logits.dtype).unsqueeze(-1)) - - grad_base.mul_(grad_output / valid_tokens) - grad = grad_base.to(logits.dtype) - else: - grad = None - - return loss.detach_(), grad - - def forward_kl_forward_backward( logits: torch.Tensor, target: torch.Tensor, @@ -467,25 +406,20 @@ def forward_kl_forward_backward( loss: Forward KL divergence loss grad: Gradients w.r.t. logits """ - - if sequence_parallel_logits: - # TODO: see hybrid dev branch where it is implemented - raise NotImplementedError("Sequence-parallel forward KL is not implemented yet, set vocab_parallel true") - - Assert.eq(target_format, TargetFormat.logits, msg="Forward KL only supports logits format") + assert target_format == TargetFormat.logits, "Forward KL only supports logits format" Assert.eq(target.shape, logits.shape) - assert target.dtype.is_floating_point, target.dtype - if loss_mask is not None: - Assert.eq(loss_mask.shape, logits.shape[:-1]) - - # TODO: implement fused? - distillation_loss, distillation_grad = _forward_kl_forward_backward( + distillation_loss, distillation_grad, teacher_entropy = _fused_cross_entropy_forward_backward( logits=logits, target=target, loss_mask=loss_mask, grad_output=grad_output, logits_scale_factor=logits_scale_factor, - teacher_softmax_temperature=teacher_softmax_temperature, + target_format=target_format, group=group, + teacher_softmax_temperature=teacher_softmax_temperature, + return_target_entropy=True, + **kwargs, ) + distillation_loss -= teacher_entropy + return distillation_loss, distillation_grad diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 411e98f4c..e2ce6ae19 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -9,7 +9,7 @@ from fast_llm.layers.common.normalization.config import NormalizationConfig from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.config import DecoderBlockConfig -from fast_llm.layers.language_model.lm_head_losses import LossConfig +from fast_llm.layers.language_model.lm_head_losses import LanguageModelLossConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -135,7 +135,7 @@ class LanguageModelHeadConfig(LanguageModelHeadBaseConfig): desc="Configuration for the final normalization layer.", hint=FieldHint.architecture, ) - losses: dict[str, LossConfig] = Field( + losses: dict[str, LanguageModelLossConfig] = Field( default_factory=dict, desc="A dictionary of loss names and their configurations.", hint=FieldHint.core, diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index e1f303323..6ba45c242 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -375,7 +375,7 @@ def _logits_loss_forward_backward( total_loss, grad = None, None for loss_name, loss_config in self._config.losses.items(): - if loss_config.factor == 0.0 and not loss_config.log_it: + if loss_config.factor == 0.0: continue # losses are returned unscaled but the grads are already scaled # we log unscaled losses seperately and the scaled total loss @@ -391,7 +391,7 @@ def _logits_loss_forward_backward( ) loss_ = loss_unscaled_ * loss_config.factor * self._loss_coefficient - if losses is not None and loss_config.log_it: + if losses is not None: losses[self._formatted_loss_names[loss_name]].append(loss_unscaled_.detach()) if total_loss is None: @@ -438,11 +438,10 @@ def get_loss_definitions(self, count: int = 1) -> list[LossDef]: LossDef(name=self._z_loss_name, formatted_name=_format_name(self._z_loss_name), count=count) ) for loss_name, loss_config in self._config.losses.items(): - if loss_config.log_it: - loss_def: LossDef = loss_config.get_loss_def( - name=loss_name, count=count, prediction_distance=self._prediction_distance - ) - loss_defs.append(loss_def) + loss_def: LossDef = loss_config.get_loss_def( + name=loss_name, count=count, prediction_distance=self._prediction_distance + ) + loss_defs.append(loss_def) return loss_defs diff --git a/fast_llm/layers/language_model/lm_head_losses.py b/fast_llm/layers/language_model/lm_head_losses.py index 9fd946625..3695954bd 100644 --- a/fast_llm/layers/language_model/lm_head_losses.py +++ b/fast_llm/layers/language_model/lm_head_losses.py @@ -44,10 +44,10 @@ def has_any_target(self) -> bool: @config_class(registry=True) -class LossConfig(Config): +class LanguageModelLossConfig(Config): """ Losses canm register themselves - using @config_class(dynamic_type={LossConfig: "loss_type_name"}) + using @config_class(dynamic_type= {LanguageModelLossConfig: "loss_type_name"}) """ _name: typing.ClassVar[str] @@ -60,12 +60,6 @@ class LossConfig(Config): valid=check_field(Assert.geq, 0.0), ) - log_it: bool = Field( - default=True, - hint=FieldHint.optional, - desc="Whether to log this loss.", - ) - @abc.abstractmethod def compute_loss( self, @@ -90,10 +84,6 @@ def get_loss_def(self, name: str, count: int = 1, prediction_distance: int | Non def _validate(self): Assert.geq(self.factor, 0.0) - if self.factor > 0.0: - with self._set_implicit_default(): - if "log_it" not in self._explicit_fields: - self.log_it = True super()._validate() def get_formatted_name(self, name=None, prediction_distance: int | None = None) -> str: @@ -103,8 +93,8 @@ def get_formatted_name(self, name=None, prediction_distance: int | None = None) return name -@config_class(dynamic_type={LossConfig: "cross_entropy_lm_loss"}) -class CrossEntropyLMLossConfig(LossConfig): +@config_class(dynamic_type={LanguageModelLossConfig: "cross_entropy"}) +class CrossEntropyLMLossConfig(LanguageModelLossConfig): _name: typing.ClassVar[str] = "CE" _abstract: typing.ClassVar[bool] = False @@ -159,8 +149,8 @@ def compute_loss( ) -@config_class(dynamic_type={LossConfig: "fkl_dist"}) -class ForwardKLLossConfig(LossConfig): +@config_class(dynamic_type={LanguageModelLossConfig: "forward_kl_distillation"}) +class ForwardKLLossConfig(LanguageModelLossConfig): """Forward KL divergence KL(p||q) for distillation (mode-covering).""" _name: typing.ClassVar[str] = "FwdKL" @@ -201,8 +191,8 @@ def compute_loss( ) -@config_class(dynamic_type={LossConfig: "revkl_dist"}) -class ReverseKLLossConfig(LossConfig): +@config_class(dynamic_type={LanguageModelLossConfig: "reverse_kl_distillation"}) +class ReverseKLLossConfig(LanguageModelLossConfig): """Reverse KL divergence KL(q||p) for distillation (mode-seeking).""" _name: typing.ClassVar[str] = "RevKL" @@ -244,8 +234,8 @@ def compute_loss( ) -@config_class(dynamic_type={LossConfig: "dpo"}) -class DPOLossConfig(LossConfig): +@config_class(dynamic_type={LanguageModelLossConfig: "dpo"}) +class DPOLossConfig(LanguageModelLossConfig): """Direct Preference Optimization (DPO) loss for alignment.""" _name: typing.ClassVar[str] = "DPO" diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 6bdaf3f67..ddfc2fc12 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -9,7 +9,7 @@ from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.language_model.config import LanguageModelHeadConfig, LanguageModelKwargs from fast_llm.layers.language_model.head import LanguageModelHead -from fast_llm.layers.language_model.lm_head_losses import LossConfig +from fast_llm.layers.language_model.lm_head_losses import LanguageModelLossConfig from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.utils import Assert from tests.utils.utils import get_base_model, get_stage, requires_cuda @@ -69,7 +69,7 @@ def _lm_head( grad_output: float = 1.0, logit_scale_factor: float = 1.0, logit_z_loss=0.0, - losses: dict[str, LossConfig], + losses: dict[str, LanguageModelLossConfig], ): hidden = torch.rms_norm( input_.to(rms_weight.dtype), @@ -80,7 +80,7 @@ def _lm_head( logits = torch.nn.functional.linear(hidden, logit_weight).float() if "dist_loss" in losses: - if losses["dist_loss"].type == "revkl_dist": + if losses["dist_loss"].type == "reverse_kl_distillation": Assert.eq(logits.shape, target.shape) loss = _reverse_kl_loss( (logits * logit_scale_factor).flatten(0, -2), (target * logit_scale_factor).flatten(0, -2), loss_mask @@ -89,7 +89,7 @@ def _lm_head( loss.backward(torch.full_like(loss, grad_output * losses["dist_loss"].factor)) # Return scaled loss return loss * losses["dist_loss"].factor, None - elif losses["dist_loss"].type == "fkl_dist": + elif losses["dist_loss"].type == "forward_kl_distillation": Assert.eq(logits.shape, target.shape) loss = _kl_loss( (logits * logit_scale_factor).flatten(0, -2), (target * logit_scale_factor).flatten(0, -2), loss_mask @@ -137,14 +137,12 @@ def _lm_head( # "distillation_model": "distillation", # "losses": { # "lm_loss": { - # "type": "cross_entropy_lm_loss", + # "type": "cross_entropy", # "weight_scalor": 0.0, - # "log_it": False, # }, # "dist_loss": { # "type": "cross_entropy_dist", # TODO: Not implemented yet # "weight_scalor": 1.0, - # "log_it": True, # } # } # } @@ -153,87 +151,18 @@ def _lm_head( # False, # 1, # ), - ( - { - "head": { - "distillation_model": "distillation", - "losses": { - "lm_loss": { - "type": "cross_entropy_lm_loss", - "factor": 0.0, - "log_it": False, - }, - "dist_loss": { - "type": "revkl_dist", - "factor": 1.0, - "log_it": True, - }, - }, - } - }, - {}, - False, - 1, - ), - # Skip - CE distillation not implemented - # ( - # { - # "head": { - # "distillation_model": "distillation", - # "losses": { - # "lm_loss": { - # "type": "cross_entropy_lm_loss", - # "weight_scalor": 1.0, - # "log_it": True, - # }, - # "dist_loss": { - # "type": "cross_entropy_dist", # TODO - # "weight_scalor": 1.0, - # "log_it": True, - # } - # } - # } - # }, - # {}, - # True, - # 1, - # ), - ( - { - "head": { - "distillation_model": "distillation", - "losses": { - "lm_loss": { - "type": "cross_entropy_lm_loss", - "factor": 0.0, - "log_it": False, - }, - "dist_loss": { - "type": "revkl_dist", - "factor": 1.0, - "log_it": True, - }, - }, - } - }, - {}, - True, - 1, - ), pytest.param( { "head": { "distillation_model": "distillation", "losses": { "lm_loss": { - "type": "cross_entropy_lm_loss", + "type": "cross_entropy", "factor": 0.0, - "log_it": True, # tracking even with zero weight }, "dist_loss": { - "type": "revkl_dist", + "type": "reverse_kl_distillation", "factor": 1.0, - "log_it": True, }, }, } @@ -249,37 +178,12 @@ def _lm_head( "distillation_model": "distillation", "losses": { "lm_loss": { - "type": "cross_entropy_lm_loss", + "type": "cross_entropy", "factor": 0.0, - "log_it": True, # tracking with zero weight }, "dist_loss": { - "type": "revkl_dist", + "type": "reverse_kl_distillation", "factor": 0.0, - "log_it": True, # tracking with zero weight - }, - }, - } - }, - {}, - False, - 1, - id="track_both_zero_factors", - ), - pytest.param( - { - "head": { - "distillation_model": "distillation", - "losses": { - "lm_loss": { - "type": "cross_entropy_lm_loss", - "factor": 0.0, - "log_it": False, # not tracking with zero weight - }, - "dist_loss": { - "type": "revkl_dist", - "factor": 0.0, - "log_it": False, # not tracking with zero weight }, }, } @@ -288,24 +192,22 @@ def _lm_head( False, 1, marks=pytest.mark.xfail( - reason="No losses computed when all factors=0 and log_it=False", + reason="Cannot track both losses with zero factor", strict=True, ), - id="zero_factors_no_tracking", + id="track_both_zero_factors", ), pytest.param( { "head": { "losses": { "lm_loss": { - "type": "cross_entropy_lm_loss", + "type": "cross_entropy", "factor": 1.0, - "log_it": False, # not tracking with zero weight }, "dist_loss": { - "type": "revkl_dist", + "type": "reverse_kl_distillation", "factor": 1.0, - "log_it": True, # not tracking with zero weight }, }, } @@ -332,10 +234,9 @@ def test_lm_head( "normalization": {"type": "rms_norm"}, "losses": { "lm_loss": { - "type": "cross_entropy_lm_loss", + "type": "cross_entropy", "implementation": cross_entropy_impl, "factor": 1.0, - "log_it": True, } }, } @@ -480,9 +381,8 @@ def test_lm_head( # Get expected loss names from the loss configs for loss_name, loss_config in head._config.losses.items(): - if loss_config.log_it: - formatted_name = loss_config.get_formatted_name(loss_name, prediction_distance) - expected_loss_keys.add(formatted_name) + formatted_name = loss_config.get_formatted_name(loss_name, prediction_distance) + expected_loss_keys.add(formatted_name) if ref_z_loss is not None: expected_loss_keys.add(f"z_loss_{prediction_distance}" if prediction_distance > 0 else "z_loss") diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 3cadb4e20..93c78b58f 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -243,7 +243,7 @@ def _update_and_add_testing_config( "head": { "output_weight": init_1, "losses": { - "lm_loss": {"type": "cross_entropy_lm_loss", "factor": 1.0, "log_it": True}, + "lm_loss": {"type": "cross_entropy", "factor": 1.0}, }, }, "hidden_size": 256, @@ -559,7 +559,7 @@ def _update_and_add_testing_config( ("model", "base_model", "head", "distillation_model"): "teacher", ("model", "base_model", "head", "losses"): { "distillation_loss": { - "type": "revkl_dist", + "type": "reverse_kl_distillation", "factor": 1.0, }, }, From 31cfb84dd2081c0d1c40f31dee20859105e50146 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 23 Dec 2025 02:22:15 +0000 Subject: [PATCH 15/29] wip --- fast_llm/data/dataset/gpt/config.py | 1 - fast_llm/layers/language_model/config.py | 14 ++++++++++++-- fast_llm/layers/language_model/head.py | 2 +- tests/test_config.py | 8 +++++++- 4 files changed, 20 insertions(+), 5 deletions(-) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 41a2fe7ff..5e978ac2b 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -65,7 +65,6 @@ def build(self, preprocessing: PreprocessingConfig) -> SamplableDataset[SampleTy def _load_config(self) -> SampledDatasetConfig[SampleType]: assert self.path.is_file(), f"File {self.path} does not exist." config = yaml.safe_load(self.path.open("r")) - Assert.eq(config.keys(), {"config", "metadata"}) if config.keys() == {"config", "metadata"}: # Newer format with metadata config = config["config"] diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index e2ce6ae19..58e85f5d8 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -209,12 +209,22 @@ def layer_class(self) -> "type[LanguageModelHead]": return LanguageModelHead def _validate(self) -> None: + with self._set_implicit_default(): + if not self.losses: + if "losses" not in self._explicit_fields: + self.losses = { + "lm_loss": LanguageModelLossConfig._from_dict( + { + "type": "cross_entropy", + "factor": 1.0, + } + ) + } for loss_config in self.losses.values(): - if "dist" in loss_config.type: + if "distillation" in loss_config.type: assert self.distillation_model is not None, "Distillation loss requires a distillation model." super()._validate() assert self.dpo_reference_model is None or self.distillation_model is None # currently don't support both - # Note: Default loss is handled at runtime in head.py if losses dict is empty @property def max_prediction_distance(self) -> int: diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 6ba45c242..a67869f8b 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -100,7 +100,7 @@ def __init__( lr_scale=self._lr_scale, peft=self._peft, ) - assert self._config.losses, "At least one loss must be configured." + self._formatted_loss_names = { loss_name: loss_config.get_formatted_name(loss_name, self._prediction_distance) for loss_name, loss_config in self._config.losses.items() diff --git a/tests/test_config.py b/tests/test_config.py index 8d6f39249..81137b587 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -147,14 +147,16 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): "normalization": {"implementation": "triton"}, }, "num_blocks": 12, - "head": {}, }, + "head": {"losses": {"lm_loss": {"type": "cross_entropy", "factor": 1.0}}}, "hidden_size": 512, "tied_embedding_weight": False, "peft": {"freeze_others": False}, } else: expected_config["base_model"] = base_model_update + # added by default + expected_config["base_model"]["head"] = {"losses": {"lm_loss": {"type": "cross_entropy", "factor": 1.0}}} check_equal_nested(_trim_type(serialized_config), _trim_type(expected_config)) @@ -297,3 +299,7 @@ def test_distributed_global_ranks(bdp: int, sdp: int, tp: int, pp: int, pipeline Assert.eq(len({global_rank for global_ranks in global_ranks_set for global_rank in global_ranks}), world_size) Assert.eq(len(rank_breakdowns), world_size) + + +if __name__ == "__main__": + pytest.main([__file__]) From 24fe67bbebbdd9a8aa5ad1393b43250ced3b8629 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 23 Dec 2025 15:43:26 +0000 Subject: [PATCH 16/29] no grad if factor 0 --- fast_llm/layers/language_model/head.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index a67869f8b..50240f49c 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -383,7 +383,9 @@ def _logits_loss_forward_backward( logits, targets, grad_output=( - grad_output * self._loss_coefficient * loss_config.factor if grad_output is not None else None + (grad_output * self._loss_coefficient * loss_config.factor if grad_output is not None else None) + if loss_config.factor != 0.0 + else None ), group=group, logits_scale_factor=self._config.logits_scale_factor, From 0e562e99198e8414b1c026d17cd3383c7acc2f55 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 23 Dec 2025 17:00:00 +0000 Subject: [PATCH 17/29] addressed comments --- fast_llm/layers/language_model/config.py | 2 +- fast_llm/layers/language_model/head.py | 8 +++--- .../layers/language_model/lm_head_losses.py | 4 +-- tests/layers/test_lm_head.py | 26 +++++++++---------- tests/test_config.py | 4 +-- 5 files changed, 22 insertions(+), 22 deletions(-) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 58e85f5d8..4bd8a592c 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -216,7 +216,7 @@ def _validate(self) -> None: "lm_loss": LanguageModelLossConfig._from_dict( { "type": "cross_entropy", - "factor": 1.0, + "weight": 1.0, } ) } diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 50240f49c..40c099617 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -375,7 +375,7 @@ def _logits_loss_forward_backward( total_loss, grad = None, None for loss_name, loss_config in self._config.losses.items(): - if loss_config.factor == 0.0: + if loss_config.weight == 0.0: continue # losses are returned unscaled but the grads are already scaled # we log unscaled losses seperately and the scaled total loss @@ -383,15 +383,15 @@ def _logits_loss_forward_backward( logits, targets, grad_output=( - (grad_output * self._loss_coefficient * loss_config.factor if grad_output is not None else None) - if loss_config.factor != 0.0 + (grad_output * self._loss_coefficient * loss_config.weight if grad_output is not None else None) + if loss_config.weight != 0.0 else None ), group=group, logits_scale_factor=self._config.logits_scale_factor, vocab_parallel=self._vocab_parallel, ) - loss_ = loss_unscaled_ * loss_config.factor * self._loss_coefficient + loss_ = loss_unscaled_ * loss_config.weight * self._loss_coefficient if losses is not None: losses[self._formatted_loss_names[loss_name]].append(loss_unscaled_.detach()) diff --git a/fast_llm/layers/language_model/lm_head_losses.py b/fast_llm/layers/language_model/lm_head_losses.py index 3695954bd..dc367be65 100644 --- a/fast_llm/layers/language_model/lm_head_losses.py +++ b/fast_llm/layers/language_model/lm_head_losses.py @@ -53,7 +53,7 @@ class LanguageModelLossConfig(Config): _name: typing.ClassVar[str] _abstract: typing.ClassVar[bool] = True - factor: float = Field( + weight: float = Field( default=1.0, hint=FieldHint.core, desc="Weight for this loss in the total loss computation.", @@ -83,7 +83,7 @@ def get_loss_def(self, name: str, count: int = 1, prediction_distance: int | Non ) def _validate(self): - Assert.geq(self.factor, 0.0) + Assert.geq(self.weight, 0.0) super()._validate() def get_formatted_name(self, name=None, prediction_distance: int | None = None) -> str: diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index ddfc2fc12..7f9e55b79 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -86,18 +86,18 @@ def _lm_head( (logits * logit_scale_factor).flatten(0, -2), (target * logit_scale_factor).flatten(0, -2), loss_mask ) # Apply distillation_loss_factor to grad_output for backward - loss.backward(torch.full_like(loss, grad_output * losses["dist_loss"].factor)) + loss.backward(torch.full_like(loss, grad_output * losses["dist_loss"].weight)) # Return scaled loss - return loss * losses["dist_loss"].factor, None + return loss * losses["dist_loss"].weight, None elif losses["dist_loss"].type == "forward_kl_distillation": Assert.eq(logits.shape, target.shape) loss = _kl_loss( (logits * logit_scale_factor).flatten(0, -2), (target * logit_scale_factor).flatten(0, -2), loss_mask ) # Apply distillation_loss_factor to grad_output for backward - loss.backward(torch.full_like(loss, grad_output * losses["dist_loss"].factor)) + loss.backward(torch.full_like(loss, grad_output * losses["dist_loss"].weight)) # Return scaled loss - return loss * losses["dist_loss"].factor, None + return loss * losses["dist_loss"].weight, None if logit_scale_factor != 1.0: logits *= logit_scale_factor @@ -105,8 +105,8 @@ def _lm_head( # Language model loss (cross-entropy with hard labels) loss = torch.nn.functional.cross_entropy(logits.flatten(0, -2), target.flatten()) # Apply language_model_loss_factor - loss.backward(torch.full_like(loss, grad_output * losses["lm_loss"].factor)) - return loss * losses["lm_loss"].factor, z_loss + loss.backward(torch.full_like(loss, grad_output * losses["lm_loss"].weight)) + return loss * losses["lm_loss"].weight, z_loss SEQUENCE_LENGTH = 200 @@ -158,11 +158,11 @@ def _lm_head( "losses": { "lm_loss": { "type": "cross_entropy", - "factor": 0.0, + "weight": 0.0, }, "dist_loss": { "type": "reverse_kl_distillation", - "factor": 1.0, + "weight": 1.0, }, }, } @@ -179,11 +179,11 @@ def _lm_head( "losses": { "lm_loss": { "type": "cross_entropy", - "factor": 0.0, + "weight": 0.0, }, "dist_loss": { "type": "reverse_kl_distillation", - "factor": 0.0, + "weight": 0.0, }, }, } @@ -203,11 +203,11 @@ def _lm_head( "losses": { "lm_loss": { "type": "cross_entropy", - "factor": 1.0, + "weight": 1.0, }, "dist_loss": { "type": "reverse_kl_distillation", - "factor": 1.0, + "weight": 1.0, }, }, } @@ -236,7 +236,7 @@ def test_lm_head( "lm_loss": { "type": "cross_entropy", "implementation": cross_entropy_impl, - "factor": 1.0, + "weight": 1.0, } }, } diff --git a/tests/test_config.py b/tests/test_config.py index 81137b587..3c6a76a35 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -148,7 +148,7 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): }, "num_blocks": 12, }, - "head": {"losses": {"lm_loss": {"type": "cross_entropy", "factor": 1.0}}}, + "head": {"losses": {"lm_loss": {"type": "cross_entropy", "weight": 1.0}}}, "hidden_size": 512, "tied_embedding_weight": False, "peft": {"freeze_others": False}, @@ -156,7 +156,7 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): else: expected_config["base_model"] = base_model_update # added by default - expected_config["base_model"]["head"] = {"losses": {"lm_loss": {"type": "cross_entropy", "factor": 1.0}}} + expected_config["base_model"]["head"] = {"losses": {"lm_loss": {"type": "cross_entropy", "weight": 1.0}}} check_equal_nested(_trim_type(serialized_config), _trim_type(expected_config)) From 52c1c113d1fe32732b7bc2c666c0cfd6303abca8 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 23 Dec 2025 17:44:53 +0000 Subject: [PATCH 18/29] addressed comments --- fast_llm/functional/cross_entropy.py | 4 --- fast_llm/layers/language_model/head.py | 11 ++----- .../layers/language_model/lm_head_losses.py | 29 ++++++++++--------- tests/utils/model_configs.py | 2 +- 4 files changed, 19 insertions(+), 27 deletions(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 06c85848c..03f7a88ef 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -247,7 +247,6 @@ def _reverse_kl_forward_backward( group: ProcessGroup | None = None, logits_scale_factor: float = 1.0, teacher_softmax_temperature: float = 1.0, - **kwargs, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Reverse KL using PyTorch's native kl_div function. @@ -325,7 +324,6 @@ def reverse_kl_forward_backward( teacher_softmax_temperature: float = 1.0, target_format: TargetFormat = TargetFormat.labels, sequence_parallel_logits: bool = False, - **kwargs, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Compute reverse KL divergence: KL(q||p) where q is the predicted distribution (student) and p is the target (teacher). @@ -383,7 +381,6 @@ def forward_kl_forward_backward( teacher_softmax_temperature: float = 1.0, target_format: TargetFormat = TargetFormat.labels, sequence_parallel_logits: bool = False, - **kwargs, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Compute forward KL divergence: KL(p||q) where p is the target distribution (teacher) and q is the predicted (student). @@ -418,7 +415,6 @@ def forward_kl_forward_backward( group=group, teacher_softmax_temperature=teacher_softmax_temperature, return_target_entropy=True, - **kwargs, ) distillation_loss -= teacher_entropy diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 40c099617..bce20c83f 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -182,14 +182,10 @@ def _get_targets(self, kwargs: dict) -> Targets | None: dpo_target, reference_model_logits, loss_mask, - chosen_spans, - rejected_spans, dpo_reference_model_logits, - ) = (None, None, None, None, None, None, None) + ) = (None, None, None, None, None) if self._config.enable_dpo: dpo_target = kwargs.get(LanguageModelKwargs.labels) - chosen_spans = kwargs.get(LanguageModelKwargs.chosen_spans) - rejected_spans = kwargs.get(LanguageModelKwargs.rejected_spans) dpo_reference_model_logits = (kwargs.get(f"{self._config.dpo_reference_model}_logits"),) else: if self._config.distillation_model is not None: @@ -230,8 +226,6 @@ def _get_targets(self, kwargs: dict) -> Targets | None: dpo_target=dpo_target, lm_target=lm_target, loss_mask=loss_mask, - chosen_spans=chosen_spans, - rejected_spans=rejected_spans, reference_model_logits=reference_model_logits, dpo_reference_model_logits=dpo_reference_model_logits, ) @@ -302,8 +296,6 @@ def _logits_cross_entropy_forward_backward_split( dpo_target=dpo_target_, reference_model_logits=reference_model_logits_, loss_mask=loss_mask_, - chosen_spans=targets.chosen_spans, - rejected_spans=targets.rejected_spans, dpo_reference_model_logits=targets.dpo_reference_model_logits, ) loss_, grad_ = self._logits_loss_forward_backward( @@ -390,6 +382,7 @@ def _logits_loss_forward_backward( group=group, logits_scale_factor=self._config.logits_scale_factor, vocab_parallel=self._vocab_parallel, + kwargs=kwargs, ) loss_ = loss_unscaled_ * loss_config.weight * self._loss_coefficient diff --git a/fast_llm/layers/language_model/lm_head_losses.py b/fast_llm/layers/language_model/lm_head_losses.py index dc367be65..4be129a28 100644 --- a/fast_llm/layers/language_model/lm_head_losses.py +++ b/fast_llm/layers/language_model/lm_head_losses.py @@ -34,8 +34,6 @@ class Targets: lm_target: "torch.Tensor | None" = None dpo_target: "torch.Tensor | None" = None loss_mask: "torch.Tensor | None" = None - chosen_spans: list[list[tuple[int, int]]] | None = None - rejected_spans: list[list[tuple[int, int]]] | None = None reference_model_logits: "torch.Tensor | None" = None dpo_reference_model_logits: "torch.Tensor | None" = None @@ -64,12 +62,12 @@ class LanguageModelLossConfig(Config): def compute_loss( self, logits: "torch.Tensor", - target: Targets, + targets: Targets, grad_output: float | None = None, group: "ProcessGroup" = None, logits_scale_factor: float | None = None, vocab_parallel: bool = False, - **kwargs, + kwargs: dict | None = None, ) -> "tuple[torch.Tensor, torch.Tensor | None]": pass @@ -119,7 +117,7 @@ def compute_loss( group: "ProcessGroup" = None, logits_scale_factor: float | None = None, vocab_parallel: bool = False, - **kwargs, + kwargs: dict | None = None, ) -> "tuple[torch.Tensor, torch.Tensor | None]": from fast_llm.functional.cross_entropy import cross_entropy_forward_backward @@ -145,7 +143,6 @@ def compute_loss( logits_scale_factor=logits_scale_factor, teacher_softmax_temperature=self.teacher_softmax_temperature, target_format=TargetFormat.labels, - **kwargs, ) @@ -170,7 +167,8 @@ def compute_loss( grad_output: float | None = None, group: "ProcessGroup" = None, logits_scale_factor: float | None = None, - **kwargs, + vocab_parallel: bool = False, + kwargs: dict | None = None, ) -> "tuple[torch.Tensor, torch.Tensor | None]": from fast_llm.functional.cross_entropy import forward_kl_forward_backward @@ -187,7 +185,6 @@ def compute_loss( logits_scale_factor=logits_scale_factor, teacher_softmax_temperature=self.teacher_softmax_temperature, target_format=TargetFormat.logits, - **kwargs, ) @@ -212,7 +209,8 @@ def compute_loss( grad_output: float | None = None, group: "ProcessGroup" = None, logits_scale_factor: float | None = None, - **kwargs, + vocab_parallel: bool = False, + kwargs: dict | None = None, ) -> "tuple[torch.Tensor, torch.Tensor | None]": from fast_llm.functional.cross_entropy import reverse_kl_forward_backward @@ -230,7 +228,6 @@ def compute_loss( logits_scale_factor=logits_scale_factor, teacher_softmax_temperature=self.teacher_softmax_temperature, target_format=TargetFormat.logits, - **kwargs, ) @@ -254,16 +251,22 @@ def compute_loss( targets: Targets, grad_output: float | None = None, group: "ProcessGroup" = None, - **kwargs, + logits_scale_factor: float | None = None, + vocab_parallel: bool = False, + kwargs: dict | None = None, ) -> "tuple[torch.Tensor, torch.Tensor | None]": from fast_llm.functional.dpo import compute_dpo_loss + from fast_llm.layers.language_model.config import LanguageModelKwargs + + chosen_spans = kwargs.get(LanguageModelKwargs.chosen_spans) + rejected_spans = kwargs.get(LanguageModelKwargs.rejected_spans) return compute_dpo_loss( logits=logits, targets=targets.dpo_target, reference_model_logits=targets.dpo_reference_model_logits, - chosen_spans=targets.chosen_spans, - rejected_spans=targets.rejected_spans, + chosen_spans=chosen_spans, + rejected_spans=rejected_spans, beta=self.beta, grad_output=grad_output, ) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 6cda07ad0..f3d4659cd 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -247,7 +247,7 @@ def _update_and_add_testing_config( "head": { "output_weight": init_1, "losses": { - "lm_loss": {"type": "cross_entropy", "factor": 1.0}, + "lm_loss": {"type": "cross_entropy", "weight": 1.0}, }, }, "hidden_size": 256, From 406d0a2eaf355488a699220ad4198371585effa2 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 30 Dec 2025 20:13:50 +0000 Subject: [PATCH 19/29] Removed Targets class Removed the targets, class, moved tragets processing to losses, made loss masks more explicit --- fast_llm/layers/language_model/config.py | 17 +- fast_llm/layers/language_model/embedding.py | 3 +- fast_llm/layers/language_model/head.py | 139 ++++++----------- fast_llm/layers/language_model/kwargs.py | 23 +++ .../layers/language_model/lm_head_losses.py | 147 +++++++++++++----- fast_llm/models/gpt/model.py | 2 +- fast_llm/models/multimodal/model.py | 2 +- tests/layers/test_lm_head.py | 3 +- 8 files changed, 185 insertions(+), 151 deletions(-) create mode 100644 fast_llm/layers/language_model/kwargs.py diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 4bd8a592c..9f6cbf4ca 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -5,7 +5,7 @@ from fast_llm.engine.config_utils.parameter import OptionalParameterConfig, ParameterConfig, combine_lr_scales from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.layers.block.config import BlockConfig, BlockKwargs, BlockSequenceConfig +from fast_llm.layers.block.config import BlockConfig, BlockSequenceConfig from fast_llm.layers.common.normalization.config import NormalizationConfig from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.config import DecoderBlockConfig @@ -19,21 +19,6 @@ from fast_llm.layers.language_model.multi_token_prediction import MultiTokenPrediction -class LanguageModelKwargs(BlockKwargs): - token_ids = "token_ids" - position_ids = "position_ids" - token_map = "token_map" - sample_map = "sample_map" - embedding_map = "embedding_map" - # TODO: These are generic - labels = "labels" - phase = "phase" - chosen_spans = "chosen_spans" - rejected_spans = "rejected_spans" - loss_mask = "loss_mask" - mask_inputs = "mask_inputs" - - @config_class() class LanguageModelEmbeddingsConfig(BlockConfig): _abstract = False diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 93850d24c..fda5e3387 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -10,7 +10,8 @@ from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.layers.block.block import Block from fast_llm.layers.common.peft.config import PeftConfig -from fast_llm.layers.language_model.config import LanguageModelEmbeddingsConfig, LanguageModelKwargs +from fast_llm.layers.language_model.config import LanguageModelEmbeddingsConfig +from fast_llm.layers.language_model.kwargs import LanguageModelKwargs from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index bce20c83f..27b090c1f 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -22,9 +22,9 @@ LanguageModelEmbeddingsConfig, LanguageModelHeadBaseConfig, LanguageModelHeadConfig, - LanguageModelKwargs, ) -from fast_llm.layers.language_model.lm_head_losses import Targets, _format_name +from fast_llm.layers.language_model.kwargs import LanguageModelKwargs +from fast_llm.layers.language_model.lm_head_losses import _format_name from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert, div, get_unique @@ -101,10 +101,12 @@ def __init__( peft=self._peft, ) - self._formatted_loss_names = { - loss_name: loss_config.get_formatted_name(loss_name, self._prediction_distance) - for loss_name, loss_config in self._config.losses.items() - } + self._formatted_loss_names = {} + for loss_name, loss_config in self._config.losses.items(): + if loss_config.weight > 0.0: + self._formatted_loss_names[loss_name] = loss_config.get_formatted_name( + loss_name, self._prediction_distance + ) def forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None @@ -154,6 +156,12 @@ def _forward_backward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None ) -> tuple[torch.Tensor, torch.Tensor | None]: targets = self._get_targets(kwargs) + loss_mask = kwargs.get(LanguageModelKwargs.loss_mask) + if loss_mask is not None: + loss_mask = loss_mask.flatten() + if self._sequence_parallel_logits: + loss_mask = split_op(loss_mask, self._parallel_dim.group, 0) + input_ = input_.detach().requires_grad_(do_grad := targets is not None and self.training) with torch.enable_grad(): ln_output = self.final_norm(input_) @@ -167,7 +175,7 @@ def _forward_backward( output_weights = self.output_weights loss, ln_output_grad = self._logits_cross_entropy_forward_backward_split( - ln_output.detach(), targets, output_weights, grad_output, kwargs, losses + ln_output.detach(), targets, loss_mask, output_weights, grad_output, kwargs, losses ) if do_grad: @@ -176,62 +184,20 @@ def _forward_backward( else: return loss, None - def _get_targets(self, kwargs: dict) -> Targets | None: - ( - lm_target, - dpo_target, - reference_model_logits, - loss_mask, - dpo_reference_model_logits, - ) = (None, None, None, None, None) - if self._config.enable_dpo: - dpo_target = kwargs.get(LanguageModelKwargs.labels) - dpo_reference_model_logits = (kwargs.get(f"{self._config.dpo_reference_model}_logits"),) - else: - if self._config.distillation_model is not None: - # Target is reference model logits. - reference_model_logits = kwargs[f"{self._config.distillation_model}_logits"].flatten(0, -2) - loss_mask = kwargs.get(LanguageModelKwargs.loss_mask) - if loss_mask is not None: - loss_mask = loss_mask.flatten() - - lm_target = kwargs.get(LanguageModelKwargs.labels) - if lm_target is not None: - # MTP: Shift the labels - lm_target_sequence_length = ( - lm_target.size(1 - kwargs[LanguageModelKwargs.sequence_first]) + 1 - self._prediction_heads - ) - if LanguageModelKwargs.sequence_q_dim in kwargs: - Assert.eq(lm_target_sequence_length, kwargs[LanguageModelKwargs.sequence_q_dim].size) - lm_target_slice = slice( - self._prediction_distance, self._prediction_distance + lm_target_sequence_length - ) - lm_target = ( - lm_target[lm_target_slice] - if kwargs[LanguageModelKwargs.sequence_first] - else lm_target[:, lm_target_slice] - ).flatten() - - if self._sequence_parallel_logits: - if dpo_target is not None: - dpo_target = split_op(dpo_target, self._parallel_dim.group, 0) - if lm_target is not None: - lm_target = split_op(lm_target, self._parallel_dim.group, 0) - if loss_mask is not None: - loss_mask = split_op(loss_mask, self._parallel_dim.group, 0) - if reference_model_logits is not None: - reference_model_logits = split_op(reference_model_logits, self._parallel_dim.group, 0) - - targets = Targets( - dpo_target=dpo_target, - lm_target=lm_target, - loss_mask=loss_mask, - reference_model_logits=reference_model_logits, - dpo_reference_model_logits=dpo_reference_model_logits, - ) - - # Return None if no targets are set - if not targets.has_any_target(): + def _get_targets(self, kwargs: dict) -> dict | None: + targets = {} + for loss_config in self._config.losses.values(): + if loss_config.weight == 0.0: + continue + loss_targets = loss_config.extract_targets_from_global_kwargs( + kwargs, + prediction_distance=self._prediction_distance, + prediction_heads=self._prediction_heads, + head_config=self._config, + sequence_parallel_logits=self._sequence_parallel_logits, + ) + targets.update({k: v for k, v in loss_targets.items() if v is not None}) + if len(targets) == 0: return None return targets @@ -241,15 +207,16 @@ def get_output_weights(self) -> list[torch.Tensor]: def _logits_cross_entropy_forward_backward_split( self, input_: torch.Tensor, - targets: Targets | None, + targets: dict[str, "torch.Tensor"] | None, + loss_mask: torch.Tensor | None, weight: torch.Tensor, grad_output: float, kwargs: dict, losses: dict | None = None, ) -> tuple[torch.Tensor | None, torch.Tensor | None]: - if self._config.cross_entropy_splits is None or targets is None: + if self._config.cross_entropy_splits is None: loss, logit_input_grad = self._logits_loss_forward_backward( - input_, targets, weight, grad_output, kwargs, losses + input_, targets, loss_mask, weight, grad_output, kwargs, losses ) if targets is None: # TODO: Make a proper way of returning the model output. @@ -273,34 +240,28 @@ def _logits_cross_entropy_forward_backward_split( else: logit_input_grad = None - # Extract target tensors for splitting (keep same order as original tuple) - target_tensors = [ - targets.lm_target, - targets.dpo_target, - targets.reference_model_logits, - targets.loss_mask, - ] split_size = div( - get_unique(target.size(0) for target in target_tensors if target is not None), + get_unique(target.size(0) for target in targets.values() if target is not None), self._config.cross_entropy_splits, ) tensors_split = [ [None] * self._config.cross_entropy_splits if tensor is None else tensor.split(split_size) - for tensor in [logit_input, *target_tensors, logit_input_grad] + for tensor in [logit_input, loss_mask, logit_input_grad] ] - for logit_input_, lm_target_, dpo_target_, reference_model_logits_, loss_mask_, logit_input_grad_ in zip( - *tensors_split, strict=True - ): - targets_ = Targets( - lm_target=lm_target_, - dpo_target=dpo_target_, - reference_model_logits=reference_model_logits_, - loss_mask=loss_mask_, - dpo_reference_model_logits=targets.dpo_reference_model_logits, + target_split = { + name: ( + [None] * self._config.cross_entropy_splits + if targets[name] is None + else targets[name].split(split_size) ) + for name in targets + } + + for i, (logit_input_, loss_mask_, logit_input_grad_) in enumerate(zip(*tensors_split, strict=True)): loss_, grad_ = self._logits_loss_forward_backward( logit_input_, - targets_, + {name: target_split[name][i] for name in target_split}, + loss_mask_, weight, grad_output, kwargs, @@ -323,7 +284,8 @@ def _logits_cross_entropy_forward_backward_split( def _logits_loss_forward_backward( self, input_: torch.Tensor, - targets: Targets | None, + targets: dict[str, "torch.Tensor"] | None, + loss_mask: torch.Tensor | None, weight: torch.Tensor, grad_output: float, kwargs: dict, @@ -370,10 +332,9 @@ def _logits_loss_forward_backward( if loss_config.weight == 0.0: continue # losses are returned unscaled but the grads are already scaled - # we log unscaled losses seperately and the scaled total loss loss_unscaled_, grad_ = loss_config.compute_loss( logits, - targets, + loss_mask, grad_output=( (grad_output * self._loss_coefficient * loss_config.weight if grad_output is not None else None) if loss_config.weight != 0.0 @@ -382,7 +343,7 @@ def _logits_loss_forward_backward( group=group, logits_scale_factor=self._config.logits_scale_factor, vocab_parallel=self._vocab_parallel, - kwargs=kwargs, + kwargs={**kwargs, **targets}, ) loss_ = loss_unscaled_ * loss_config.weight * self._loss_coefficient diff --git a/fast_llm/layers/language_model/kwargs.py b/fast_llm/layers/language_model/kwargs.py new file mode 100644 index 000000000..4f6203881 --- /dev/null +++ b/fast_llm/layers/language_model/kwargs.py @@ -0,0 +1,23 @@ +from fast_llm.layers.block.config import BlockKwargs + + +class TargetsKwargs: + lm_target = "preprocessed_lm_target" + dpo_target = "preprocessed_dpo_target" + reference_model_logits = "reference_model_logits" + dpo_reference_model_logits = "dpo_reference_model_logits" + + +class LanguageModelKwargs(BlockKwargs): + token_ids = "token_ids" + position_ids = "position_ids" + token_map = "token_map" + sample_map = "sample_map" + embedding_map = "embedding_map" + # TODO: These are generic + labels = "labels" + phase = "phase" + chosen_spans = "chosen_spans" + rejected_spans = "rejected_spans" + loss_mask = "loss_mask" + mask_inputs = "mask_inputs" diff --git a/fast_llm/layers/language_model/lm_head_losses.py b/fast_llm/layers/language_model/lm_head_losses.py index 4be129a28..088e55042 100644 --- a/fast_llm/layers/language_model/lm_head_losses.py +++ b/fast_llm/layers/language_model/lm_head_losses.py @@ -1,18 +1,20 @@ import abc -import dataclasses import logging import typing from fast_llm.config import Config, Field, FieldHint, check_field, config_class +from fast_llm.core.ops import split_op from fast_llm.engine.base_model.config import LossDef from fast_llm.engine.config_utils.data_type import DataType from fast_llm.functional.config import CrossEntropyImpl, TargetFormat, TritonConfig +from fast_llm.layers.language_model.kwargs import LanguageModelKwargs, TargetsKwargs from fast_llm.utils import Assert if typing.TYPE_CHECKING: import torch from fast_llm.core.distributed import ProcessGroup + from fast_llm.layers.language_model.config import LanguageModelHeadConfig logger = logging.getLogger(__name__) @@ -29,23 +31,10 @@ def _format_name(name: str) -> str: return name.replace("_", " ") -@dataclasses.dataclass -class Targets: - lm_target: "torch.Tensor | None" = None - dpo_target: "torch.Tensor | None" = None - loss_mask: "torch.Tensor | None" = None - reference_model_logits: "torch.Tensor | None" = None - dpo_reference_model_logits: "torch.Tensor | None" = None - - def has_any_target(self) -> bool: - return any(getattr(self, field.name) is not None for field in dataclasses.fields(self)) - - @config_class(registry=True) class LanguageModelLossConfig(Config): """ - Losses canm register themselves - using @config_class(dynamic_type= {LanguageModelLossConfig: "loss_type_name"}) + Losses can register themselves using @config_class(dynamic_type= {LanguageModelLossConfig: "loss_type_name"}). """ _name: typing.ClassVar[str] @@ -62,7 +51,7 @@ class LanguageModelLossConfig(Config): def compute_loss( self, logits: "torch.Tensor", - targets: Targets, + loss_mask: "torch.Tensor | None", grad_output: float | None = None, group: "ProcessGroup" = None, logits_scale_factor: float | None = None, @@ -90,6 +79,18 @@ def get_formatted_name(self, name=None, prediction_distance: int | None = None) name = f"{name}_{prediction_distance}" return name + @abc.abstractmethod + def extract_targets_from_global_kwargs( + self, + kwargs: dict | None = None, + prediction_distance: int | None = None, + prediction_heads: int | None = None, + head_config: "LanguageModelHeadConfig | None" = None, + sequence_parallel_logits: bool | None = None, + group: "ProcessGroup" = None, + ) -> dict[str, "torch.Tensor"]: + pass + @config_class(dynamic_type={LanguageModelLossConfig: "cross_entropy"}) class CrossEntropyLMLossConfig(LanguageModelLossConfig): @@ -109,10 +110,40 @@ class CrossEntropyLMLossConfig(LanguageModelLossConfig): valid=check_field(Assert.gt, 0.0), ) + def extract_targets_from_global_kwargs( + self, + kwargs: dict | None = None, + prediction_distance: int | None = None, + prediction_heads: int | None = None, + head_config: "LanguageModelHeadConfig | None" = None, + sequence_parallel_logits: bool | None = None, + group: "ProcessGroup" = None, + ) -> dict[str, "torch.Tensor"]: + if kwargs is None: + kwargs = {} + + lm_target = kwargs.get(LanguageModelKwargs.labels) + if lm_target is not None: + # MTP: Shift the labels + lm_target_sequence_length = ( + lm_target.size(1 - kwargs[LanguageModelKwargs.sequence_first]) + 1 - prediction_heads + ) + if LanguageModelKwargs.sequence_q_dim in kwargs: + Assert.eq(lm_target_sequence_length, kwargs[LanguageModelKwargs.sequence_q_dim].size) + lm_target_slice = slice(prediction_distance, prediction_distance + lm_target_sequence_length) + lm_target = ( + lm_target[lm_target_slice] + if kwargs[LanguageModelKwargs.sequence_first] + else lm_target[:, lm_target_slice] + ).flatten() + if sequence_parallel_logits: + lm_target = split_op(lm_target, group, 0) + return {TargetsKwargs.lm_target: lm_target} + def compute_loss( self, logits: "torch.Tensor", - targets: Targets, + loss_mask: "torch.Tensor | None", grad_output: float | None = None, group: "ProcessGroup" = None, logits_scale_factor: float | None = None, @@ -121,9 +152,7 @@ def compute_loss( ) -> "tuple[torch.Tensor, torch.Tensor | None]": from fast_llm.functional.cross_entropy import cross_entropy_forward_backward - target = targets.lm_target - if target is None: - raise ValueError("CrossEntropyLoss requires lm_target to be set in Targets") + target = kwargs.get(TargetsKwargs.lm_target) implementation = self.implementation if implementation == CrossEntropyImpl.auto: if vocab_parallel: @@ -160,10 +189,29 @@ class ForwardKLLossConfig(LanguageModelLossConfig): valid=check_field(Assert.gt, 0.0), ) + def extract_targets_from_global_kwargs( + self, + kwargs: dict | None = None, + prediction_distance: int | None = None, + prediction_heads: int | None = None, + head_config: "LanguageModelHeadConfig | None" = None, + sequence_parallel_logits: bool | None = None, + group: "ProcessGroup" = None, + ) -> dict[str, "torch.Tensor"]: + if kwargs is None: + kwargs = {} + + reference_model_logits = kwargs.get(f"{head_config.distillation_model}_logits") + if reference_model_logits is not None: + reference_model_logits = reference_model_logits.flatten(0, -2) + if sequence_parallel_logits: + reference_model_logits = split_op(reference_model_logits, group, 0) + return {TargetsKwargs.reference_model_logits: reference_model_logits} + def compute_loss( self, logits: "torch.Tensor", - targets: Targets, + loss_mask: "torch.Tensor | None", grad_output: float | None = None, group: "ProcessGroup" = None, logits_scale_factor: float | None = None, @@ -172,14 +220,12 @@ def compute_loss( ) -> "tuple[torch.Tensor, torch.Tensor | None]": from fast_llm.functional.cross_entropy import forward_kl_forward_backward - target = targets.reference_model_logits - if target is None: - raise ValueError("ForwardKLLoss requires distillation_target to be set in Targets") + target = kwargs.get(TargetsKwargs.reference_model_logits) return forward_kl_forward_backward( logits=logits.flatten(0, -2), target=target, - loss_mask=targets.loss_mask, + loss_mask=loss_mask, grad_output=grad_output, group=group, logits_scale_factor=logits_scale_factor, @@ -189,23 +235,16 @@ def compute_loss( @config_class(dynamic_type={LanguageModelLossConfig: "reverse_kl_distillation"}) -class ReverseKLLossConfig(LanguageModelLossConfig): +class ReverseKLLossConfig(ForwardKLLossConfig): """Reverse KL divergence KL(q||p) for distillation (mode-seeking).""" _name: typing.ClassVar[str] = "RevKL" _abstract: typing.ClassVar[bool] = False - teacher_softmax_temperature: float = Field( - default=1.0, - hint=FieldHint.optional, - desc="Temperature for teacher softmax.", - valid=check_field(Assert.gt, 0.0), - ) - def compute_loss( self, logits: "torch.Tensor", - targets: Targets, + loss_mask: "torch.Tensor | None", grad_output: float | None = None, group: "ProcessGroup" = None, logits_scale_factor: float | None = None, @@ -215,14 +254,12 @@ def compute_loss( from fast_llm.functional.cross_entropy import reverse_kl_forward_backward # Use distillation_target for KL losses - target = targets.reference_model_logits - if target is None: - raise ValueError("ReverseKLLoss requires distillation_target to be set in Targets") + target = kwargs.get(TargetsKwargs.reference_model_logits) return reverse_kl_forward_backward( logits=logits.flatten(0, -2), target=target, - loss_mask=targets.loss_mask, + loss_mask=loss_mask, grad_output=grad_output, group=group, logits_scale_factor=logits_scale_factor, @@ -245,10 +282,35 @@ class DPOLossConfig(LanguageModelLossConfig): valid=check_field(Assert.gt, 0.0), ) + def extract_targets_from_global_kwargs( + self, + kwargs: dict | None = None, + prediction_distance: int | None = None, + prediction_heads: int | None = None, + head_config: "LanguageModelHeadConfig | None" = None, + sequence_parallel_logits: bool | None = None, + group: "ProcessGroup" = None, + ) -> dict[str, "torch.Tensor"]: + if kwargs is None: + kwargs = {} + + reference_model_logits = kwargs.get(f"{head_config.dpo_reference_model}_logits") + dpo_target = kwargs.get(LanguageModelKwargs.labels) + if reference_model_logits is not None: + reference_model_logits = reference_model_logits.flatten(0, -2) + if sequence_parallel_logits: + reference_model_logits = split_op(reference_model_logits, group, 0) + if dpo_target is not None: + dpo_target = split_op(dpo_target, group, 0) + return { + TargetsKwargs.dpo_reference_model_logits: reference_model_logits, + TargetsKwargs.dpo_target: dpo_target, + } + def compute_loss( self, logits: "torch.Tensor", - targets: Targets, + loss_mask: "torch.Tensor | None", grad_output: float | None = None, group: "ProcessGroup" = None, logits_scale_factor: float | None = None, @@ -256,15 +318,16 @@ def compute_loss( kwargs: dict | None = None, ) -> "tuple[torch.Tensor, torch.Tensor | None]": from fast_llm.functional.dpo import compute_dpo_loss - from fast_llm.layers.language_model.config import LanguageModelKwargs + dpo_target = kwargs.get(TargetsKwargs.dpo_target) + dpo_reference_model_logits = kwargs.get(TargetsKwargs.dpo_reference_model_logits) chosen_spans = kwargs.get(LanguageModelKwargs.chosen_spans) rejected_spans = kwargs.get(LanguageModelKwargs.rejected_spans) return compute_dpo_loss( logits=logits, - targets=targets.dpo_target, - reference_model_logits=targets.dpo_reference_model_logits, + targets=dpo_target, + reference_model_logits=dpo_reference_model_logits, chosen_spans=chosen_spans, rejected_spans=rejected_spans, beta=self.beta, diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 2f43d1e41..846c65646 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -12,7 +12,7 @@ from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.block.config import BlockDimNames, BlockKwargs -from fast_llm.layers.language_model.config import LanguageModelKwargs +from fast_llm.layers.language_model.kwargs import LanguageModelKwargs from fast_llm.layers.language_model.language_model import LanguageModel from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron diff --git a/fast_llm/models/multimodal/model.py b/fast_llm/models/multimodal/model.py index 890d5760e..88da79e65 100644 --- a/fast_llm/models/multimodal/model.py +++ b/fast_llm/models/multimodal/model.py @@ -10,7 +10,7 @@ from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.block.config import BlockDimNames, BlockKwargs -from fast_llm.layers.language_model.config import LanguageModelKwargs +from fast_llm.layers.language_model.kwargs import LanguageModelKwargs from fast_llm.layers.vision.config import VisionKwargs from fast_llm.layers.vision.vision_encoder import VisionMultiModalModel from fast_llm.models.gpt.config import GPTBatchConfig diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 7f9e55b79..ed639db93 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -7,8 +7,9 @@ from fast_llm.engine.config_utils.data_type import DataType from fast_llm.functional.config import CrossEntropyImpl from fast_llm.layers.attention.config import AttentionKwargs -from fast_llm.layers.language_model.config import LanguageModelHeadConfig, LanguageModelKwargs +from fast_llm.layers.language_model.config import LanguageModelHeadConfig from fast_llm.layers.language_model.head import LanguageModelHead +from fast_llm.layers.language_model.kwargs import LanguageModelKwargs from fast_llm.layers.language_model.lm_head_losses import LanguageModelLossConfig from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.utils import Assert From f25380a191fd53bdc0427bc3592c3a026ad3fd22 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 30 Dec 2025 20:39:22 +0000 Subject: [PATCH 20/29] fixes --- fast_llm/layers/language_model/head.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 27b090c1f..cb2312d75 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -195,6 +195,7 @@ def _get_targets(self, kwargs: dict) -> dict | None: prediction_heads=self._prediction_heads, head_config=self._config, sequence_parallel_logits=self._sequence_parallel_logits, + group=self._parallel_dim.group, ) targets.update({k: v for k, v in loss_targets.items() if v is not None}) if len(targets) == 0: @@ -240,8 +241,14 @@ def _logits_cross_entropy_forward_backward_split( else: logit_input_grad = None + # Collect all tensors that need to be split to determine the split size + tensors_to_check = [logit_input] + if loss_mask is not None: + tensors_to_check.append(loss_mask) + tensors_to_check.extend(target for target in targets.values() if target is not None) + split_size = div( - get_unique(target.size(0) for target in targets.values() if target is not None), + get_unique(tensor.size(0) for tensor in tensors_to_check), self._config.cross_entropy_splits, ) tensors_split = [ From 8adb7ddb9da22eba3f9a4e8a3cbff0e86ca2f214 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 30 Dec 2025 20:51:52 +0000 Subject: [PATCH 21/29] imports --- .../layers/language_model/lm_head_losses.py | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/fast_llm/layers/language_model/lm_head_losses.py b/fast_llm/layers/language_model/lm_head_losses.py index 088e55042..f6e69b4fa 100644 --- a/fast_llm/layers/language_model/lm_head_losses.py +++ b/fast_llm/layers/language_model/lm_head_losses.py @@ -3,7 +3,6 @@ import typing from fast_llm.config import Config, Field, FieldHint, check_field, config_class -from fast_llm.core.ops import split_op from fast_llm.engine.base_model.config import LossDef from fast_llm.engine.config_utils.data_type import DataType from fast_llm.functional.config import CrossEntropyImpl, TargetFormat, TritonConfig @@ -137,6 +136,8 @@ def extract_targets_from_global_kwargs( else lm_target[:, lm_target_slice] ).flatten() if sequence_parallel_logits: + from fast_llm.core.ops import split_op + lm_target = split_op(lm_target, group, 0) return {TargetsKwargs.lm_target: lm_target} @@ -205,6 +206,8 @@ def extract_targets_from_global_kwargs( if reference_model_logits is not None: reference_model_logits = reference_model_logits.flatten(0, -2) if sequence_parallel_logits: + from fast_llm.core.ops import split_op + reference_model_logits = split_op(reference_model_logits, group, 0) return {TargetsKwargs.reference_model_logits: reference_model_logits} @@ -296,12 +299,15 @@ def extract_targets_from_global_kwargs( reference_model_logits = kwargs.get(f"{head_config.dpo_reference_model}_logits") dpo_target = kwargs.get(LanguageModelKwargs.labels) - if reference_model_logits is not None: - reference_model_logits = reference_model_logits.flatten(0, -2) - if sequence_parallel_logits: - reference_model_logits = split_op(reference_model_logits, group, 0) - if dpo_target is not None: - dpo_target = split_op(dpo_target, group, 0) + if reference_model_logits is not None or dpo_target is not None: + from fast_llm.core.ops import split_op + + if reference_model_logits is not None: + reference_model_logits = reference_model_logits.flatten(0, -2) + if sequence_parallel_logits: + reference_model_logits = split_op(reference_model_logits, group, 0) + if dpo_target is not None: + dpo_target = split_op(dpo_target, group, 0) return { TargetsKwargs.dpo_reference_model_logits: reference_model_logits, TargetsKwargs.dpo_target: dpo_target, From 1ce641d85ea418077865a080b4470ff9947fad85 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 6 Jan 2026 20:21:09 +0000 Subject: [PATCH 22/29] polish naming --- fast_llm/layers/language_model/head.py | 6 +++--- fast_llm/layers/language_model/lm_head_losses.py | 15 +++++++++------ 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index cb2312d75..f05da5534 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -102,10 +102,10 @@ def __init__( ) self._formatted_loss_names = {} - for loss_name, loss_config in self._config.losses.items(): + for registered_loss_name, loss_config in self._config.losses.items(): if loss_config.weight > 0.0: - self._formatted_loss_names[loss_name] = loss_config.get_formatted_name( - loss_name, self._prediction_distance + self._formatted_loss_names[registered_loss_name] = loss_config.get_formatted_name( + registered_loss_name, self._prediction_distance ) def forward( diff --git a/fast_llm/layers/language_model/lm_head_losses.py b/fast_llm/layers/language_model/lm_head_losses.py index f6e69b4fa..49dbb3ced 100644 --- a/fast_llm/layers/language_model/lm_head_losses.py +++ b/fast_llm/layers/language_model/lm_head_losses.py @@ -72,8 +72,11 @@ def _validate(self): Assert.geq(self.weight, 0.0) super()._validate() - def get_formatted_name(self, name=None, prediction_distance: int | None = None) -> str: - name = f"{self._name}({name})" + def get_formatted_name(self, registered_loss_name=None, prediction_distance: int | None = None) -> str: + """ + Retruns loss name for logging as '()', e.g. lm_loss(CE_loss), distillation(FwdKL_loss) + """ + name = f"{registered_loss_name}({self._name})" if prediction_distance is not None: name = f"{name}_{prediction_distance}" return name @@ -93,7 +96,7 @@ def extract_targets_from_global_kwargs( @config_class(dynamic_type={LanguageModelLossConfig: "cross_entropy"}) class CrossEntropyLMLossConfig(LanguageModelLossConfig): - _name: typing.ClassVar[str] = "CE" + _name: typing.ClassVar[str] = "CE_loss" _abstract: typing.ClassVar[bool] = False implementation: CrossEntropyImpl = Field( @@ -180,7 +183,7 @@ def compute_loss( class ForwardKLLossConfig(LanguageModelLossConfig): """Forward KL divergence KL(p||q) for distillation (mode-covering).""" - _name: typing.ClassVar[str] = "FwdKL" + _name: typing.ClassVar[str] = "FwdKL_loss" _abstract: typing.ClassVar[bool] = False teacher_softmax_temperature: float = Field( @@ -241,7 +244,7 @@ def compute_loss( class ReverseKLLossConfig(ForwardKLLossConfig): """Reverse KL divergence KL(q||p) for distillation (mode-seeking).""" - _name: typing.ClassVar[str] = "RevKL" + _name: typing.ClassVar[str] = "RevKL_loss" _abstract: typing.ClassVar[bool] = False def compute_loss( @@ -275,7 +278,7 @@ def compute_loss( class DPOLossConfig(LanguageModelLossConfig): """Direct Preference Optimization (DPO) loss for alignment.""" - _name: typing.ClassVar[str] = "DPO" + _name: typing.ClassVar[str] = "DPO_loss" _abstract: typing.ClassVar[bool] = False beta: float = Field( From 95f14afc76b4d3639d45dde7228951ba7de4c666 Mon Sep 17 00:00:00 2001 From: oleksost Date: Thu, 8 Jan 2026 18:44:01 +0000 Subject: [PATCH 23/29] addresseing comments --- fast_llm/functional/cross_entropy.py | 13 +++- fast_llm/layers/language_model/config.py | 78 +++++++++++++------ fast_llm/layers/language_model/head.py | 11 +-- .../layers/language_model/lm_head_losses.py | 54 +++++++++---- tests/layers/test_lm_head.py | 5 +- tests/test_config.py | 8 +- tests/utils/model_configs.py | 2 +- 7 files changed, 109 insertions(+), 62 deletions(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 03f7a88ef..6b0a4e92f 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -98,7 +98,10 @@ def _fused_cross_entropy_forward_backward( logits_norm, exp_logits, sum_exp_logits = _fused_softmax_base(logits, logits_scale_factor, group) if target_format == TargetFormat.logits: - target = _fused_softmax(target, logits_scale_factor / teacher_softmax_temperature, group) + target_logits, exp_logits, sum_exp_target_logits = _fused_softmax_base( + target, logits_scale_factor / teacher_softmax_temperature, group + ) + target = exp_logits / sum_exp_target_logits if target_format == TargetFormat.labels: target = target.unsqueeze(-1) @@ -159,9 +162,11 @@ def _fused_cross_entropy_forward_backward( loss = per_sample_loss.mean() if target_format != TargetFormat.labels and group is not None: all_reduce(loss, op=ReduceOp.AVG, group=group) - if return_target_entropy and target_format == TargetFormat.logits: - # Compute teacher entropy - teacher_log_prob = torch.log(target + 1e-20) + if return_target_entropy: + if target_format == TargetFormat.logits: + teacher_log_prob = target_logits - sum_exp_target_logits.log() + else: + teacher_log_prob = torch.log(target + 1e-20) target_entropy = -(target * teacher_log_prob).sum(dim=-1) if loss_mask is not None: target_entropy = target_entropy * loss_mask.squeeze(-1) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 9f6cbf4ca..a74489005 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -1,5 +1,7 @@ import abc import typing +import warnings +from functools import cached_property from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.engine.config_utils.parameter import OptionalParameterConfig, ParameterConfig, combine_lr_scales @@ -9,7 +11,13 @@ from fast_llm.layers.common.normalization.config import NormalizationConfig from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.config import DecoderBlockConfig -from fast_llm.layers.language_model.lm_head_losses import LanguageModelLossConfig +from fast_llm.layers.language_model.lm_head_losses import ( + CrossEntropyLMLossConfig, + DPOLossConfig, + ForwardKLLossConfig, + LanguageModelLossConfig, + ReverseKLLossConfig, +) from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -151,17 +159,6 @@ class LanguageModelHeadConfig(LanguageModelHeadBaseConfig): hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) - dpo_reference_model: str | None = Field( - default=None, - desc="Name of the reference model to use for dpo.", - hint=FieldHint.feature, - ) - distillation_model: str | None = Field( - default=None, - desc="Name of the reference model to use for knowledge distillation." - "If provided, replace the loss with a distillation loss.", - hint=FieldHint.feature, - ) def get_layer( self, @@ -193,23 +190,37 @@ def layer_class(self) -> "type[LanguageModelHead]": return LanguageModelHead + @classmethod + def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self: + removed_fields = ["distillation_loss_factor", "distillation_model", "language_model_loss_factor"] + for field in removed_fields: + if field in default: + warnings.warn( + f"Field `{field}` has been removed from {cls.__name__}. " + "Loss configuration should now be done via the `losses` field.", + DeprecationWarning, + ) + default.pop(field) + return super()._from_dict(default, strict=strict) + def _validate(self) -> None: with self._set_implicit_default(): if not self.losses: if "losses" not in self._explicit_fields: - self.losses = { - "lm_loss": LanguageModelLossConfig._from_dict( - { - "type": "cross_entropy", - "weight": 1.0, - } - ) - } - for loss_config in self.losses.values(): - if "distillation" in loss_config.type: - assert self.distillation_model is not None, "Distillation loss requires a distillation model." + self.losses = {"lm_loss": CrossEntropyLMLossConfig()} super()._validate() - assert self.dpo_reference_model is None or self.distillation_model is None # currently don't support both + if DPOLossConfig in self._loss_configs: + assert ForwardKLLossConfig not in self._loss_configs.keys() # currently don't support both + assert ReverseKLLossConfig not in self._loss_configs.keys() # currently don't support both + if ForwardKLLossConfig in self._loss_configs.keys() and ReverseKLLossConfig in self._loss_configs.keys(): + assert ( + self._loss_configs[ForwardKLLossConfig].distillation_model + == self._loss_configs[ReverseKLLossConfig].distillation_model + ), "Distillation losses must use the same teacher." + + @cached_property + def _loss_configs(self) -> dict[type, LanguageModelLossConfig]: + return {loss.__class__: loss for loss in self.losses.values()} @property def max_prediction_distance(self) -> int: @@ -217,7 +228,24 @@ def max_prediction_distance(self) -> int: @property def enable_dpo(self) -> bool: - return self.dpo_reference_model is not None + return DPOLossConfig in self._loss_configs.keys() + + @property + def enable_distillation(self) -> bool: + return ForwardKLLossConfig in self._loss_configs.keys() or ReverseKLLossConfig in self._loss_configs.keys() + + @property + def distillation_model(self) -> str | None: + for loss_type in [ForwardKLLossConfig, ReverseKLLossConfig]: + if loss_type in self._loss_configs: + return self._loss_configs[loss_type].distillation_model + return None + + @property + def dpo_reference_model(self) -> str | None: + if DPOLossConfig in self._loss_configs: + return self._loss_configs[DPOLossConfig].dpo_reference_model + return None @config_class(dynamic_type={LanguageModelHeadBaseConfig: "multi_token_prediction"}) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index f05da5534..465984e01 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -67,9 +67,7 @@ def __init__( lr_scale=lr_scale, peft=peft, ) - if prediction_distance > 0 and ( - self._config.distillation_model is not None or self._config.dpo_reference_model is not None - ): + if prediction_distance > 0 and (self._config.enable_dpo or self._config.enable_distillation): raise NotImplementedError("Multi-token prediction not supported with distillation or dpo.") Assert.in_range(prediction_distance, 0, prediction_heads) @@ -189,11 +187,10 @@ def _get_targets(self, kwargs: dict) -> dict | None: for loss_config in self._config.losses.values(): if loss_config.weight == 0.0: continue - loss_targets = loss_config.extract_targets_from_global_kwargs( + loss_targets = loss_config.get_targets( kwargs, prediction_distance=self._prediction_distance, prediction_heads=self._prediction_heads, - head_config=self._config, sequence_parallel_logits=self._sequence_parallel_logits, group=self._parallel_dim.group, ) @@ -339,7 +336,7 @@ def _logits_loss_forward_backward( if loss_config.weight == 0.0: continue # losses are returned unscaled but the grads are already scaled - loss_unscaled_, grad_ = loss_config.compute_loss( + loss_unscaled_, grad_ = loss_config.get_loss( logits, loss_mask, grad_output=( @@ -401,7 +398,7 @@ def get_loss_definitions(self, count: int = 1) -> list[LossDef]: LossDef(name=self._z_loss_name, formatted_name=_format_name(self._z_loss_name), count=count) ) for loss_name, loss_config in self._config.losses.items(): - loss_def: LossDef = loss_config.get_loss_def( + loss_def: LossDef = loss_config.get_loss_definitions( name=loss_name, count=count, prediction_distance=self._prediction_distance ) loss_defs.append(loss_def) diff --git a/fast_llm/layers/language_model/lm_head_losses.py b/fast_llm/layers/language_model/lm_head_losses.py index 49dbb3ced..e1004b5c8 100644 --- a/fast_llm/layers/language_model/lm_head_losses.py +++ b/fast_llm/layers/language_model/lm_head_losses.py @@ -13,7 +13,6 @@ import torch from fast_llm.core.distributed import ProcessGroup - from fast_llm.layers.language_model.config import LanguageModelHeadConfig logger = logging.getLogger(__name__) @@ -46,8 +45,15 @@ class LanguageModelLossConfig(Config): valid=check_field(Assert.geq, 0.0), ) + distillation_model: str | None = Field( + default=None, + desc="Name of the reference model to use for knowledge distillation." + "If provided, replace the loss with a distillation loss.", + hint=FieldHint.feature, + ) + @abc.abstractmethod - def compute_loss( + def get_loss( self, logits: "torch.Tensor", loss_mask: "torch.Tensor | None", @@ -59,7 +65,7 @@ def compute_loss( ) -> "tuple[torch.Tensor, torch.Tensor | None]": pass - def get_loss_def(self, name: str, count: int = 1, prediction_distance: int | None = None) -> LossDef: + def get_loss_definitions(self, name: str, count: int = 1, prediction_distance: int | None = None) -> LossDef: name = self.get_formatted_name(name, prediction_distance) return LossDef( name=name, @@ -82,12 +88,11 @@ def get_formatted_name(self, registered_loss_name=None, prediction_distance: int return name @abc.abstractmethod - def extract_targets_from_global_kwargs( + def get_targets( self, kwargs: dict | None = None, prediction_distance: int | None = None, prediction_heads: int | None = None, - head_config: "LanguageModelHeadConfig | None" = None, sequence_parallel_logits: bool | None = None, group: "ProcessGroup" = None, ) -> dict[str, "torch.Tensor"]: @@ -112,12 +117,11 @@ class CrossEntropyLMLossConfig(LanguageModelLossConfig): valid=check_field(Assert.gt, 0.0), ) - def extract_targets_from_global_kwargs( + def get_targets( self, kwargs: dict | None = None, prediction_distance: int | None = None, prediction_heads: int | None = None, - head_config: "LanguageModelHeadConfig | None" = None, sequence_parallel_logits: bool | None = None, group: "ProcessGroup" = None, ) -> dict[str, "torch.Tensor"]: @@ -144,7 +148,7 @@ def extract_targets_from_global_kwargs( lm_target = split_op(lm_target, group, 0) return {TargetsKwargs.lm_target: lm_target} - def compute_loss( + def get_loss( self, logits: "torch.Tensor", loss_mask: "torch.Tensor | None", @@ -193,19 +197,22 @@ class ForwardKLLossConfig(LanguageModelLossConfig): valid=check_field(Assert.gt, 0.0), ) - def extract_targets_from_global_kwargs( + def _validate(self): + assert self.distillation_model is not None, "Distillation loss required by ForwardKL Loss." + super()._validate() + + def get_targets( self, kwargs: dict | None = None, prediction_distance: int | None = None, prediction_heads: int | None = None, - head_config: "LanguageModelHeadConfig | None" = None, sequence_parallel_logits: bool | None = None, group: "ProcessGroup" = None, ) -> dict[str, "torch.Tensor"]: if kwargs is None: kwargs = {} - reference_model_logits = kwargs.get(f"{head_config.distillation_model}_logits") + reference_model_logits = kwargs.get(f"{self.distillation_model}_logits") if reference_model_logits is not None: reference_model_logits = reference_model_logits.flatten(0, -2) if sequence_parallel_logits: @@ -214,7 +221,7 @@ def extract_targets_from_global_kwargs( reference_model_logits = split_op(reference_model_logits, group, 0) return {TargetsKwargs.reference_model_logits: reference_model_logits} - def compute_loss( + def get_loss( self, logits: "torch.Tensor", loss_mask: "torch.Tensor | None", @@ -247,7 +254,11 @@ class ReverseKLLossConfig(ForwardKLLossConfig): _name: typing.ClassVar[str] = "RevKL_loss" _abstract: typing.ClassVar[bool] = False - def compute_loss( + def _validate(self): + assert self.distillation_model is not None, "Distillation loss required by Reverse KL Loss." + super()._validate() + + def get_loss( self, logits: "torch.Tensor", loss_mask: "torch.Tensor | None", @@ -288,19 +299,28 @@ class DPOLossConfig(LanguageModelLossConfig): valid=check_field(Assert.gt, 0.0), ) - def extract_targets_from_global_kwargs( + dpo_reference_model: str | None = Field( + default=None, + desc="Name of the reference model to use for dpo.", + hint=FieldHint.feature, + ) + + def _validate(self): + assert self.dpo_reference_model is not None, "DPO loss requires a reference model." + super()._validate() + + def get_targets( self, kwargs: dict | None = None, prediction_distance: int | None = None, prediction_heads: int | None = None, - head_config: "LanguageModelHeadConfig | None" = None, sequence_parallel_logits: bool | None = None, group: "ProcessGroup" = None, ) -> dict[str, "torch.Tensor"]: if kwargs is None: kwargs = {} - reference_model_logits = kwargs.get(f"{head_config.dpo_reference_model}_logits") + reference_model_logits = kwargs.get(f"{self.dpo_reference_model}_logits") dpo_target = kwargs.get(LanguageModelKwargs.labels) if reference_model_logits is not None or dpo_target is not None: from fast_llm.core.ops import split_op @@ -316,7 +336,7 @@ def extract_targets_from_global_kwargs( TargetsKwargs.dpo_target: dpo_target, } - def compute_loss( + def get_loss( self, logits: "torch.Tensor", loss_mask: "torch.Tensor | None", diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index ed639db93..f25aba1e7 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -155,7 +155,6 @@ def _lm_head( pytest.param( { "head": { - "distillation_model": "distillation", "losses": { "lm_loss": { "type": "cross_entropy", @@ -164,6 +163,7 @@ def _lm_head( "dist_loss": { "type": "reverse_kl_distillation", "weight": 1.0, + "distillation_model": "distillation", }, }, } @@ -176,7 +176,6 @@ def _lm_head( pytest.param( { "head": { - "distillation_model": "distillation", "losses": { "lm_loss": { "type": "cross_entropy", @@ -185,6 +184,7 @@ def _lm_head( "dist_loss": { "type": "reverse_kl_distillation", "weight": 0.0, + "distillation_model": "distillation", }, }, } @@ -209,6 +209,7 @@ def _lm_head( "dist_loss": { "type": "reverse_kl_distillation", "weight": 1.0, + "distillation_model": "distillation", }, }, } diff --git a/tests/test_config.py b/tests/test_config.py index 3c6a76a35..2e900cb14 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -148,7 +148,7 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): }, "num_blocks": 12, }, - "head": {"losses": {"lm_loss": {"type": "cross_entropy", "weight": 1.0}}}, + "head": {"losses": {"lm_loss": {"type": "cross_entropy"}}}, "hidden_size": 512, "tied_embedding_weight": False, "peft": {"freeze_others": False}, @@ -156,7 +156,7 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): else: expected_config["base_model"] = base_model_update # added by default - expected_config["base_model"]["head"] = {"losses": {"lm_loss": {"type": "cross_entropy", "weight": 1.0}}} + expected_config["base_model"]["head"] = {"losses": {"lm_loss": {"type": "cross_entropy"}}} check_equal_nested(_trim_type(serialized_config), _trim_type(expected_config)) @@ -299,7 +299,3 @@ def test_distributed_global_ranks(bdp: int, sdp: int, tp: int, pp: int, pipeline Assert.eq(len({global_rank for global_ranks in global_ranks_set for global_rank in global_ranks}), world_size) Assert.eq(len(rank_breakdowns), world_size) - - -if __name__ == "__main__": - pytest.main([__file__]) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index f3d4659cd..a9a2e65bf 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -247,7 +247,7 @@ def _update_and_add_testing_config( "head": { "output_weight": init_1, "losses": { - "lm_loss": {"type": "cross_entropy", "weight": 1.0}, + "lm_loss": {"type": "cross_entropy"}, }, }, "hidden_size": 256, From 5ad4c0c98ffc96a58f226376d16a93f77c4e61d2 Mon Sep 17 00:00:00 2001 From: oleksost Date: Thu, 8 Jan 2026 21:59:24 +0000 Subject: [PATCH 24/29] explicit z_loss grads --- fast_llm/layers/common/auxiliary_loss.py | 42 +++++++++----- fast_llm/layers/language_model/head.py | 40 ++++++------- .../layers/language_model/lm_head_losses.py | 36 ++++++++++++ tests/layers/test_lm_head.py | 57 ++++++++++++++----- 4 files changed, 125 insertions(+), 50 deletions(-) diff --git a/fast_llm/layers/common/auxiliary_loss.py b/fast_llm/layers/common/auxiliary_loss.py index 44c2d2088..335debb12 100644 --- a/fast_llm/layers/common/auxiliary_loss.py +++ b/fast_llm/layers/common/auxiliary_loss.py @@ -21,18 +21,34 @@ def calculate_z_loss(logits: torch.Tensor, logits_scale_factor: float = 1.0) -> def z_loss( logits: torch.Tensor, - z_loss_factor: float, - training: bool, grad_scale: float | None = None, - losses: dict | None = None, - loss_name: str | None = None, logits_scale_factor: float = 1.0, -) -> torch.Tensor: - if losses is not None or (training and grad_scale is not None): - loss = calculate_z_loss(logits, logits_scale_factor=logits_scale_factor) - if losses is not None and loss_name is not None: - losses[loss_name].append(loss.detach()) - if training and grad_scale is not None: - logits = AuxiliaryLoss.apply(logits, loss, z_loss_factor * grad_scale) - - return logits +) -> tuple[torch.Tensor, torch.Tensor | None]: + """ + Compute z-loss and its gradient. + + Z-loss = mean(logsumexp(logits, dim=-1) ** 2) + + Returns: + loss: The z-loss value (unscaled) + grad: The gradient w.r.t. logits (scaled by grad_scale), or None if grad_scale is None + """ + if logits_scale_factor != 1.0: + scaled_logits = logits * logits_scale_factor + else: + scaled_logits = logits + + # Forward: z_loss = mean(logsumexp^2) + lse = torch.logsumexp(scaled_logits, dim=-1) # (N,) + loss = torch.mean(lse**2) + + # Backward: grad = (2/N) * lse * softmax(scaled_logits) + grad = None + if grad_scale is not None: + N = scaled_logits.shape[0] + softmax_logits = torch.softmax(scaled_logits, dim=-1) + grad = (2.0 / N) * lse.unsqueeze(-1) * softmax_logits * grad_scale + if logits_scale_factor != 1.0: + grad = grad * logits_scale_factor # Chain rule for logits_scale_factor + + return loss, grad diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 465984e01..f4c38abed 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -16,7 +16,7 @@ from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward from fast_llm.layers.block.block import Block from fast_llm.layers.block.config import BlockDimNames -from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss +from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.language_model.config import ( LanguageModelEmbeddingsConfig, @@ -101,10 +101,9 @@ def __init__( self._formatted_loss_names = {} for registered_loss_name, loss_config in self._config.losses.items(): - if loss_config.weight > 0.0: - self._formatted_loss_names[registered_loss_name] = loss_config.get_formatted_name( - registered_loss_name, self._prediction_distance - ) + self._formatted_loss_names[registered_loss_name] = loss_config.get_formatted_name( + registered_loss_name, self._prediction_distance + ) def forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None @@ -185,8 +184,6 @@ def _forward_backward( def _get_targets(self, kwargs: dict) -> dict | None: targets = {} for loss_config in self._config.losses.values(): - if loss_config.weight == 0.0: - continue loss_targets = loss_config.get_targets( kwargs, prediction_distance=self._prediction_distance, @@ -304,17 +301,17 @@ def _logits_loss_forward_backward( sequence_parallel=self._sequence_parallel and self._vocab_parallel, ) - # TODO: also move to lm_head_losses? - if self._config.logit_z_loss > 0.0: - logits = z_loss( - logits, - self._config.logit_z_loss, - self.training, - grad_output, - losses, - self._z_loss_name, - logits_scale_factor=self._config.logits_scale_factor, - ) + # # TODO: also move to lm_head_losses? + # if self._config.logit_z_loss > 0.0: + # logits = z_loss( + # logits, + # self._config.logit_z_loss, + # self.training, + # grad_output, + # losses, + # self._z_loss_name, + # logits_scale_factor=self._config.logits_scale_factor, + # ) sequence_dim = BlockDimNames.sequence_q_tp if self._sequence_parallel_logits else BlockDimNames.sequence_q if LanguageModelKwargs.hidden_dims in kwargs: @@ -333,8 +330,6 @@ def _logits_loss_forward_backward( total_loss, grad = None, None for loss_name, loss_config in self._config.losses.items(): - if loss_config.weight == 0.0: - continue # losses are returned unscaled but the grads are already scaled loss_unscaled_, grad_ = loss_config.get_loss( logits, @@ -349,6 +344,7 @@ def _logits_loss_forward_backward( vocab_parallel=self._vocab_parallel, kwargs={**kwargs, **targets}, ) + loss_ = loss_unscaled_ * loss_config.weight * self._loss_coefficient if losses is not None: @@ -393,10 +389,6 @@ def get_loss_definitions(self, count: int = 1) -> list[LossDef]: name=self._total_head_loss_name, formatted_name=_format_name(self._total_head_loss_name), count=count ) ] - if self._config.logit_z_loss > 0.0: - loss_defs.append( - LossDef(name=self._z_loss_name, formatted_name=_format_name(self._z_loss_name), count=count) - ) for loss_name, loss_config in self._config.losses.items(): loss_def: LossDef = loss_config.get_loss_definitions( name=loss_name, count=count, prediction_distance=self._prediction_distance diff --git a/fast_llm/layers/language_model/lm_head_losses.py b/fast_llm/layers/language_model/lm_head_losses.py index e1004b5c8..327dee560 100644 --- a/fast_llm/layers/language_model/lm_head_losses.py +++ b/fast_llm/layers/language_model/lm_head_losses.py @@ -362,3 +362,39 @@ def get_loss( beta=self.beta, grad_output=grad_output, ) + + +@config_class(dynamic_type={LanguageModelLossConfig: "z_loss"}) +class ZLossConfig(LanguageModelLossConfig): + """Z-loss regularization to prevent overconfidence.""" + + _name: typing.ClassVar[str] = "Z_loss" + _abstract: typing.ClassVar[bool] = False + + def get_targets( + self, + kwargs: dict | None = None, + prediction_distance: int | None = None, + prediction_heads: int | None = None, + sequence_parallel_logits: bool | None = None, + group: "ProcessGroup" = None, + ) -> dict[str, "torch.Tensor"]: + return {} + + def get_loss( + self, + logits: "torch.Tensor", + loss_mask: "torch.Tensor | None", + grad_output: float | None = None, + group: "ProcessGroup" = None, + logits_scale_factor: float | None = None, + vocab_parallel: bool = False, + kwargs: dict | None = None, + ) -> "tuple[torch.Tensor, torch.Tensor | None]": + from fast_llm.layers.common.auxiliary_loss import z_loss + + return z_loss( + logits=logits.flatten(0, -2), + grad_scale=grad_output, + logits_scale_factor=logits_scale_factor, + ) diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index f25aba1e7..9c81ba0a4 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -69,7 +69,6 @@ def _lm_head( logit_weight: torch.Tensor, grad_output: float = 1.0, logit_scale_factor: float = 1.0, - logit_z_loss=0.0, losses: dict[str, LanguageModelLossConfig], ): hidden = torch.rms_norm( @@ -102,12 +101,31 @@ def _lm_head( if logit_scale_factor != 1.0: logits *= logit_scale_factor - z_loss = torch.mean(torch.logsumexp(logits, dim=-1) ** 2) if logit_z_loss > 0 else None + + # Compute z_loss if configured + if "z_loss" in losses: + z_loss_unscaled = torch.mean(torch.logsumexp(logits, dim=-1) ** 2) + # Backward through z_loss (retain_graph since we need to also backward through ce_loss) + z_loss_unscaled.backward( + torch.full_like(z_loss_unscaled, grad_output * losses["z_loss"].weight), retain_graph=True + ) + z_loss_scaled = z_loss_unscaled * losses["z_loss"].weight + else: + z_loss_unscaled = None + z_loss_scaled = None + # Language model loss (cross-entropy with hard labels) - loss = torch.nn.functional.cross_entropy(logits.flatten(0, -2), target.flatten()) - # Apply language_model_loss_factor - loss.backward(torch.full_like(loss, grad_output * losses["lm_loss"].weight)) - return loss * losses["lm_loss"].weight, z_loss + ce_loss = torch.nn.functional.cross_entropy(logits.flatten(0, -2), target.flatten()) + # Backward through ce_loss + ce_loss.backward(torch.full_like(ce_loss, grad_output * losses["lm_loss"].weight)) + ce_loss_scaled = ce_loss * losses["lm_loss"].weight + + # Total loss = ce_loss + z_loss (both scaled) + total_loss = ce_loss_scaled + if z_loss_scaled is not None: + total_loss = total_loss + z_loss_scaled + + return total_loss, z_loss_unscaled SEQUENCE_LENGTH = 200 @@ -126,7 +144,21 @@ def _lm_head( ({}, {"compute_dtype": DataType.bfloat16}, False, 1), ({"embeddings": {"full_precision_residual": True}}, {"compute_dtype": DataType.bfloat16}, False, 1), ({"sequence_first": True}, {}, False, 1), - ({"head": {"logit_z_loss": 1e-3}}, {}, False, 1), + ( + { + "head": { + "losses": { + "z_loss": { + "type": "z_loss", + "weight": 1e-3, + }, + }, + } + }, + {}, + False, + 1, + ), ({"head": {"logits_scale_factor": 5.0}}, {}, False, 1), ({"tied_embedding_weight": True}, {}, False, 1), ({}, {}, False, 2), @@ -365,7 +397,6 @@ def test_lm_head( rms_weight=ref_rms_weight, logit_weight=ref_logit_weight, logit_scale_factor=head_config.logits_scale_factor, - logit_z_loss=head_config.logit_z_loss, losses=head_config.losses, ) @@ -386,8 +417,8 @@ def test_lm_head( formatted_name = loss_config.get_formatted_name(loss_name, prediction_distance) expected_loss_keys.add(formatted_name) - if ref_z_loss is not None: - expected_loss_keys.add(f"z_loss_{prediction_distance}" if prediction_distance > 0 else "z_loss") + # if ref_z_loss is not None: + # expected_loss_keys.add(f"z_loss_{prediction_distance}" if prediction_distance > 0 else "z_loss") Assert.eq( {loss_definition.name: loss_definition.count for loss_definition in head.get_loss_definitions()}, @@ -404,9 +435,9 @@ def test_lm_head( Assert.eq(losses.keys(), expected_loss_keys) Assert.eq(len(losses[lm_head_loss_name]), 1) - if ref_z_loss is not None: - Assert.eq(len(losses["z_loss"]), 1) - Assert.rms_close_relative(losses["z_loss"][0], ref_z_loss, threshold, min_threshold) + # if ref_z_loss is not None: + # Assert.eq(len(losses["z_loss"]), 1) + # Assert.rms_close_relative(losses["z_loss"][0], ref_z_loss, threshold, min_threshold) Assert.rms_close_relative(losses[lm_head_loss_name][0], ref_loss, threshold, min_threshold) From 0a66e145fe903f03ecf124e46ea70331a04cb8da Mon Sep 17 00:00:00 2001 From: oleksost Date: Thu, 8 Jan 2026 22:03:07 +0000 Subject: [PATCH 25/29] removed z_loss as aux loss --- fast_llm/layers/language_model/head.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index f4c38abed..b3e0e47b6 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -301,18 +301,6 @@ def _logits_loss_forward_backward( sequence_parallel=self._sequence_parallel and self._vocab_parallel, ) - # # TODO: also move to lm_head_losses? - # if self._config.logit_z_loss > 0.0: - # logits = z_loss( - # logits, - # self._config.logit_z_loss, - # self.training, - # grad_output, - # losses, - # self._z_loss_name, - # logits_scale_factor=self._config.logits_scale_factor, - # ) - sequence_dim = BlockDimNames.sequence_q_tp if self._sequence_parallel_logits else BlockDimNames.sequence_q if LanguageModelKwargs.hidden_dims in kwargs: batch_dim = kwargs[LanguageModelKwargs.hidden_dims][1 if kwargs[LanguageModelKwargs.sequence_first] else 0] From f8f70415b5a9c647359b8a9754aca5f13638a927 Mon Sep 17 00:00:00 2001 From: oleksost Date: Thu, 8 Jan 2026 22:14:50 +0000 Subject: [PATCH 26/29] move loss configs to the lm config --- fast_llm/layers/language_model/config.py | 392 ++++++++++++++++- fast_llm/layers/language_model/head.py | 2 +- .../layers/language_model/lm_head_losses.py | 400 ------------------ tests/layers/test_lm_head.py | 3 +- 4 files changed, 386 insertions(+), 411 deletions(-) delete mode 100644 fast_llm/layers/language_model/lm_head_losses.py diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index a74489005..adf8dd86e 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -3,30 +3,406 @@ import warnings from functools import cached_property -from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.engine.base_model.config import LossDef +from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.config_utils.parameter import OptionalParameterConfig, ParameterConfig, combine_lr_scales from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.functional.config import CrossEntropyImpl, TargetFormat, TritonConfig from fast_llm.layers.block.config import BlockConfig, BlockSequenceConfig from fast_llm.layers.common.normalization.config import NormalizationConfig from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.config import DecoderBlockConfig -from fast_llm.layers.language_model.lm_head_losses import ( - CrossEntropyLMLossConfig, - DPOLossConfig, - ForwardKLLossConfig, - LanguageModelLossConfig, - ReverseKLLossConfig, -) +from fast_llm.layers.language_model.kwargs import LanguageModelKwargs, TargetsKwargs from fast_llm.utils import Assert if typing.TYPE_CHECKING: + import torch + + from fast_llm.core.distributed import ProcessGroup from fast_llm.layers.language_model.embedding import LanguageModelEmbedding from fast_llm.layers.language_model.head import LanguageModelHead, LanguageModelHeadBase from fast_llm.layers.language_model.language_model import LanguageModel from fast_llm.layers.language_model.multi_token_prediction import MultiTokenPrediction +def _format_name(name: str) -> str: + return name.replace("_", " ") + + +@config_class(registry=True) +class LanguageModelLossConfig(Config): + """ + Losses can register themselves using @config_class(dynamic_type= {LanguageModelLossConfig: "loss_type_name"}). + """ + + _name: typing.ClassVar[str] + _abstract: typing.ClassVar[bool] = True + + weight: float = Field( + default=1.0, + hint=FieldHint.core, + desc="Weight for this loss in the total loss computation.", + valid=check_field(Assert.geq, 0.0), + ) + + distillation_model: str | None = Field( + default=None, + desc="Name of the reference model to use for knowledge distillation." + "If provided, replace the loss with a distillation loss.", + hint=FieldHint.feature, + ) + + @abc.abstractmethod + def get_loss( + self, + logits: "torch.Tensor", + loss_mask: "torch.Tensor | None", + grad_output: float | None = None, + group: "ProcessGroup" = None, + logits_scale_factor: float | None = None, + vocab_parallel: bool = False, + kwargs: dict | None = None, + ) -> "tuple[torch.Tensor, torch.Tensor | None]": + pass + + def get_loss_definitions(self, name: str, count: int = 1, prediction_distance: int | None = None) -> LossDef: + name = self.get_formatted_name(name, prediction_distance) + return LossDef( + name=name, + formatted_name=_format_name(name), + count=count, + dtype=DataType.float32, + ) + + def _validate(self): + Assert.geq(self.weight, 0.0) + super()._validate() + + def get_formatted_name(self, registered_loss_name=None, prediction_distance: int | None = None) -> str: + """ + Returns loss name for logging as '()', + e.g. lm_loss(CE_loss), distillation(FwdKL_loss) + """ + name = f"{registered_loss_name}({self._name})" + if prediction_distance is not None: + name = f"{name}_{prediction_distance}" + return name + + @abc.abstractmethod + def get_targets( + self, + kwargs: dict | None = None, + prediction_distance: int | None = None, + prediction_heads: int | None = None, + sequence_parallel_logits: bool | None = None, + group: "ProcessGroup" = None, + ) -> dict[str, "torch.Tensor"]: + pass + + +@config_class(dynamic_type={LanguageModelLossConfig: "cross_entropy"}) +class CrossEntropyLMLossConfig(LanguageModelLossConfig): + _name: typing.ClassVar[str] = "CE_loss" + _abstract: typing.ClassVar[bool] = False + + implementation: CrossEntropyImpl = Field( + default=CrossEntropyImpl.auto, + desc="Implementation for the cross-entropy computation.", + hint=FieldHint.performance, + ) + + teacher_softmax_temperature: float = Field( + default=1.0, + hint=FieldHint.optional, + desc="Temperature for teacher softmax (used in distillation losses).", + valid=check_field(Assert.gt, 0.0), + ) + + def get_targets( + self, + kwargs: dict | None = None, + prediction_distance: int | None = None, + prediction_heads: int | None = None, + sequence_parallel_logits: bool | None = None, + group: "ProcessGroup" = None, + ) -> dict[str, "torch.Tensor"]: + if kwargs is None: + kwargs = {} + + lm_target = kwargs.get(LanguageModelKwargs.labels) + if lm_target is not None: + # MTP: Shift the labels + lm_target_sequence_length = ( + lm_target.size(1 - kwargs[LanguageModelKwargs.sequence_first]) + 1 - prediction_heads + ) + if LanguageModelKwargs.sequence_q_dim in kwargs: + Assert.eq(lm_target_sequence_length, kwargs[LanguageModelKwargs.sequence_q_dim].size) + lm_target_slice = slice(prediction_distance, prediction_distance + lm_target_sequence_length) + lm_target = ( + lm_target[lm_target_slice] + if kwargs[LanguageModelKwargs.sequence_first] + else lm_target[:, lm_target_slice] + ).flatten() + if sequence_parallel_logits: + from fast_llm.core.ops import split_op + + lm_target = split_op(lm_target, group, 0) + return {TargetsKwargs.lm_target: lm_target} + + def get_loss( + self, + logits: "torch.Tensor", + loss_mask: "torch.Tensor | None", + grad_output: float | None = None, + group: "ProcessGroup" = None, + logits_scale_factor: float | None = None, + vocab_parallel: bool = False, + kwargs: dict | None = None, + ) -> "tuple[torch.Tensor, torch.Tensor | None]": + from fast_llm.functional.cross_entropy import cross_entropy_forward_backward + + target = kwargs.get(TargetsKwargs.lm_target) + implementation = self.implementation + if implementation == CrossEntropyImpl.auto: + if vocab_parallel: + implementation = CrossEntropyImpl.fused + elif TritonConfig.TRITON_ENABLED: + implementation = CrossEntropyImpl.triton + else: + implementation = CrossEntropyImpl.fused + + return cross_entropy_forward_backward( + logits=logits.flatten(0, -2), + target=target, + loss_mask=None, # Labels are already masked + grad_output=grad_output, + group=group, + implementation=implementation, + logits_scale_factor=logits_scale_factor, + teacher_softmax_temperature=self.teacher_softmax_temperature, + target_format=TargetFormat.labels, + ) + + +@config_class(dynamic_type={LanguageModelLossConfig: "forward_kl_distillation"}) +class ForwardKLLossConfig(LanguageModelLossConfig): + """Forward KL divergence KL(p||q) for distillation (mode-covering).""" + + _name: typing.ClassVar[str] = "FwdKL_loss" + _abstract: typing.ClassVar[bool] = False + + teacher_softmax_temperature: float = Field( + default=1.0, + hint=FieldHint.optional, + desc="Temperature for teacher softmax.", + valid=check_field(Assert.gt, 0.0), + ) + + def _validate(self): + assert self.distillation_model is not None, "Distillation loss required by ForwardKL Loss." + super()._validate() + + def get_targets( + self, + kwargs: dict | None = None, + prediction_distance: int | None = None, + prediction_heads: int | None = None, + sequence_parallel_logits: bool | None = None, + group: "ProcessGroup" = None, + ) -> dict[str, "torch.Tensor"]: + if kwargs is None: + kwargs = {} + + reference_model_logits = kwargs.get(f"{self.distillation_model}_logits") + if reference_model_logits is not None: + reference_model_logits = reference_model_logits.flatten(0, -2) + if sequence_parallel_logits: + from fast_llm.core.ops import split_op + + reference_model_logits = split_op(reference_model_logits, group, 0) + return {TargetsKwargs.reference_model_logits: reference_model_logits} + + def get_loss( + self, + logits: "torch.Tensor", + loss_mask: "torch.Tensor | None", + grad_output: float | None = None, + group: "ProcessGroup" = None, + logits_scale_factor: float | None = None, + vocab_parallel: bool = False, + kwargs: dict | None = None, + ) -> "tuple[torch.Tensor, torch.Tensor | None]": + from fast_llm.functional.cross_entropy import forward_kl_forward_backward + + target = kwargs.get(TargetsKwargs.reference_model_logits) + + return forward_kl_forward_backward( + logits=logits.flatten(0, -2), + target=target, + loss_mask=loss_mask, + grad_output=grad_output, + group=group, + logits_scale_factor=logits_scale_factor, + teacher_softmax_temperature=self.teacher_softmax_temperature, + target_format=TargetFormat.logits, + ) + + +@config_class(dynamic_type={LanguageModelLossConfig: "reverse_kl_distillation"}) +class ReverseKLLossConfig(ForwardKLLossConfig): + """Reverse KL divergence KL(q||p) for distillation (mode-seeking).""" + + _name: typing.ClassVar[str] = "RevKL_loss" + _abstract: typing.ClassVar[bool] = False + + def _validate(self): + assert self.distillation_model is not None, "Distillation loss required by Reverse KL Loss." + super()._validate() + + def get_loss( + self, + logits: "torch.Tensor", + loss_mask: "torch.Tensor | None", + grad_output: float | None = None, + group: "ProcessGroup" = None, + logits_scale_factor: float | None = None, + vocab_parallel: bool = False, + kwargs: dict | None = None, + ) -> "tuple[torch.Tensor, torch.Tensor | None]": + from fast_llm.functional.cross_entropy import reverse_kl_forward_backward + + # Use distillation_target for KL losses + target = kwargs.get(TargetsKwargs.reference_model_logits) + + return reverse_kl_forward_backward( + logits=logits.flatten(0, -2), + target=target, + loss_mask=loss_mask, + grad_output=grad_output, + group=group, + logits_scale_factor=logits_scale_factor, + teacher_softmax_temperature=self.teacher_softmax_temperature, + target_format=TargetFormat.logits, + ) + + +@config_class(dynamic_type={LanguageModelLossConfig: "dpo"}) +class DPOLossConfig(LanguageModelLossConfig): + """Direct Preference Optimization (DPO) loss for alignment.""" + + _name: typing.ClassVar[str] = "DPO_loss" + _abstract: typing.ClassVar[bool] = False + + beta: float = Field( + default=1.0, + hint=FieldHint.core, + desc="Beta parameter for DPO loss (controls strength of preference optimization).", + valid=check_field(Assert.gt, 0.0), + ) + + dpo_reference_model: str | None = Field( + default=None, + desc="Name of the reference model to use for dpo.", + hint=FieldHint.feature, + ) + + def _validate(self): + assert self.dpo_reference_model is not None, "DPO loss requires a reference model." + super()._validate() + + def get_targets( + self, + kwargs: dict | None = None, + prediction_distance: int | None = None, + prediction_heads: int | None = None, + sequence_parallel_logits: bool | None = None, + group: "ProcessGroup" = None, + ) -> dict[str, "torch.Tensor"]: + if kwargs is None: + kwargs = {} + + reference_model_logits = kwargs.get(f"{self.dpo_reference_model}_logits") + dpo_target = kwargs.get(LanguageModelKwargs.labels) + if reference_model_logits is not None or dpo_target is not None: + from fast_llm.core.ops import split_op + + if reference_model_logits is not None: + reference_model_logits = reference_model_logits.flatten(0, -2) + if sequence_parallel_logits: + reference_model_logits = split_op(reference_model_logits, group, 0) + if dpo_target is not None: + dpo_target = split_op(dpo_target, group, 0) + return { + TargetsKwargs.dpo_reference_model_logits: reference_model_logits, + TargetsKwargs.dpo_target: dpo_target, + } + + def get_loss( + self, + logits: "torch.Tensor", + loss_mask: "torch.Tensor | None", + grad_output: float | None = None, + group: "ProcessGroup" = None, + logits_scale_factor: float | None = None, + vocab_parallel: bool = False, + kwargs: dict | None = None, + ) -> "tuple[torch.Tensor, torch.Tensor | None]": + from fast_llm.functional.dpo import compute_dpo_loss + + dpo_target = kwargs.get(TargetsKwargs.dpo_target) + dpo_reference_model_logits = kwargs.get(TargetsKwargs.dpo_reference_model_logits) + chosen_spans = kwargs.get(LanguageModelKwargs.chosen_spans) + rejected_spans = kwargs.get(LanguageModelKwargs.rejected_spans) + + return compute_dpo_loss( + logits=logits, + targets=dpo_target, + reference_model_logits=dpo_reference_model_logits, + chosen_spans=chosen_spans, + rejected_spans=rejected_spans, + beta=self.beta, + grad_output=grad_output, + ) + + +@config_class(dynamic_type={LanguageModelLossConfig: "z_loss"}) +class ZLossConfig(LanguageModelLossConfig): + """Z-loss regularization to prevent overconfidence.""" + + _name: typing.ClassVar[str] = "Z_loss" + _abstract: typing.ClassVar[bool] = False + + def get_targets( + self, + kwargs: dict | None = None, + prediction_distance: int | None = None, + prediction_heads: int | None = None, + sequence_parallel_logits: bool | None = None, + group: "ProcessGroup" = None, + ) -> dict[str, "torch.Tensor"]: + return {} + + def get_loss( + self, + logits: "torch.Tensor", + loss_mask: "torch.Tensor | None", + grad_output: float | None = None, + group: "ProcessGroup" = None, + logits_scale_factor: float | None = None, + vocab_parallel: bool = False, + kwargs: dict | None = None, + ) -> "tuple[torch.Tensor, torch.Tensor | None]": + from fast_llm.layers.common.auxiliary_loss import z_loss + + return z_loss( + logits=logits.flatten(0, -2), + grad_scale=grad_output, + logits_scale_factor=logits_scale_factor, + ) + + @config_class() class LanguageModelEmbeddingsConfig(BlockConfig): _abstract = False diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index b3e0e47b6..7f303684f 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -22,9 +22,9 @@ LanguageModelEmbeddingsConfig, LanguageModelHeadBaseConfig, LanguageModelHeadConfig, + _format_name, ) from fast_llm.layers.language_model.kwargs import LanguageModelKwargs -from fast_llm.layers.language_model.lm_head_losses import _format_name from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert, div, get_unique diff --git a/fast_llm/layers/language_model/lm_head_losses.py b/fast_llm/layers/language_model/lm_head_losses.py deleted file mode 100644 index 327dee560..000000000 --- a/fast_llm/layers/language_model/lm_head_losses.py +++ /dev/null @@ -1,400 +0,0 @@ -import abc -import logging -import typing - -from fast_llm.config import Config, Field, FieldHint, check_field, config_class -from fast_llm.engine.base_model.config import LossDef -from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.functional.config import CrossEntropyImpl, TargetFormat, TritonConfig -from fast_llm.layers.language_model.kwargs import LanguageModelKwargs, TargetsKwargs -from fast_llm.utils import Assert - -if typing.TYPE_CHECKING: - import torch - - from fast_llm.core.distributed import ProcessGroup - -logger = logging.getLogger(__name__) - -# -# CE loss on lm_targets for standard LM training. Here targets are already masked. -# CE loss for distillation: cross entropuy that uses reference_model_logits as soft targets, not implemented, TODO. -# Forward KL divergence loss on reference_model_logits for distillation (mode-covering). -# Reverse KL divergence loss on reference_model_logits for distillation (mode-seeking). -# DPO loss for alignment using chosen and rejected spans. -# - - -def _format_name(name: str) -> str: - return name.replace("_", " ") - - -@config_class(registry=True) -class LanguageModelLossConfig(Config): - """ - Losses can register themselves using @config_class(dynamic_type= {LanguageModelLossConfig: "loss_type_name"}). - """ - - _name: typing.ClassVar[str] - _abstract: typing.ClassVar[bool] = True - - weight: float = Field( - default=1.0, - hint=FieldHint.core, - desc="Weight for this loss in the total loss computation.", - valid=check_field(Assert.geq, 0.0), - ) - - distillation_model: str | None = Field( - default=None, - desc="Name of the reference model to use for knowledge distillation." - "If provided, replace the loss with a distillation loss.", - hint=FieldHint.feature, - ) - - @abc.abstractmethod - def get_loss( - self, - logits: "torch.Tensor", - loss_mask: "torch.Tensor | None", - grad_output: float | None = None, - group: "ProcessGroup" = None, - logits_scale_factor: float | None = None, - vocab_parallel: bool = False, - kwargs: dict | None = None, - ) -> "tuple[torch.Tensor, torch.Tensor | None]": - pass - - def get_loss_definitions(self, name: str, count: int = 1, prediction_distance: int | None = None) -> LossDef: - name = self.get_formatted_name(name, prediction_distance) - return LossDef( - name=name, - formatted_name=_format_name(name), - count=count, - dtype=DataType.float32, - ) - - def _validate(self): - Assert.geq(self.weight, 0.0) - super()._validate() - - def get_formatted_name(self, registered_loss_name=None, prediction_distance: int | None = None) -> str: - """ - Retruns loss name for logging as '()', e.g. lm_loss(CE_loss), distillation(FwdKL_loss) - """ - name = f"{registered_loss_name}({self._name})" - if prediction_distance is not None: - name = f"{name}_{prediction_distance}" - return name - - @abc.abstractmethod - def get_targets( - self, - kwargs: dict | None = None, - prediction_distance: int | None = None, - prediction_heads: int | None = None, - sequence_parallel_logits: bool | None = None, - group: "ProcessGroup" = None, - ) -> dict[str, "torch.Tensor"]: - pass - - -@config_class(dynamic_type={LanguageModelLossConfig: "cross_entropy"}) -class CrossEntropyLMLossConfig(LanguageModelLossConfig): - _name: typing.ClassVar[str] = "CE_loss" - _abstract: typing.ClassVar[bool] = False - - implementation: CrossEntropyImpl = Field( - default=CrossEntropyImpl.auto, - desc="Implementation for the cross-entropy computation.", - hint=FieldHint.performance, - ) - - teacher_softmax_temperature: float = Field( - default=1.0, - hint=FieldHint.optional, - desc="Temperature for teacher softmax (used in distillation losses).", - valid=check_field(Assert.gt, 0.0), - ) - - def get_targets( - self, - kwargs: dict | None = None, - prediction_distance: int | None = None, - prediction_heads: int | None = None, - sequence_parallel_logits: bool | None = None, - group: "ProcessGroup" = None, - ) -> dict[str, "torch.Tensor"]: - if kwargs is None: - kwargs = {} - - lm_target = kwargs.get(LanguageModelKwargs.labels) - if lm_target is not None: - # MTP: Shift the labels - lm_target_sequence_length = ( - lm_target.size(1 - kwargs[LanguageModelKwargs.sequence_first]) + 1 - prediction_heads - ) - if LanguageModelKwargs.sequence_q_dim in kwargs: - Assert.eq(lm_target_sequence_length, kwargs[LanguageModelKwargs.sequence_q_dim].size) - lm_target_slice = slice(prediction_distance, prediction_distance + lm_target_sequence_length) - lm_target = ( - lm_target[lm_target_slice] - if kwargs[LanguageModelKwargs.sequence_first] - else lm_target[:, lm_target_slice] - ).flatten() - if sequence_parallel_logits: - from fast_llm.core.ops import split_op - - lm_target = split_op(lm_target, group, 0) - return {TargetsKwargs.lm_target: lm_target} - - def get_loss( - self, - logits: "torch.Tensor", - loss_mask: "torch.Tensor | None", - grad_output: float | None = None, - group: "ProcessGroup" = None, - logits_scale_factor: float | None = None, - vocab_parallel: bool = False, - kwargs: dict | None = None, - ) -> "tuple[torch.Tensor, torch.Tensor | None]": - from fast_llm.functional.cross_entropy import cross_entropy_forward_backward - - target = kwargs.get(TargetsKwargs.lm_target) - implementation = self.implementation - if implementation == CrossEntropyImpl.auto: - if vocab_parallel: - implementation = CrossEntropyImpl.fused - elif TritonConfig.TRITON_ENABLED: - implementation = CrossEntropyImpl.triton - else: - implementation = CrossEntropyImpl.fused - - return cross_entropy_forward_backward( - logits=logits.flatten(0, -2), - target=target, - loss_mask=None, # Labels are already masked - grad_output=grad_output, - group=group, - implementation=implementation, - logits_scale_factor=logits_scale_factor, - teacher_softmax_temperature=self.teacher_softmax_temperature, - target_format=TargetFormat.labels, - ) - - -@config_class(dynamic_type={LanguageModelLossConfig: "forward_kl_distillation"}) -class ForwardKLLossConfig(LanguageModelLossConfig): - """Forward KL divergence KL(p||q) for distillation (mode-covering).""" - - _name: typing.ClassVar[str] = "FwdKL_loss" - _abstract: typing.ClassVar[bool] = False - - teacher_softmax_temperature: float = Field( - default=1.0, - hint=FieldHint.optional, - desc="Temperature for teacher softmax.", - valid=check_field(Assert.gt, 0.0), - ) - - def _validate(self): - assert self.distillation_model is not None, "Distillation loss required by ForwardKL Loss." - super()._validate() - - def get_targets( - self, - kwargs: dict | None = None, - prediction_distance: int | None = None, - prediction_heads: int | None = None, - sequence_parallel_logits: bool | None = None, - group: "ProcessGroup" = None, - ) -> dict[str, "torch.Tensor"]: - if kwargs is None: - kwargs = {} - - reference_model_logits = kwargs.get(f"{self.distillation_model}_logits") - if reference_model_logits is not None: - reference_model_logits = reference_model_logits.flatten(0, -2) - if sequence_parallel_logits: - from fast_llm.core.ops import split_op - - reference_model_logits = split_op(reference_model_logits, group, 0) - return {TargetsKwargs.reference_model_logits: reference_model_logits} - - def get_loss( - self, - logits: "torch.Tensor", - loss_mask: "torch.Tensor | None", - grad_output: float | None = None, - group: "ProcessGroup" = None, - logits_scale_factor: float | None = None, - vocab_parallel: bool = False, - kwargs: dict | None = None, - ) -> "tuple[torch.Tensor, torch.Tensor | None]": - from fast_llm.functional.cross_entropy import forward_kl_forward_backward - - target = kwargs.get(TargetsKwargs.reference_model_logits) - - return forward_kl_forward_backward( - logits=logits.flatten(0, -2), - target=target, - loss_mask=loss_mask, - grad_output=grad_output, - group=group, - logits_scale_factor=logits_scale_factor, - teacher_softmax_temperature=self.teacher_softmax_temperature, - target_format=TargetFormat.logits, - ) - - -@config_class(dynamic_type={LanguageModelLossConfig: "reverse_kl_distillation"}) -class ReverseKLLossConfig(ForwardKLLossConfig): - """Reverse KL divergence KL(q||p) for distillation (mode-seeking).""" - - _name: typing.ClassVar[str] = "RevKL_loss" - _abstract: typing.ClassVar[bool] = False - - def _validate(self): - assert self.distillation_model is not None, "Distillation loss required by Reverse KL Loss." - super()._validate() - - def get_loss( - self, - logits: "torch.Tensor", - loss_mask: "torch.Tensor | None", - grad_output: float | None = None, - group: "ProcessGroup" = None, - logits_scale_factor: float | None = None, - vocab_parallel: bool = False, - kwargs: dict | None = None, - ) -> "tuple[torch.Tensor, torch.Tensor | None]": - from fast_llm.functional.cross_entropy import reverse_kl_forward_backward - - # Use distillation_target for KL losses - target = kwargs.get(TargetsKwargs.reference_model_logits) - - return reverse_kl_forward_backward( - logits=logits.flatten(0, -2), - target=target, - loss_mask=loss_mask, - grad_output=grad_output, - group=group, - logits_scale_factor=logits_scale_factor, - teacher_softmax_temperature=self.teacher_softmax_temperature, - target_format=TargetFormat.logits, - ) - - -@config_class(dynamic_type={LanguageModelLossConfig: "dpo"}) -class DPOLossConfig(LanguageModelLossConfig): - """Direct Preference Optimization (DPO) loss for alignment.""" - - _name: typing.ClassVar[str] = "DPO_loss" - _abstract: typing.ClassVar[bool] = False - - beta: float = Field( - default=1.0, - hint=FieldHint.core, - desc="Beta parameter for DPO loss (controls strength of preference optimization).", - valid=check_field(Assert.gt, 0.0), - ) - - dpo_reference_model: str | None = Field( - default=None, - desc="Name of the reference model to use for dpo.", - hint=FieldHint.feature, - ) - - def _validate(self): - assert self.dpo_reference_model is not None, "DPO loss requires a reference model." - super()._validate() - - def get_targets( - self, - kwargs: dict | None = None, - prediction_distance: int | None = None, - prediction_heads: int | None = None, - sequence_parallel_logits: bool | None = None, - group: "ProcessGroup" = None, - ) -> dict[str, "torch.Tensor"]: - if kwargs is None: - kwargs = {} - - reference_model_logits = kwargs.get(f"{self.dpo_reference_model}_logits") - dpo_target = kwargs.get(LanguageModelKwargs.labels) - if reference_model_logits is not None or dpo_target is not None: - from fast_llm.core.ops import split_op - - if reference_model_logits is not None: - reference_model_logits = reference_model_logits.flatten(0, -2) - if sequence_parallel_logits: - reference_model_logits = split_op(reference_model_logits, group, 0) - if dpo_target is not None: - dpo_target = split_op(dpo_target, group, 0) - return { - TargetsKwargs.dpo_reference_model_logits: reference_model_logits, - TargetsKwargs.dpo_target: dpo_target, - } - - def get_loss( - self, - logits: "torch.Tensor", - loss_mask: "torch.Tensor | None", - grad_output: float | None = None, - group: "ProcessGroup" = None, - logits_scale_factor: float | None = None, - vocab_parallel: bool = False, - kwargs: dict | None = None, - ) -> "tuple[torch.Tensor, torch.Tensor | None]": - from fast_llm.functional.dpo import compute_dpo_loss - - dpo_target = kwargs.get(TargetsKwargs.dpo_target) - dpo_reference_model_logits = kwargs.get(TargetsKwargs.dpo_reference_model_logits) - chosen_spans = kwargs.get(LanguageModelKwargs.chosen_spans) - rejected_spans = kwargs.get(LanguageModelKwargs.rejected_spans) - - return compute_dpo_loss( - logits=logits, - targets=dpo_target, - reference_model_logits=dpo_reference_model_logits, - chosen_spans=chosen_spans, - rejected_spans=rejected_spans, - beta=self.beta, - grad_output=grad_output, - ) - - -@config_class(dynamic_type={LanguageModelLossConfig: "z_loss"}) -class ZLossConfig(LanguageModelLossConfig): - """Z-loss regularization to prevent overconfidence.""" - - _name: typing.ClassVar[str] = "Z_loss" - _abstract: typing.ClassVar[bool] = False - - def get_targets( - self, - kwargs: dict | None = None, - prediction_distance: int | None = None, - prediction_heads: int | None = None, - sequence_parallel_logits: bool | None = None, - group: "ProcessGroup" = None, - ) -> dict[str, "torch.Tensor"]: - return {} - - def get_loss( - self, - logits: "torch.Tensor", - loss_mask: "torch.Tensor | None", - grad_output: float | None = None, - group: "ProcessGroup" = None, - logits_scale_factor: float | None = None, - vocab_parallel: bool = False, - kwargs: dict | None = None, - ) -> "tuple[torch.Tensor, torch.Tensor | None]": - from fast_llm.layers.common.auxiliary_loss import z_loss - - return z_loss( - logits=logits.flatten(0, -2), - grad_scale=grad_output, - logits_scale_factor=logits_scale_factor, - ) diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 9c81ba0a4..aca378418 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -7,10 +7,9 @@ from fast_llm.engine.config_utils.data_type import DataType from fast_llm.functional.config import CrossEntropyImpl from fast_llm.layers.attention.config import AttentionKwargs -from fast_llm.layers.language_model.config import LanguageModelHeadConfig +from fast_llm.layers.language_model.config import LanguageModelHeadConfig, LanguageModelLossConfig from fast_llm.layers.language_model.head import LanguageModelHead from fast_llm.layers.language_model.kwargs import LanguageModelKwargs -from fast_llm.layers.language_model.lm_head_losses import LanguageModelLossConfig from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.utils import Assert from tests.utils.utils import get_base_model, get_stage, requires_cuda From ab9c9176efae53d0c5d5c5db47b96804ffe1b4ba Mon Sep 17 00:00:00 2001 From: oleksost Date: Thu, 8 Jan 2026 22:30:42 +0000 Subject: [PATCH 27/29] tests --- fast_llm/functional/cross_entropy.py | 4 ++-- tests/layers/test_lm_head.py | 23 ++++++++++++++++++++++- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 6b0a4e92f..6204ce316 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -98,10 +98,10 @@ def _fused_cross_entropy_forward_backward( logits_norm, exp_logits, sum_exp_logits = _fused_softmax_base(logits, logits_scale_factor, group) if target_format == TargetFormat.logits: - target_logits, exp_logits, sum_exp_target_logits = _fused_softmax_base( + target_logits, exp_logits_targets, sum_exp_target_logits = _fused_softmax_base( target, logits_scale_factor / teacher_softmax_temperature, group ) - target = exp_logits / sum_exp_target_logits + target = exp_logits_targets / sum_exp_target_logits if target_format == TargetFormat.labels: target = target.unsqueeze(-1) diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index aca378418..6929784f5 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -204,6 +204,27 @@ def _lm_head( 1, id="track_lm_zero_factor", ), + pytest.param( + { + "head": { + "losses": { + "lm_loss": { + "type": "cross_entropy", + "weight": 0.0, + }, + "dist_loss": { + "type": "forward_kl_distillation", + "weight": 1.0, + "distillation_model": "distillation", + }, + }, + } + }, + {}, + False, + 1, + id="forward_kl_distillation", + ), pytest.param( { "head": { @@ -224,7 +245,7 @@ def _lm_head( False, 1, marks=pytest.mark.xfail( - reason="Cannot track both losses with zero factor", + reason="At least one loss has to have non-zero factor to track gradients", strict=True, ), id="track_both_zero_factors", From 6e54c93bace0c52837724c08d3f510118a31316b Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 12 Jan 2026 18:25:23 +0000 Subject: [PATCH 28/29] comments --- fast_llm/layers/language_model/config.py | 29 +++++++++++++++++++-- fast_llm/layers/language_model/embedding.py | 3 +-- fast_llm/layers/language_model/head.py | 2 +- fast_llm/layers/language_model/kwargs.py | 23 ---------------- fast_llm/models/gpt/model.py | 14 +++++----- fast_llm/models/multimodal/model.py | 2 +- tests/layers/test_lm_head.py | 3 +-- 7 files changed, 38 insertions(+), 38 deletions(-) delete mode 100644 fast_llm/layers/language_model/kwargs.py diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index adf8dd86e..ab8848d99 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -10,11 +10,10 @@ from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.functional.config import CrossEntropyImpl, TargetFormat, TritonConfig -from fast_llm.layers.block.config import BlockConfig, BlockSequenceConfig +from fast_llm.layers.block.config import BlockConfig, BlockKwargs, BlockSequenceConfig from fast_llm.layers.common.normalization.config import NormalizationConfig from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.config import DecoderBlockConfig -from fast_llm.layers.language_model.kwargs import LanguageModelKwargs, TargetsKwargs from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -27,6 +26,28 @@ from fast_llm.layers.language_model.multi_token_prediction import MultiTokenPrediction +class TargetsKwargs: + lm_target = "preprocessed_lm_target" + dpo_target = "preprocessed_dpo_target" + reference_model_logits = "reference_model_logits" + dpo_reference_model_logits = "dpo_reference_model_logits" + + +class LanguageModelKwargs(BlockKwargs): + token_ids = "token_ids" + position_ids = "position_ids" + token_map = "token_map" + sample_map = "sample_map" + embedding_map = "embedding_map" + # TODO: These are generic + labels = "labels" + phase = "phase" + chosen_spans = "chosen_spans" + rejected_spans = "rejected_spans" + loss_mask = "loss_mask" + mask_inputs = "mask_inputs" + + def _format_name(name: str) -> str: return name.replace("_", " ") @@ -610,6 +631,10 @@ def enable_dpo(self) -> bool: def enable_distillation(self) -> bool: return ForwardKLLossConfig in self._loss_configs.keys() or ReverseKLLossConfig in self._loss_configs.keys() + @property + def requires_loss_masks(self) -> bool: + return self.enable_distillation + @property def distillation_model(self) -> str | None: for loss_type in [ForwardKLLossConfig, ReverseKLLossConfig]: diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index fda5e3387..93850d24c 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -10,8 +10,7 @@ from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.layers.block.block import Block from fast_llm.layers.common.peft.config import PeftConfig -from fast_llm.layers.language_model.config import LanguageModelEmbeddingsConfig -from fast_llm.layers.language_model.kwargs import LanguageModelKwargs +from fast_llm.layers.language_model.config import LanguageModelEmbeddingsConfig, LanguageModelKwargs from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 7f303684f..2fa2dffe0 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -22,9 +22,9 @@ LanguageModelEmbeddingsConfig, LanguageModelHeadBaseConfig, LanguageModelHeadConfig, + LanguageModelKwargs, _format_name, ) -from fast_llm.layers.language_model.kwargs import LanguageModelKwargs from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert, div, get_unique diff --git a/fast_llm/layers/language_model/kwargs.py b/fast_llm/layers/language_model/kwargs.py deleted file mode 100644 index 4f6203881..000000000 --- a/fast_llm/layers/language_model/kwargs.py +++ /dev/null @@ -1,23 +0,0 @@ -from fast_llm.layers.block.config import BlockKwargs - - -class TargetsKwargs: - lm_target = "preprocessed_lm_target" - dpo_target = "preprocessed_dpo_target" - reference_model_logits = "reference_model_logits" - dpo_reference_model_logits = "dpo_reference_model_logits" - - -class LanguageModelKwargs(BlockKwargs): - token_ids = "token_ids" - position_ids = "position_ids" - token_map = "token_map" - sample_map = "sample_map" - embedding_map = "embedding_map" - # TODO: These are generic - labels = "labels" - phase = "phase" - chosen_spans = "chosen_spans" - rejected_spans = "rejected_spans" - loss_mask = "loss_mask" - mask_inputs = "mask_inputs" diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 846c65646..f83d12ca4 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -12,7 +12,7 @@ from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.block.config import BlockDimNames, BlockKwargs -from fast_llm.layers.language_model.kwargs import LanguageModelKwargs +from fast_llm.layers.language_model.config import LanguageModelKwargs from fast_llm.layers.language_model.language_model import LanguageModel from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron @@ -263,7 +263,6 @@ def preprocess_batch( if phase != PhaseType.inference: labels_begin = tokens_begin + 1 labels_end = tokens_end + self._config.head.max_prediction_distance - labels = batch.tokens.crop(labels_begin, labels_end).tokens if batch.loss_masking_spans is not None: @@ -272,13 +271,14 @@ def preprocess_batch( for sample_index, loss_masking_spans in enumerate(loss_masking_spans.ranges): for begin, end in loss_masking_spans: loss_mask[sample_index, begin:end] = False - if ( - self._config.head.distillation_model is not None - or self._config.decoder.block.distillation_model is not None - ): - kwargs[LanguageModelKwargs.loss_mask] = loss_mask labels = torch.where(loss_mask, labels, -100) + if ( + self._config.head.requires_loss_masks is not None + ): # loss masks only used for distillation currently + # loss masks contain all three sources of masking: padding, user-defined spans, image placeholders + kwargs[LanguageModelKwargs.loss_mask] = labels >= 0 + kwargs[LanguageModelKwargs.labels] = ( labels.transpose(0, 1) if kwargs[AttentionKwargs.sequence_first] else labels ).contiguous() diff --git a/fast_llm/models/multimodal/model.py b/fast_llm/models/multimodal/model.py index 88da79e65..890d5760e 100644 --- a/fast_llm/models/multimodal/model.py +++ b/fast_llm/models/multimodal/model.py @@ -10,7 +10,7 @@ from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.block.config import BlockDimNames, BlockKwargs -from fast_llm.layers.language_model.kwargs import LanguageModelKwargs +from fast_llm.layers.language_model.config import LanguageModelKwargs from fast_llm.layers.vision.config import VisionKwargs from fast_llm.layers.vision.vision_encoder import VisionMultiModalModel from fast_llm.models.gpt.config import GPTBatchConfig diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index c98c2780a..e01beb031 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -7,9 +7,8 @@ from fast_llm.engine.config_utils.data_type import DataType from fast_llm.functional.config import CrossEntropyImpl from fast_llm.layers.attention.config import AttentionKwargs -from fast_llm.layers.language_model.config import LanguageModelHeadConfig, LanguageModelLossConfig +from fast_llm.layers.language_model.config import LanguageModelHeadConfig, LanguageModelKwargs, LanguageModelLossConfig from fast_llm.layers.language_model.head import LanguageModelHead -from fast_llm.layers.language_model.kwargs import LanguageModelKwargs from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.utils import Assert from tests.utils.utils import get_base_model, get_stage, requires_cuda From 3c8f3c265abc71bb216bdb5ce0b1004f36c888da Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 12 Jan 2026 22:26:38 -0500 Subject: [PATCH 29/29] misc --- fast_llm/functional/cross_entropy.py | 57 ++++++--------- fast_llm/layers/common/auxiliary_loss.py | 57 +++++++++------ .../layers/decoder/mlp/mixture_of_experts.py | 4 +- fast_llm/layers/language_model/config.py | 73 ++++++++++--------- tests/utils/model_configs.py | 7 +- 5 files changed, 97 insertions(+), 101 deletions(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 9c4b7fcfc..c21b49a6c 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -75,7 +75,7 @@ def _fused_softmax( return exp_logits / sum_exp_logits -# @torch.compile +@torch.compile def _fused_cross_entropy_forward_backward( logits: torch.Tensor, target: torch.Tensor, @@ -85,7 +85,7 @@ def _fused_cross_entropy_forward_backward( target_format: TargetFormat, group: ProcessGroup | None = None, teacher_softmax_temperature: float = 1.0, - return_target_entropy: bool = False, + return_kl_loss: bool = False, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ A fused implementation of cross-entropy with torch compile. @@ -108,14 +108,14 @@ def _fused_cross_entropy_forward_backward( loss_mask = target >= 0 if group is None: # Keep values within range for scatter and gather ops to work. - target = target * loss_mask + target_masked = target * loss_mask target_mask = None else: # Mask the target (fused) # TODO: Could mask earlier on cpu or overlap with reduce? vocab_start_index = logits.size(-1) * group.rank() target_mask = (target >= vocab_start_index) * (target < vocab_start_index + logits.size(-1)) - target = (target - vocab_start_index) * target_mask + target_masked = (target - vocab_start_index) * target_mask else: # Target should be tensor-parallel already, no further manipulation needed. target_mask = None @@ -128,10 +128,10 @@ def _fused_cross_entropy_forward_backward( # grad / grad_output = exp_logits / sum_exp_logits - target_probabilities. if target_format == TargetFormat.labels: grad_base = exp_logits.scatter_add( - 1, target, -sum_exp_logits if target_mask is None else -(target_mask * sum_exp_logits) + 1, target_masked, -sum_exp_logits if target_mask is None else -(target_mask * sum_exp_logits) ) else: - grad_base = exp_logits - sum_exp_logits * target + grad_base = exp_logits - sum_exp_logits * target_masked grad = grad_base.mul((grad_output / logits.size(0)) / sum_exp_logits) if logits_scale_factor != 1.0: @@ -142,13 +142,13 @@ def _fused_cross_entropy_forward_backward( # loss = mean(log(sum_exp_logits) - sum(probabilities * logits)) if target_format == TargetFormat.labels: - predicted_logits = logits_norm.gather(1, target) + predicted_logits = logits_norm.gather(1, target_masked) if group is not None: predicted_logits = target_mask * predicted_logits all_reduce(predicted_logits, op=ReduceOp.SUM, group=group) else: - predicted_logits = (target * logits_norm).sum(dim=-1, keepdim=True) + predicted_logits = (target_masked * logits_norm).sum(dim=-1, keepdim=True) if group is not None and target_format != TargetFormat.labels: # this is needed because on each rank we calculate log Z - sum_i t_i * z_i, where z_i is logit. # Then we average on line 160: 1/K sum_ranks (log Z - sum_i t_i * z_i) @@ -162,7 +162,7 @@ def _fused_cross_entropy_forward_backward( loss = per_sample_loss.mean() if target_format != TargetFormat.labels and group is not None: all_reduce(loss, op=ReduceOp.AVG, group=group) - if return_target_entropy: + if return_kl_loss: if target_format == TargetFormat.logits: teacher_log_prob = target_logits - sum_exp_target_logits.log() else: @@ -173,7 +173,7 @@ def _fused_cross_entropy_forward_backward( target_entropy = target_entropy.mean() if group is not None: all_reduce(target_entropy, op=ReduceOp.SUM, group=group) - return loss, grad, target_entropy + loss -= target_entropy return loss, grad @@ -249,10 +249,7 @@ def _reverse_kl_forward_backward( target: torch.Tensor, loss_mask: torch.Tensor | None, grad_output: float | None, - target_format: TargetFormat, group: ProcessGroup | None = None, - logits_scale_factor: float = 1.0, - teacher_softmax_temperature: float = 1.0, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Reverse KL using PyTorch's native kl_div function. @@ -264,13 +261,6 @@ def _reverse_kl_forward_backward( loss_mask: [BxS] or [B, S] or None ... """ - Assert.eq( - teacher_softmax_temperature, - 1, - msg="Teacher softmax temperature must be 1 for sequence-tensor-parallel reverse KL", - ) - Assert.eq(logits_scale_factor, 1, msg="Logits scale factor must be 1 for sequence-tensor-parallel reverse KL") - Assert.eq(target.shape, logits.shape) assert target.dtype.is_floating_point, target.dtype if loss_mask is not None: Assert.eq(loss_mask.shape, logits.shape[:-1]) @@ -326,7 +316,6 @@ def reverse_kl_forward_backward( logits_scale_factor: float = 1.0, teacher_softmax_temperature: float = 1.0, target_format: TargetFormat = TargetFormat.labels, - sequence_parallel_logits: bool = False, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Compute reverse KL divergence: KL(q||p) where q is the predicted distribution (student) and p is the target (teacher). @@ -349,12 +338,13 @@ def reverse_kl_forward_backward( loss: Reverse KL divergence loss grad: Gradients w.r.t. logits """ - - if sequence_parallel_logits: - # TODO: see hybrid dev branch where it is implemented - raise NotImplementedError("Sequence-parallel reverse KL is not implemented yet, set vocab_parallel true") - Assert.eq(target_format, TargetFormat.logits, msg="Reverse KL only supports logits format") + Assert.eq( + teacher_softmax_temperature, + 1, + msg="Teacher softmax temperature must be 1 for reverse KL", + ) + Assert.eq(logits_scale_factor, 1, msg="Logits scale factor must be 1 for reverse KL") Assert.eq(target.shape, logits.shape) assert target.dtype.is_floating_point, target.dtype if loss_mask is not None: @@ -366,9 +356,6 @@ def reverse_kl_forward_backward( target=target, loss_mask=loss_mask, grad_output=grad_output, - logits_scale_factor=logits_scale_factor, - target_format=target_format, - teacher_softmax_temperature=teacher_softmax_temperature, group=group, ) return distillation_loss, distillation_grad @@ -383,7 +370,6 @@ def forward_kl_forward_backward( logits_scale_factor: float = 1.0, teacher_softmax_temperature: float = 1.0, target_format: TargetFormat = TargetFormat.labels, - sequence_parallel_logits: bool = False, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Compute forward KL divergence: KL(p||q) where p is the target distribution (teacher) and q is the predicted (student). @@ -408,7 +394,11 @@ def forward_kl_forward_backward( """ assert target_format == TargetFormat.logits, "Forward KL only supports logits format" Assert.eq(target.shape, logits.shape) - distillation_loss, distillation_grad, teacher_entropy = _fused_cross_entropy_forward_backward( + assert target.dtype.is_floating_point, target.dtype + if loss_mask is not None: + Assert.eq(loss_mask.shape, logits.shape[:-1]) + + return _fused_cross_entropy_forward_backward( logits=logits, target=target, loss_mask=loss_mask, @@ -417,8 +407,5 @@ def forward_kl_forward_backward( target_format=target_format, group=group, teacher_softmax_temperature=teacher_softmax_temperature, - return_target_entropy=True, + return_kl_loss=True, ) - distillation_loss -= teacher_entropy - - return distillation_loss, distillation_grad diff --git a/fast_llm/layers/common/auxiliary_loss.py b/fast_llm/layers/common/auxiliary_loss.py index 335debb12..1c8fe1c73 100644 --- a/fast_llm/layers/common/auxiliary_loss.py +++ b/fast_llm/layers/common/auxiliary_loss.py @@ -3,9 +3,9 @@ class AuxiliaryLoss(torch.autograd.Function): @staticmethod - def forward(ctx, scores: torch.Tensor, aux_loss: torch.Tensor, grad: float) -> torch.Tensor: # noqa + def forward(ctx, input_: torch.Tensor, aux_loss: torch.Tensor, grad: float) -> torch.Tensor: # noqa ctx.grad = torch.full_like(aux_loss, grad) - return scores + return input_ @staticmethod def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor | None, ...]: # noqa @@ -14,14 +14,33 @@ def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor | None, ...]: @torch.compile def calculate_z_loss(logits: torch.Tensor, logits_scale_factor: float = 1.0) -> torch.Tensor: - if logits_scale_factor != 1.0: - logits *= logits_scale_factor - return torch.mean(torch.logsumexp(logits, dim=-1) ** 2) + return torch.mean( + torch.logsumexp(logits if logits_scale_factor == 1.0 else logits * logits_scale_factor, dim=-1) ** 2 + ) -def z_loss( +def auxiliary_z_loss( logits: torch.Tensor, + z_loss_factor: float, + training: bool, grad_scale: float | None = None, + losses: dict | None = None, + loss_name: str | None = None, + logits_scale_factor: float = 1.0, +) -> torch.Tensor: + if losses is not None or (training and grad_scale is not None): + loss = calculate_z_loss(logits, logits_scale_factor=logits_scale_factor) + if losses is not None and loss_name is not None: + losses[loss_name].append(loss.detach()) + if training and grad_scale is not None: + logits = AuxiliaryLoss.apply(logits, loss, z_loss_factor * grad_scale) + + return logits + + +def z_loss_forward_backward( + logits: torch.Tensor, + grad_output: float | None = None, logits_scale_factor: float = 1.0, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ @@ -33,22 +52,14 @@ def z_loss( loss: The z-loss value (unscaled) grad: The gradient w.r.t. logits (scaled by grad_scale), or None if grad_scale is None """ - if logits_scale_factor != 1.0: - scaled_logits = logits * logits_scale_factor - else: - scaled_logits = logits - - # Forward: z_loss = mean(logsumexp^2) - lse = torch.logsumexp(scaled_logits, dim=-1) # (N,) - loss = torch.mean(lse**2) - - # Backward: grad = (2/N) * lse * softmax(scaled_logits) - grad = None - if grad_scale is not None: - N = scaled_logits.shape[0] - softmax_logits = torch.softmax(scaled_logits, dim=-1) - grad = (2.0 / N) * lse.unsqueeze(-1) * softmax_logits * grad_scale - if logits_scale_factor != 1.0: - grad = grad * logits_scale_factor # Chain rule for logits_scale_factor + + with torch.set_grad_enabled(grad_output is not None): + logits_ = logits.detach().requires_grad_(grad_output is not None) + loss = calculate_z_loss(logits, logits_scale_factor) + if grad_output is None: + grad = None + else: + loss.backward(torch.full_like(loss, grad_output)) + grad = logits_.grad.detach().to(logits.dtype) return loss, grad diff --git a/fast_llm/layers/decoder/mlp/mixture_of_experts.py b/fast_llm/layers/decoder/mlp/mixture_of_experts.py index 5cc351dac..fd3647389 100644 --- a/fast_llm/layers/decoder/mlp/mixture_of_experts.py +++ b/fast_llm/layers/decoder/mlp/mixture_of_experts.py @@ -13,7 +13,7 @@ from fast_llm.functional.triton.sparse_copy import get_sparse_map from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.block.config import BlockKwargs -from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss +from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, auxiliary_z_loss from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.mlp.config import MLPLossNames, MoEMLPConfig, RoutingType from fast_llm.layers.decoder.mlp.mlp import MLPBase @@ -102,7 +102,7 @@ def _forward( # Apply z_loss if applicable if self._config.z_loss_coefficient > 0.0: - logits = z_loss( + logits = auxiliary_z_loss( logits, self._config.z_loss_coefficient, self.training, diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index ab8848d99..e3de9e9cb 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -81,7 +81,7 @@ def get_loss( logits: "torch.Tensor", loss_mask: "torch.Tensor | None", grad_output: float | None = None, - group: "ProcessGroup" = None, + group: "ProcessGroup|None" = None, logits_scale_factor: float | None = None, vocab_parallel: bool = False, kwargs: dict | None = None, @@ -118,13 +118,13 @@ def get_targets( prediction_distance: int | None = None, prediction_heads: int | None = None, sequence_parallel_logits: bool | None = None, - group: "ProcessGroup" = None, + group: "ProcessGroup|None" = None, ) -> dict[str, "torch.Tensor"]: pass @config_class(dynamic_type={LanguageModelLossConfig: "cross_entropy"}) -class CrossEntropyLMLossConfig(LanguageModelLossConfig): +class CrossEntropyLanguageModelLossConfig(LanguageModelLossConfig): _name: typing.ClassVar[str] = "CE_loss" _abstract: typing.ClassVar[bool] = False @@ -134,10 +134,10 @@ class CrossEntropyLMLossConfig(LanguageModelLossConfig): hint=FieldHint.performance, ) - teacher_softmax_temperature: float = Field( + temperature: float = Field( default=1.0, hint=FieldHint.optional, - desc="Temperature for teacher softmax (used in distillation losses).", + desc="Temperature for teacher softmax.", valid=check_field(Assert.gt, 0.0), ) @@ -147,7 +147,7 @@ def get_targets( prediction_distance: int | None = None, prediction_heads: int | None = None, sequence_parallel_logits: bool | None = None, - group: "ProcessGroup" = None, + group: "ProcessGroup|None" = None, ) -> dict[str, "torch.Tensor"]: if kwargs is None: kwargs = {} @@ -202,19 +202,19 @@ def get_loss( group=group, implementation=implementation, logits_scale_factor=logits_scale_factor, - teacher_softmax_temperature=self.teacher_softmax_temperature, + teacher_softmax_temperature=self.temperature, target_format=TargetFormat.labels, ) @config_class(dynamic_type={LanguageModelLossConfig: "forward_kl_distillation"}) -class ForwardKLLossConfig(LanguageModelLossConfig): +class ForwardKLDistillationLossConfig(LanguageModelLossConfig): """Forward KL divergence KL(p||q) for distillation (mode-covering).""" _name: typing.ClassVar[str] = "FwdKL_loss" _abstract: typing.ClassVar[bool] = False - teacher_softmax_temperature: float = Field( + temperature: float = Field( default=1.0, hint=FieldHint.optional, desc="Temperature for teacher softmax.", @@ -231,7 +231,7 @@ def get_targets( prediction_distance: int | None = None, prediction_heads: int | None = None, sequence_parallel_logits: bool | None = None, - group: "ProcessGroup" = None, + group: "ProcessGroup|None" = None, ) -> dict[str, "torch.Tensor"]: if kwargs is None: kwargs = {} @@ -250,7 +250,7 @@ def get_loss( logits: "torch.Tensor", loss_mask: "torch.Tensor | None", grad_output: float | None = None, - group: "ProcessGroup" = None, + group: "ProcessGroup|None" = None, logits_scale_factor: float | None = None, vocab_parallel: bool = False, kwargs: dict | None = None, @@ -266,13 +266,13 @@ def get_loss( grad_output=grad_output, group=group, logits_scale_factor=logits_scale_factor, - teacher_softmax_temperature=self.teacher_softmax_temperature, + teacher_softmax_temperature=self.temperature, target_format=TargetFormat.logits, ) @config_class(dynamic_type={LanguageModelLossConfig: "reverse_kl_distillation"}) -class ReverseKLLossConfig(ForwardKLLossConfig): +class ReverseKLLossConfig(ForwardKLDistillationLossConfig): """Reverse KL divergence KL(q||p) for distillation (mode-seeking).""" _name: typing.ClassVar[str] = "RevKL_loss" @@ -287,7 +287,7 @@ def get_loss( logits: "torch.Tensor", loss_mask: "torch.Tensor | None", grad_output: float | None = None, - group: "ProcessGroup" = None, + group: "ProcessGroup|None" = None, logits_scale_factor: float | None = None, vocab_parallel: bool = False, kwargs: dict | None = None, @@ -304,7 +304,7 @@ def get_loss( grad_output=grad_output, group=group, logits_scale_factor=logits_scale_factor, - teacher_softmax_temperature=self.teacher_softmax_temperature, + teacher_softmax_temperature=self.temperature, target_format=TargetFormat.logits, ) @@ -339,7 +339,7 @@ def get_targets( prediction_distance: int | None = None, prediction_heads: int | None = None, sequence_parallel_logits: bool | None = None, - group: "ProcessGroup" = None, + group: "ProcessGroup|None" = None, ) -> dict[str, "torch.Tensor"]: if kwargs is None: kwargs = {} @@ -365,7 +365,7 @@ def get_loss( logits: "torch.Tensor", loss_mask: "torch.Tensor | None", grad_output: float | None = None, - group: "ProcessGroup" = None, + group: "ProcessGroup|None" = None, logits_scale_factor: float | None = None, vocab_parallel: bool = False, kwargs: dict | None = None, @@ -401,7 +401,7 @@ def get_targets( prediction_distance: int | None = None, prediction_heads: int | None = None, sequence_parallel_logits: bool | None = None, - group: "ProcessGroup" = None, + group: "ProcessGroup|None" = None, ) -> dict[str, "torch.Tensor"]: return {} @@ -410,16 +410,20 @@ def get_loss( logits: "torch.Tensor", loss_mask: "torch.Tensor | None", grad_output: float | None = None, - group: "ProcessGroup" = None, + group: "ProcessGroup|None" = None, logits_scale_factor: float | None = None, vocab_parallel: bool = False, kwargs: dict | None = None, ) -> "tuple[torch.Tensor, torch.Tensor | None]": - from fast_llm.layers.common.auxiliary_loss import z_loss + from fast_llm.layers.common.auxiliary_loss import z_loss_forward_backward + + # TODO: ====== Support loss mask, vocab_parallel ====== + assert loss_mask is None + assert group is None - return z_loss( + return z_loss_forward_backward( logits=logits.flatten(0, -2), - grad_scale=grad_output, + grad_output=grad_output, logits_scale_factor=logits_scale_factor, ) @@ -549,13 +553,6 @@ class LanguageModelHeadConfig(LanguageModelHeadBaseConfig): hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) - logit_z_loss: float = Field( - default=0.0, - desc="Regularize the logits with Z-loss.", - doc="We recommend 1e-4 for stability, as used for training PaLM.", - hint=FieldHint.feature, - valid=check_field(Assert.geq, 0), - ) def get_layer( self, @@ -604,14 +601,17 @@ def _validate(self) -> None: with self._set_implicit_default(): if not self.losses: if "losses" not in self._explicit_fields: - self.losses = {"lm_loss": CrossEntropyLMLossConfig()} + self.losses = {"lm_loss": CrossEntropyLanguageModelLossConfig()} super()._validate() if DPOLossConfig in self._loss_configs: - assert ForwardKLLossConfig not in self._loss_configs.keys() # currently don't support both + assert ForwardKLDistillationLossConfig not in self._loss_configs.keys() # currently don't support both assert ReverseKLLossConfig not in self._loss_configs.keys() # currently don't support both - if ForwardKLLossConfig in self._loss_configs.keys() and ReverseKLLossConfig in self._loss_configs.keys(): + if ( + ForwardKLDistillationLossConfig in self._loss_configs.keys() + and ReverseKLLossConfig in self._loss_configs.keys() + ): assert ( - self._loss_configs[ForwardKLLossConfig].distillation_model + self._loss_configs[ForwardKLDistillationLossConfig].distillation_model == self._loss_configs[ReverseKLLossConfig].distillation_model ), "Distillation losses must use the same teacher." @@ -629,7 +629,10 @@ def enable_dpo(self) -> bool: @property def enable_distillation(self) -> bool: - return ForwardKLLossConfig in self._loss_configs.keys() or ReverseKLLossConfig in self._loss_configs.keys() + return ( + ForwardKLDistillationLossConfig in self._loss_configs.keys() + or ReverseKLLossConfig in self._loss_configs.keys() + ) @property def requires_loss_masks(self) -> bool: @@ -637,7 +640,7 @@ def requires_loss_masks(self) -> bool: @property def distillation_model(self) -> str | None: - for loss_type in [ForwardKLLossConfig, ReverseKLLossConfig]: + for loss_type in [ForwardKLDistillationLossConfig, ReverseKLLossConfig]: if loss_type in self._loss_configs: return self._loss_configs[loss_type].distillation_model return None diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 8d705583d..d18ce934e 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -246,12 +246,7 @@ def update_and_add_testing_config( }, "num_blocks": 2, }, - "head": { - "output_weight": init_1, - "losses": { - "lm_loss": {"type": "cross_entropy"}, - }, - }, + "head": {"output_weight": init_1}, "hidden_size": 256, "tied_embedding_weight": True, },