diff --git a/fast_llm/engine/base_model/base_model.py b/fast_llm/engine/base_model/base_model.py index 5df59d4cd..106bea217 100644 --- a/fast_llm/engine/base_model/base_model.py +++ b/fast_llm/engine/base_model/base_model.py @@ -145,6 +145,7 @@ def preprocess_batch( phase: PhaseType, iteration: int, metrics: dict | None = None, + setup_activation_storage: bool = False, ) -> list[tuple[torch.Tensor, dict]]: # TODO Move batch splitting elsewhere, align interface with LayerBase pass diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index 531bc206b..867cca984 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -361,7 +361,6 @@ def _validate(self) -> None: # TODO: Add support. Assert.eq(self.model.distributed.pipeline_parallel, 1) # TODO: Check if these work. - Assert.eq(self.model.distributed.tensor_parallel, 1) Assert.eq(self.model.distributed.sequence_data_parallel, 1) if self.run.experiment_dir is None: assert not self.training.checkpoint.enabled() diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index d56dce98d..c22319c17 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -151,7 +151,7 @@ def _fused_cross_entropy_forward_backward( loss = per_sample_loss.mean() if target_format != TargetFormat.labels and group is not None: - all_reduce(loss, op=ReduceOp.MEAN, group=group) + all_reduce(loss, op=ReduceOp.AVG, group=group) return loss, grad @@ -277,7 +277,7 @@ def _torch_reverse_kl_forward_backward( loss = (loss_per_sample * loss_mask).mean() if group is not None and target_format != TargetFormat.labels: - all_reduce(loss, op=ReduceOp.MEAN, group=group) + all_reduce(loss, op=ReduceOp.AVG, group=group) if grad_output is not None: loss.backward(torch.full_like(loss, grad_output)) diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index f3e93edeb..dfc80a470 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -37,6 +37,9 @@ class BlockKwargs: sequence_lengths = "sequence_lengths" # TODO: Belongs elsewhere? grad_output = "grad_output" + activation_distillation_storage = "activation_distillation_storage" + activation_distillation_targets = "activation_distillation_targets" + activation_distillation_total = "activation_distillation_total" @config_class(registry=True) diff --git a/fast_llm/layers/decoder/block.py b/fast_llm/layers/decoder/block.py index 8b19db66a..e9f27f404 100644 --- a/fast_llm/layers/decoder/block.py +++ b/fast_llm/layers/decoder/block.py @@ -11,9 +11,12 @@ from fast_llm.engine.distributed.distributed import Distributed from fast_llm.layers.block.block import Block from fast_llm.layers.block.config import BlockKwargs +from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.config import BlockWithBiasConfig, DecoderBlockConfig +from fast_llm.layers.language_model.head import _format_name from fast_llm.tensor import TensorMeta +from fast_llm.utils import Assert logger = logging.getLogger(__name__) @@ -136,6 +139,9 @@ def forward( if self._debug.enabled: self._debug(hidden_states, "norm 1", kwargs[BlockKwargs.hidden_dims], kwargs) hidden_states, bias = self.mixer(hidden_states, kwargs) + + # hidden_states, bias = self.activation_distillation_loss(hidden_states, bias, kwargs, losses) + if self._debug.enabled: self._debug( hidden_states if bias is None else hidden_states + bias, @@ -150,6 +156,7 @@ def forward( hidden_states = self.norm_2(input_) if self._debug.enabled: self._debug(hidden_states, "norm 2", kwargs[BlockKwargs.hidden_dims], kwargs) + hidden_states, _ = self.activation_distillation_loss(hidden_states, None, kwargs, losses) hidden_states, bias = self.mlp(hidden_states, kwargs, losses, metrics) if self._debug.enabled: self._debug( @@ -166,6 +173,42 @@ def forward( hidden_states = torch.stack((fw_input, hidden_states), dim=0) return hidden_states + def activation_distillation_loss(self, hidden_states, bias, kwargs, losses): + """ + Maybe apply activation distillation loss and setup backward hooks + """ + mixer_output = hidden_states if bias is None else hidden_states + bias + # Teacher populates mixer activations for distillation. + activation_storage = kwargs.get(BlockKwargs.activation_distillation_storage) + if activation_storage is not None: + activation_storage[self.module_name] = mixer_output.detach() + # Student gets teacher activations and computes the activation-level loss. + activation_targets = kwargs.get(BlockKwargs.activation_distillation_targets) + if ( + activation_targets is not None + and self.training + and (teacher_output := activation_targets.pop(self.module_name, None)) is not None + ): + # Compare student mixer output with the teacher’s stored activation and accumulate the loss. + teacher_tensor = teacher_output.detach().to(device=mixer_output.device, dtype=mixer_output.dtype) + Assert.eq(teacher_tensor.shape, mixer_output.shape) + # TODO: handle sequence-first? + # TODO: un-scaled loss for reporting? Average loss over layers? + # L2 loss + activation_loss_factor = self._config.activation_distillation_factor + # (batch, sequence, hidden). Take the norm over hidden dim. + # TODO: handle possible padding? + activation_loss = activation_loss_factor * torch.mean( + torch.norm(mixer_output - teacher_tensor, p=2, dim=(2)) + ) + # Backward hooks + hidden_states = AuxiliaryLoss.apply(hidden_states, activation_loss, 1.0) + bias = AuxiliaryLoss.apply(bias, activation_loss, 1.0) if bias is not None else None + # Logging + if losses is not None and self._activation_distillation_loss_name in losses: + losses[self._activation_distillation_loss_name].append(activation_loss.detach()) + return hidden_states, bias + def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: # TODO: Add marginal compute? (normalization, bias_dropout_add) return sum( @@ -179,5 +222,21 @@ def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None self.mixer.preprocess(batch, kwargs) self.mlp.preprocess(batch, kwargs) + # TODO: add layer_index + _activation_distillation_loss_name = "activation_distillation_loss" + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: - return self.mixer.get_loss_definitions(count=count) + self.mlp.get_loss_definitions(count=count) + loss_definitions = [] + if self._config.activation_distillation_factor > 0.0 and self._config.distillation_model is not None: + loss_definitions.append( + LossDef( + name=self._activation_distillation_loss_name, + formatted_name=_format_name(self._activation_distillation_loss_name), + count=count, + ) + ) + return ( + loss_definitions + + self.mixer.get_loss_definitions(count=count) + + self.mlp.get_loss_definitions(count=count) + ) diff --git a/fast_llm/layers/decoder/config.py b/fast_llm/layers/decoder/config.py index 403b204c8..99331ee7a 100644 --- a/fast_llm/layers/decoder/config.py +++ b/fast_llm/layers/decoder/config.py @@ -88,6 +88,22 @@ class DecoderBlockConfig(BlockConfig): hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) + distillation_model: str | None = Field( + default=None, + desc="Name of the reference model to use for activation-level distillation.", + hint=FieldHint.feature, + ) + activation_distillation_factor: float = Field( + default=0.0, + desc="Factor to scale the activation-level distillation loss by.", + hint=FieldHint.feature, + valid=check_field(Assert.geq, 0), + ) + + def _validate(self) -> None: + super()._validate() + if self.activation_distillation_factor > 0.0 and self.distillation_model is None: + raise ValueError("Activation distillation requires a distillation_model.") @property def layer_class(self) -> "type[DecoderBlock]": diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index efa348ecb..17187d0b2 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -10,7 +10,7 @@ from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.layers.attention.config import AttentionKwargs -from fast_llm.layers.block.config import BlockDimNames +from fast_llm.layers.block.config import BlockDimNames, BlockKwargs from fast_llm.layers.language_model.config import LanguageModelKwargs from fast_llm.layers.language_model.language_model import LanguageModel from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig @@ -157,6 +157,7 @@ def preprocess_batch( phase: PhaseType, iteration: int, metrics: dict | None = None, + setup_activation_storage: bool = False, ) -> list[tuple[torch.Tensor, dict]]: # TODO Move batch splitting elsewhere, align interface with LayerBase assert self._is_setup @@ -175,6 +176,10 @@ def preprocess_batch( non_blocking=True, ) + # TODO: decoder doesn't necessarily have a `block` attribute + distillation_model = self._config.decoder.block.distillation_model + activation_factor = self._config.decoder.block.activation_distillation_factor + reference_logits: list[dict[str, typing.Any]] | None = None reference_logits = [{} for _ in preprocessed_meta] for name, reference_model in self._reference_models.items(): reference_preprocessed_meta = [ @@ -182,7 +187,11 @@ def preprocess_batch( ] reference_batch = reference_model.fast_llm_model.base_model.preprocess_batch( - batch, reference_preprocessed_meta, phase=PhaseType.inference, iteration=iteration + batch, + reference_preprocessed_meta, + phase=PhaseType.inference, + iteration=iteration, + setup_activation_storage=activation_factor > 0.0 and distillation_model == name, ) # TODO: Do things work with >1? @@ -190,6 +199,11 @@ def preprocess_batch( for i, (reference_tokens, reference_kwargs) in enumerate(reference_batch): reference_model.forward(reference_tokens, reference_kwargs, iteration=iteration) reference_logits[i][f"{name}_logits"] = reference_kwargs["logits"] + if BlockKwargs.activation_distillation_storage in reference_kwargs: + reference_logits[i][f"{name}_activations"] = reference_kwargs[ + BlockKwargs.activation_distillation_storage + ] + del reference_kwargs[BlockKwargs.activation_distillation_storage] token_ids = batch.token_ids if sequence_first: @@ -255,7 +269,13 @@ def preprocess_batch( kwargs[LanguageModelKwargs.loss_mask] = loss_mask labels = torch.where(loss_mask, labels, -100) kwargs[LanguageModelKwargs.labels] = labels - kwargs.update(reference_logits[i]) + if reference_logits is not None: + reference_payload = reference_logits[i] + kwargs.update(reference_payload) + if distillation_model is not None and activation_factor > 0.0: + teacher_key = f"{distillation_model}_activations" + if teacher_key in reference_payload: + kwargs[BlockKwargs.activation_distillation_targets] = reference_payload.pop(teacher_key) if batch.chosen_spans is not None: chosen_valid_spans = [] @@ -288,6 +308,8 @@ def preprocess_batch( rejected_valid_spans.append(valid_spans) kwargs[LanguageModelKwargs.rejected_spans] = rejected_valid_spans + if setup_activation_storage: + kwargs.setdefault(BlockKwargs.activation_distillation_storage, {}) self.preprocess(tokens, kwargs) preprocessed.append((tokens, kwargs))