Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions fast_llm/engine/base_model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def preprocess_batch(
phase: PhaseType,
iteration: int,
metrics: dict | None = None,
setup_activation_storage: bool = False,
) -> list[tuple[torch.Tensor, dict]]:
# TODO Move batch splitting elsewhere, align interface with LayerBase
pass
Expand Down
1 change: 0 additions & 1 deletion fast_llm/engine/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,6 @@ def _validate(self) -> None:
# TODO: Add support.
Assert.eq(self.model.distributed.pipeline_parallel, 1)
# TODO: Check if these work.
Assert.eq(self.model.distributed.tensor_parallel, 1)
Assert.eq(self.model.distributed.sequence_data_parallel, 1)
if self.run.experiment_dir is None:
assert not self.training.checkpoint.enabled()
Expand Down
4 changes: 2 additions & 2 deletions fast_llm/functional/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def _fused_cross_entropy_forward_backward(

loss = per_sample_loss.mean()
if target_format != TargetFormat.labels and group is not None:
all_reduce(loss, op=ReduceOp.MEAN, group=group)
all_reduce(loss, op=ReduceOp.AVG, group=group)

return loss, grad

Expand Down Expand Up @@ -277,7 +277,7 @@ def _torch_reverse_kl_forward_backward(
loss = (loss_per_sample * loss_mask).mean()

if group is not None and target_format != TargetFormat.labels:
all_reduce(loss, op=ReduceOp.MEAN, group=group)
all_reduce(loss, op=ReduceOp.AVG, group=group)

if grad_output is not None:
loss.backward(torch.full_like(loss, grad_output))
Expand Down
3 changes: 3 additions & 0 deletions fast_llm/layers/block/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ class BlockKwargs:
sequence_lengths = "sequence_lengths"
# TODO: Belongs elsewhere?
grad_output = "grad_output"
activation_distillation_storage = "activation_distillation_storage"
activation_distillation_targets = "activation_distillation_targets"
activation_distillation_total = "activation_distillation_total"


@config_class(registry=True)
Expand Down
61 changes: 60 additions & 1 deletion fast_llm/layers/decoder/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@
from fast_llm.engine.distributed.distributed import Distributed
from fast_llm.layers.block.block import Block
from fast_llm.layers.block.config import BlockKwargs
from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss
from fast_llm.layers.common.peft.config import PeftConfig
from fast_llm.layers.decoder.config import BlockWithBiasConfig, DecoderBlockConfig
from fast_llm.layers.language_model.head import _format_name
from fast_llm.tensor import TensorMeta
from fast_llm.utils import Assert

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -136,6 +139,9 @@ def forward(
if self._debug.enabled:
self._debug(hidden_states, "norm 1", kwargs[BlockKwargs.hidden_dims], kwargs)
hidden_states, bias = self.mixer(hidden_states, kwargs)

# hidden_states, bias = self.activation_distillation_loss(hidden_states, bias, kwargs, losses)

if self._debug.enabled:
self._debug(
hidden_states if bias is None else hidden_states + bias,
Expand All @@ -150,6 +156,7 @@ def forward(
hidden_states = self.norm_2(input_)
if self._debug.enabled:
self._debug(hidden_states, "norm 2", kwargs[BlockKwargs.hidden_dims], kwargs)
hidden_states, _ = self.activation_distillation_loss(hidden_states, None, kwargs, losses)
hidden_states, bias = self.mlp(hidden_states, kwargs, losses, metrics)
if self._debug.enabled:
self._debug(
Expand All @@ -166,6 +173,42 @@ def forward(
hidden_states = torch.stack((fw_input, hidden_states), dim=0)
return hidden_states

def activation_distillation_loss(self, hidden_states, bias, kwargs, losses):
"""
Maybe apply activation distillation loss and setup backward hooks
"""
mixer_output = hidden_states if bias is None else hidden_states + bias
# Teacher populates mixer activations for distillation.
activation_storage = kwargs.get(BlockKwargs.activation_distillation_storage)
if activation_storage is not None:
activation_storage[self.module_name] = mixer_output.detach()
# Student gets teacher activations and computes the activation-level loss.
activation_targets = kwargs.get(BlockKwargs.activation_distillation_targets)
if (
activation_targets is not None
and self.training
and (teacher_output := activation_targets.pop(self.module_name, None)) is not None
):
# Compare student mixer output with the teacher’s stored activation and accumulate the loss.
teacher_tensor = teacher_output.detach().to(device=mixer_output.device, dtype=mixer_output.dtype)
Assert.eq(teacher_tensor.shape, mixer_output.shape)
# TODO: handle sequence-first?
# TODO: un-scaled loss for reporting? Average loss over layers?
# L2 loss
activation_loss_factor = self._config.activation_distillation_factor
# (batch, sequence, hidden). Take the norm over hidden dim.
# TODO: handle possible padding?
activation_loss = activation_loss_factor * torch.mean(
torch.norm(mixer_output - teacher_tensor, p=2, dim=(2))
)
# Backward hooks
hidden_states = AuxiliaryLoss.apply(hidden_states, activation_loss, 1.0)
bias = AuxiliaryLoss.apply(bias, activation_loss, 1.0) if bias is not None else None
# Logging
if losses is not None and self._activation_distillation_loss_name in losses:
losses[self._activation_distillation_loss_name].append(activation_loss.detach())
return hidden_states, bias

def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int:
# TODO: Add marginal compute? (normalization, bias_dropout_add)
return sum(
Expand All @@ -179,5 +222,21 @@ def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None
self.mixer.preprocess(batch, kwargs)
self.mlp.preprocess(batch, kwargs)

# TODO: add layer_index
_activation_distillation_loss_name = "activation_distillation_loss"

def get_loss_definitions(self, count: int = 1) -> list[LossDef]:
return self.mixer.get_loss_definitions(count=count) + self.mlp.get_loss_definitions(count=count)
loss_definitions = []
if self._config.activation_distillation_factor > 0.0 and self._config.distillation_model is not None:
loss_definitions.append(
LossDef(
name=self._activation_distillation_loss_name,
formatted_name=_format_name(self._activation_distillation_loss_name),
count=count,
)
)
return (
loss_definitions
+ self.mixer.get_loss_definitions(count=count)
+ self.mlp.get_loss_definitions(count=count)
)
16 changes: 16 additions & 0 deletions fast_llm/layers/decoder/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,22 @@ class DecoderBlockConfig(BlockConfig):
hint=FieldHint.feature,
valid=check_field(Assert.geq, 0),
)
distillation_model: str | None = Field(
default=None,
desc="Name of the reference model to use for activation-level distillation.",
hint=FieldHint.feature,
)
activation_distillation_factor: float = Field(
default=0.0,
desc="Factor to scale the activation-level distillation loss by.",
hint=FieldHint.feature,
valid=check_field(Assert.geq, 0),
)

def _validate(self) -> None:
super()._validate()
if self.activation_distillation_factor > 0.0 and self.distillation_model is None:
raise ValueError("Activation distillation requires a distillation_model.")

@property
def layer_class(self) -> "type[DecoderBlock]":
Expand Down
28 changes: 25 additions & 3 deletions fast_llm/models/gpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from fast_llm.engine.inference.runner import InferenceRunner
from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel
from fast_llm.layers.attention.config import AttentionKwargs
from fast_llm.layers.block.config import BlockDimNames
from fast_llm.layers.block.config import BlockDimNames, BlockKwargs
from fast_llm.layers.language_model.config import LanguageModelKwargs
from fast_llm.layers.language_model.language_model import LanguageModel
from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig
Expand Down Expand Up @@ -157,6 +157,7 @@ def preprocess_batch(
phase: PhaseType,
iteration: int,
metrics: dict | None = None,
setup_activation_storage: bool = False,
) -> list[tuple[torch.Tensor, dict]]:
# TODO Move batch splitting elsewhere, align interface with LayerBase
assert self._is_setup
Expand All @@ -175,21 +176,34 @@ def preprocess_batch(
non_blocking=True,
)

# TODO: decoder doesn't necessarily have a `block` attribute
distillation_model = self._config.decoder.block.distillation_model
activation_factor = self._config.decoder.block.activation_distillation_factor
reference_logits: list[dict[str, typing.Any]] | None = None
reference_logits = [{} for _ in preprocessed_meta]
for name, reference_model in self._reference_models.items():
reference_preprocessed_meta = [
(tokens_meta, kwargs_meta["reference_models"][name]) for tokens_meta, kwargs_meta in preprocessed_meta
]

reference_batch = reference_model.fast_llm_model.base_model.preprocess_batch(
batch, reference_preprocessed_meta, phase=PhaseType.inference, iteration=iteration
batch,
reference_preprocessed_meta,
phase=PhaseType.inference,
iteration=iteration,
setup_activation_storage=activation_factor > 0.0 and distillation_model == name,
)

# TODO: Do things work with >1?
Assert.eq(len(reference_batch), len(preprocessed_meta), 1)
for i, (reference_tokens, reference_kwargs) in enumerate(reference_batch):
reference_model.forward(reference_tokens, reference_kwargs, iteration=iteration)
reference_logits[i][f"{name}_logits"] = reference_kwargs["logits"]
if BlockKwargs.activation_distillation_storage in reference_kwargs:
reference_logits[i][f"{name}_activations"] = reference_kwargs[
BlockKwargs.activation_distillation_storage
]
del reference_kwargs[BlockKwargs.activation_distillation_storage]

token_ids = batch.token_ids
if sequence_first:
Expand Down Expand Up @@ -255,7 +269,13 @@ def preprocess_batch(
kwargs[LanguageModelKwargs.loss_mask] = loss_mask
labels = torch.where(loss_mask, labels, -100)
kwargs[LanguageModelKwargs.labels] = labels
kwargs.update(reference_logits[i])
if reference_logits is not None:
reference_payload = reference_logits[i]
kwargs.update(reference_payload)
if distillation_model is not None and activation_factor > 0.0:
teacher_key = f"{distillation_model}_activations"
if teacher_key in reference_payload:
kwargs[BlockKwargs.activation_distillation_targets] = reference_payload.pop(teacher_key)

if batch.chosen_spans is not None:
chosen_valid_spans = []
Expand Down Expand Up @@ -288,6 +308,8 @@ def preprocess_batch(
rejected_valid_spans.append(valid_spans)
kwargs[LanguageModelKwargs.rejected_spans] = rejected_valid_spans

if setup_activation_storage:
kwargs.setdefault(BlockKwargs.activation_distillation_storage, {})
self.preprocess(tokens, kwargs)
preprocessed.append((tokens, kwargs))

Expand Down