-
Notifications
You must be signed in to change notification settings - Fork 42
Refactor lm_head losses #425
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 29 commits
c335f6e
e06a4b2
179ae25
af456f0
9968aac
945c5a7
4b6e3d7
c5fefa0
4119596
b55a0a4
097baeb
d773d98
35400c1
282925c
0f73ea2
04a0193
fa85c41
feb416e
31cfb84
24fe67b
00f6118
0cadf98
0e562e9
2a474e2
52c1c11
406d0a2
f25380a
8adb7dd
1ce641d
95f14af
5ad4c0c
0a66e14
f8f7041
ab9c917
89470dc
6e54c93
8137b8c
3c8f3c2
705c482
3c8ce50
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -85,6 +85,7 @@ def _fused_cross_entropy_forward_backward( | |
| target_format: TargetFormat, | ||
| group: ProcessGroup | None = None, | ||
| teacher_softmax_temperature: float = 1.0, | ||
| return_target_entropy: bool = False, | ||
| ) -> tuple[torch.Tensor, torch.Tensor | None]: | ||
| """ | ||
| A fused implementation of cross-entropy with torch compile. | ||
|
|
@@ -158,6 +159,16 @@ def _fused_cross_entropy_forward_backward( | |
| loss = per_sample_loss.mean() | ||
| if target_format != TargetFormat.labels and group is not None: | ||
| all_reduce(loss, op=ReduceOp.AVG, group=group) | ||
| if return_target_entropy and target_format == TargetFormat.logits: | ||
|
||
| # Compute teacher entropy | ||
| teacher_log_prob = torch.log(target + 1e-20) | ||
|
||
| target_entropy = -(target * teacher_log_prob).sum(dim=-1) | ||
| if loss_mask is not None: | ||
| target_entropy = target_entropy * loss_mask.squeeze(-1) | ||
| target_entropy = target_entropy.mean() | ||
| if group is not None: | ||
| all_reduce(target_entropy, op=ReduceOp.SUM, group=group) | ||
| return loss, grad, target_entropy | ||
|
|
||
| return loss, grad | ||
|
|
||
|
|
@@ -236,7 +247,6 @@ def _reverse_kl_forward_backward( | |
| group: ProcessGroup | None = None, | ||
| logits_scale_factor: float = 1.0, | ||
| teacher_softmax_temperature: float = 1.0, | ||
| **kwargs, | ||
| ) -> tuple[torch.Tensor, torch.Tensor | None]: | ||
| """ | ||
| Reverse KL using PyTorch's native kl_div function. | ||
|
|
@@ -359,3 +369,53 @@ def reverse_kl_forward_backward( | |
| group=group, | ||
| ) | ||
| return distillation_loss, distillation_grad | ||
|
|
||
|
|
||
| def forward_kl_forward_backward( | ||
| logits: torch.Tensor, | ||
| target: torch.Tensor, | ||
| loss_mask: torch.Tensor | None, | ||
| grad_output: float | None, | ||
| group: ProcessGroup | None = None, | ||
| logits_scale_factor: float = 1.0, | ||
| teacher_softmax_temperature: float = 1.0, | ||
| target_format: TargetFormat = TargetFormat.labels, | ||
| sequence_parallel_logits: bool = False, | ||
| ) -> tuple[torch.Tensor, torch.Tensor | None]: | ||
| """ | ||
| Compute forward KL divergence: KL(p||q) where p is the target distribution (teacher) and q is the predicted (student). | ||
| This is mode-covering (vs. mode-seeking for reverse KL) and useful for: | ||
| - Encouraging the model to cover all modes of the target distribution | ||
| - Spreading probability mass broadly across the target support | ||
| - Standard distillation scenarios where you want to match the full teacher distribution | ||
|
|
||
| Key differences from reverse KL: | ||
| - Forward KL: KL(p||q) = mode-covering (spreads mass broadly) | ||
| - Reverse KL: KL(q||p) = mode-seeking (focuses on target modes) | ||
|
|
||
| Takes: | ||
| logits: [BxS, V] or [B, S, V], where V is local vocab size | ||
| target: [BxS, V] or [B, S, V] (logits format) | ||
| loss_mask: [BxS] or [B, S] or None | ||
| ... | ||
|
|
||
| Returns: | ||
| loss: Forward KL divergence loss | ||
| grad: Gradients w.r.t. logits | ||
| """ | ||
| assert target_format == TargetFormat.logits, "Forward KL only supports logits format" | ||
| Assert.eq(target.shape, logits.shape) | ||
| distillation_loss, distillation_grad, teacher_entropy = _fused_cross_entropy_forward_backward( | ||
| logits=logits, | ||
| target=target, | ||
| loss_mask=loss_mask, | ||
| grad_output=grad_output, | ||
| logits_scale_factor=logits_scale_factor, | ||
| target_format=target_format, | ||
| group=group, | ||
| teacher_softmax_temperature=teacher_softmax_temperature, | ||
| return_target_entropy=True, | ||
| ) | ||
| distillation_loss -= teacher_entropy | ||
|
|
||
| return distillation_loss, distillation_grad | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,11 +5,11 @@ | |
| from fast_llm.engine.config_utils.parameter import OptionalParameterConfig, ParameterConfig, combine_lr_scales | ||
| from fast_llm.engine.config_utils.tensor_dim import TensorDim | ||
| from fast_llm.engine.distributed.config import DistributedConfig | ||
| from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl | ||
| from fast_llm.layers.block.config import BlockConfig, BlockKwargs, BlockSequenceConfig | ||
| from fast_llm.layers.block.config import BlockConfig, BlockSequenceConfig | ||
| from fast_llm.layers.common.normalization.config import NormalizationConfig | ||
| from fast_llm.layers.common.peft.config import PeftConfig | ||
| from fast_llm.layers.decoder.config import DecoderBlockConfig | ||
| from fast_llm.layers.language_model.lm_head_losses import LanguageModelLossConfig | ||
| from fast_llm.utils import Assert | ||
|
|
||
| if typing.TYPE_CHECKING: | ||
|
|
@@ -19,21 +19,6 @@ | |
| from fast_llm.layers.language_model.multi_token_prediction import MultiTokenPrediction | ||
|
|
||
|
|
||
| class LanguageModelKwargs(BlockKwargs): | ||
| token_ids = "token_ids" | ||
| position_ids = "position_ids" | ||
| token_map = "token_map" | ||
| sample_map = "sample_map" | ||
| embedding_map = "embedding_map" | ||
| # TODO: These are generic | ||
| labels = "labels" | ||
| phase = "phase" | ||
| chosen_spans = "chosen_spans" | ||
| rejected_spans = "rejected_spans" | ||
| loss_mask = "loss_mask" | ||
| mask_inputs = "mask_inputs" | ||
|
|
||
|
|
||
| @config_class() | ||
| class LanguageModelEmbeddingsConfig(BlockConfig): | ||
| _abstract = False | ||
|
|
@@ -135,44 +120,22 @@ class LanguageModelHeadConfig(LanguageModelHeadBaseConfig): | |
| desc="Configuration for the final normalization layer.", | ||
| hint=FieldHint.architecture, | ||
| ) | ||
| losses: dict[str, LanguageModelLossConfig] = Field( | ||
| default_factory=dict, | ||
| desc="A dictionary of loss names and their configurations.", | ||
| hint=FieldHint.core, | ||
| ) | ||
| # TODO: Cleanup | ||
| output_weight: ParameterConfig = Field( | ||
| desc="Configuration for the LM output layer (weight). Ignored for tied embeddings", | ||
| hint=FieldHint.architecture, | ||
| ) | ||
| cross_entropy_implementation: CrossEntropyImpl = Field( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These removals are likely to cause backward compatibility issues when loading existing models. Please make sure it doesn't disrupt ongoing work, and if needed add backward compatibility in
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tested training with checkpoints created on the main branch in both |
||
| default=CrossEntropyImpl.auto, | ||
| desc="Implementation for the cross-entropy computation.", | ||
| hint=FieldHint.performance, | ||
| ) | ||
| distillation_loss_implementation: DistillationLossImpl = Field( | ||
| default=DistillationLossImpl.cross_entropy, | ||
| desc="Implementation for the distillation cross-entropy computation.", | ||
| hint=FieldHint.performance, | ||
| ) | ||
| cross_entropy_splits: int | None = Field( | ||
| default=None, | ||
| desc="Split the logit and cross-entropy computation into this many fragment, to reduce memory usage.", | ||
| hint=FieldHint.feature, | ||
| valid=skip_valid_if_none(check_field(Assert.gt, 0)), | ||
| ) | ||
| logit_z_loss: float = Field( | ||
| default=0.0, | ||
| desc="Regularize the logits with Z-loss.", | ||
| doc="We recommend 1e-4 for stability, as used for training PaLM.", | ||
| hint=FieldHint.feature, | ||
| valid=check_field(Assert.geq, 0), | ||
| ) | ||
| language_model_loss_factor: float = Field( | ||
| default=None, | ||
| desc="Factor to scale the language modeling loss by when using distillation.", | ||
| hint=FieldHint.feature, | ||
| ) | ||
| distillation_loss_factor: float = Field( | ||
| default=1.0, | ||
| desc="Factor to scale the distillation loss by when using distillation.", | ||
| hint=FieldHint.feature, | ||
| ) | ||
| logits_scale_factor: float = Field( | ||
| default=1.0, | ||
| desc="Multiply output logits by scale factor.", | ||
|
|
@@ -181,10 +144,10 @@ class LanguageModelHeadConfig(LanguageModelHeadBaseConfig): | |
| hint=FieldHint.feature, | ||
| valid=check_field(Assert.geq, 0), | ||
| ) | ||
| teacher_softmax_temperature: float = Field( | ||
| default=1.0, | ||
| desc="Divides distillation target logits by this factor.", | ||
| doc="Divides distillation target logits by this factor.", | ||
| logit_z_loss: float = Field( | ||
| default=0.0, | ||
| desc="Regularize the logits with Z-loss.", | ||
| doc="We recommend 1e-4 for stability, as used for training PaLM.", | ||
| hint=FieldHint.feature, | ||
| valid=check_field(Assert.geq, 0), | ||
| ) | ||
|
|
@@ -193,11 +156,6 @@ class LanguageModelHeadConfig(LanguageModelHeadBaseConfig): | |
| desc="Name of the reference model to use for dpo.", | ||
| hint=FieldHint.feature, | ||
| ) | ||
| dpo_beta: float | None = Field( | ||
| default=1.0, | ||
| desc="Beta value for DPO loss.", | ||
| hint=FieldHint.feature, | ||
| ) | ||
| distillation_model: str | None = Field( | ||
| default=None, | ||
| desc="Name of the reference model to use for knowledge distillation." | ||
|
|
@@ -237,11 +195,19 @@ def layer_class(self) -> "type[LanguageModelHead]": | |
|
|
||
| def _validate(self) -> None: | ||
| with self._set_implicit_default(): | ||
| if self.language_model_loss_factor is None: | ||
| if self.distillation_model is None: | ||
| self.language_model_loss_factor = 1.0 | ||
| else: | ||
| self.language_model_loss_factor = 0.0 | ||
| if not self.losses: | ||
| if "losses" not in self._explicit_fields: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure it's needed, it doesn't make sense to have a head without loss. Can simplify to |
||
| self.losses = { | ||
| "lm_loss": LanguageModelLossConfig._from_dict( | ||
| { | ||
| "type": "cross_entropy", | ||
| "weight": 1.0, | ||
| } | ||
| ) | ||
| } | ||
| for loss_config in self.losses.values(): | ||
| if "distillation" in loss_config.type: | ||
| assert self.distillation_model is not None, "Distillation loss requires a distillation model." | ||
|
||
| super()._validate() | ||
| assert self.dpo_reference_model is None or self.distillation_model is None # currently don't support both | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return_kl_lossinstead?Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_fused_cross_entropy_forward_backward, as the name implies, should not returnkl loss. I find this more explicit.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
KL and cross-entropy are basically the same thing, and this is a private method anyway so the name is not that important. I'm more worried about the inconsistent return type and extra complexity this is creating.