Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion fast_llm/layers/common/linear/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,13 @@ def _forward_torch(self, input_: torch.Tensor) -> torch.Tensor:
)[..., : input_.size(1)]
)

def _forward_causal_conv1d(self, input_: torch.Tensor) -> torch.Tensor:
def _forward_causal_conv1d(self, input_: torch.Tensor, **kwargs) -> torch.Tensor:
return _causal_conv1d_fn(
input_,
self.weight.squeeze(1),
self.bias,
activation=(None if self._activation == ActivationType.identity else self._activation.value),
**kwargs,
)

def get_compute_usage(self, input_: TensorMeta, config: ResourceUsageConfig) -> int:
Expand Down
11 changes: 11 additions & 0 deletions fast_llm/layers/common/normalization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,14 @@ def module_class(self):
from fast_llm.layers.common.normalization.normalization import RMSNormalization

return RMSNormalization


@config_class(dynamic_type={NormalizationConfig: "gated_rms_norm"})
class GatedRMSNormalizationConfig(RMSNormalizationConfig):
_abstract = False

@property
def module_class(self):
from fast_llm.layers.common.normalization.normalization import GatedRMSNormalization

return GatedRMSNormalization
42 changes: 42 additions & 0 deletions fast_llm/layers/common/normalization/normalization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import abc

import torch
import torch.nn.functional as F

from fast_llm.config import Configurable
from fast_llm.engine.config_utils.initialization import init_ones_, init_zeros_
Expand All @@ -9,6 +10,7 @@
from fast_llm.functional.config import TritonConfig
from fast_llm.functional.triton.normalization import triton_normalization_autograd
from fast_llm.layers.common.normalization.config import (
GatedRMSNormalizationConfig,
LayerNormalizationConfig,
NoNormalizationConfig,
NormalizationConfig,
Expand All @@ -33,6 +35,12 @@
_fast_normalization_available = False


try:
from fla.modules.fused_norm_gate import rms_norm_gated # noqa
except ImportError:
rms_norm_gated = None


_PERSIST_LN_SIZES = (
1024,
1536,
Expand Down Expand Up @@ -292,3 +300,37 @@ def _forward_fused(self, input_: torch.Tensor) -> torch.Tensor:

def _forward_torch(self, input_: torch.Tensor) -> torch.Tensor:
return torch.rms_norm(input_.to(self.weight.dtype), self._normalized_shape, self.weight, self._config.epsilon)


class GatedRMSNormalization[ConfigType: GatedRMSNormalizationConfig](RMSNormalization[ConfigType], torch.nn.Module):
"""
A gated RMS normalization layer.
"""

def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | None = None):
super().__init__(config, hidden_dim, lr_scale)

if rms_norm_gated is not None:
self._forward = self._forward_fused
else:
self._forward = self._forward

def forward(self, input_: torch.Tensor, gate: torch.Tensor) -> torch.Tensor:
return self._forward(input_.view(-1, *self._normalized_shape), gate).view_as(input_)

def _forward_fused(self, input_: torch.Tensor, gate: torch.Tensor) -> torch.Tensor:
return rms_norm_gated(
input_,
gate,
self.weight,
None,
activation="silu",
eps=self._config.epsilon,
residual=None,
prenorm=False,
residual_in_fp32=False,
)

def _forward(self, input_: torch.Tensor, gate: torch.Tensor) -> torch.Tensor:
normalized = self.rmsnorm(input_)
return normalized * F.silu(gate)
193 changes: 193 additions & 0 deletions fast_llm/layers/ssm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
from fast_llm.config import Field, FieldHint, check_field, config_class
from fast_llm.engine.config_utils.initialization import InitializationConfig, Initializer, LambdaInitializer
from fast_llm.engine.config_utils.parameter import ParameterConfig
from fast_llm.functional.config import ActivationType
from fast_llm.layers.block.config import BlockKwargs
from fast_llm.layers.common.linear.config import AffineLinearConfig, CausalConv1dConfig, LinearConfig
from fast_llm.layers.common.normalization.config import GatedRMSNormalizationConfig
from fast_llm.layers.decoder.config import MixerConfig
from fast_llm.utils import Assert

Expand All @@ -16,6 +19,196 @@
from fast_llm.tensor import ParameterMeta


class LinearAttentionKwargs(BlockKwargs):
cu_seqlens = "cu_seqlens"
seq_idx = "seq_idx"


@config_class(dynamic_type={MixerConfig: "gdn"})
class GatedDeltaNetConfig(MixerConfig):
"""
Configuration for the gated DeltaNet mixer used in Qwen3Next style linear attention blocks.
"""

_abstract = False
normalization: GatedRMSNormalizationConfig = Field(
desc="Configuration for the block normalization layers.",
hint=FieldHint.architecture,
)
qkv_projection_layer: AffineLinearConfig = Field(
desc="Projection that produces query, key, value and modulation vectors.",
hint=FieldHint.architecture,
)
ba_projection_layer: AffineLinearConfig = Field(
desc="Projection that produces the decay and beta terms.",
hint=FieldHint.architecture,
)
convolution_layer: CausalConv1dConfig = Field(
desc="Depth-wise convolution applied to the concatenated QKV streams.",
hint=FieldHint.architecture,
)
output_layer: AffineLinearConfig = Field(
desc="Output projection applied after the DeltaNet recurrence and gated RMS norm.",
hint=FieldHint.architecture,
)
dt_bias_weight: ParameterConfig = Field(
desc="Parameter configuration for the DeltaNet time-step bias.",
hint=FieldHint.architecture,
)
a_log_weight: ParameterConfig = Field(
desc="Parameter configuration for the DeltaNet decay rates.",
hint=FieldHint.architecture,
)

value_heads: int = Field(
default=16,
desc="Number of value heads.",
hint=FieldHint.architecture,
valid=check_field(Assert.gt, 0),
)
key_heads: int = Field(
default=8,
desc="Number of key heads.",
hint=FieldHint.architecture,
valid=check_field(Assert.gt, 0),
)
key_head_dim: int = Field(
default=64,
desc="Dimension of each key head.",
hint=FieldHint.architecture,
valid=check_field(Assert.gt, 0),
)
value_head_dim: int = Field(
default=64,
desc="Dimension of each value head.",
hint=FieldHint.architecture,
valid=check_field(Assert.gt, 0),
)
norm_epsilon: float = Field(
default=1e-6,
desc="Epsilon used by the gated RMS norm.",
hint=FieldHint.architecture,
valid=check_field(Assert.gt, 0),
)
activation: ActivationType = Field(
default=ActivationType.silu,
desc="Activation used after the convolution.",
hint=FieldHint.architecture,
)

def _validate(self) -> None:
super()._validate()
Assert.multiple(self.value_heads, self.key_heads)

@property
def layer_class(self) -> "type":
from fast_llm.layers.ssm.gdn import GatedDeltaNet

return GatedDeltaNet

def _validate(self) -> None:
with self._set_implicit_default():
if "epsilon" not in self.normalization._explicit_fields:
self.normalization.epsilon = 1.0e-5
if "activation" not in self.convolution_layer._explicit_fields:
self.convolution_layer.activation = "silu"
if "kernel_size" not in self.convolution_layer._explicit_fields:
self.convolution_layer.kernel_size = 4

super()._validate()


@config_class(dynamic_type={MixerConfig: "kda"})
class KimiDeltaAttentionConfig(MixerConfig):
"""
Configuration for the KimiDeltaAttention mixer inspired by the Kimi Linear models.
"""

_abstract = False
normalization: GatedRMSNormalizationConfig = Field(
desc="Configuration for the gated normalization applied to the KDA output.",
hint=FieldHint.architecture,
)
q_projection_layer: AffineLinearConfig = Field(
desc="Projection that produces query vectors.",
hint=FieldHint.architecture,
)
k_projection_layer: AffineLinearConfig = Field(
desc="Projection that produces key vectors.",
hint=FieldHint.architecture,
)
v_projection_layer: AffineLinearConfig = Field(
desc="Projection that produces value vectors.",
hint=FieldHint.architecture,
)
f_a_projection_layer: AffineLinearConfig = Field(
desc="Projection used for the Delta gating pre-activation.",
hint=FieldHint.architecture,
)
f_b_projection_layer: AffineLinearConfig = Field(
desc="Projection used for the Delta gating expansion.",
hint=FieldHint.architecture,
)
g_a_projection_layer: AffineLinearConfig = Field(
desc="Projection used for the output gating pre-activation.",
hint=FieldHint.architecture,
)
g_b_projection_layer: AffineLinearConfig = Field(
desc="Projection used for the output gating expansion.",
hint=FieldHint.architecture,
)
beta_projection_layer: AffineLinearConfig = Field(
desc="Projection that produces the Beta gate.",
hint=FieldHint.architecture,
)
output_projection_layer: AffineLinearConfig = Field(
desc="Projection applied after the Delta recurrence and gated normalization.",
hint=FieldHint.architecture,
)
convolution_layer: CausalConv1dConfig = Field(
desc="Depth-wise convolution applied independently on each Q, K and V stream.",
hint=FieldHint.architecture,
)
dt_bias_weight: ParameterConfig = Field(
desc="Parameter configuration for the Delta gate bias.",
hint=FieldHint.architecture,
)
a_log_weight: ParameterConfig = Field(
desc="Parameter configuration for the decay rates.",
hint=FieldHint.architecture,
)

heads: int = Field(
default=16,
desc="Number of attention heads.",
hint=FieldHint.architecture,
valid=check_field(Assert.gt, 0),
)
head_dim: int = Field(
default=64,
desc="Dimension of each head.",
hint=FieldHint.architecture,
valid=check_field(Assert.gt, 0),
)

@property
def layer_class(self) -> "type":
from fast_llm.layers.ssm.kda import KimiDeltaAttention

return KimiDeltaAttention

def _validate(self) -> None:
with self._set_implicit_default():
if "epsilon" not in self.normalization._explicit_fields:
self.normalization.epsilon = 1.0e-5
if "activation" not in self.convolution_layer._explicit_fields:
self.convolution_layer.activation = "silu"
if "kernel_size" not in self.convolution_layer._explicit_fields:
self.convolution_layer.kernel_size = 4

super()._validate()


@config_class()
class SSMConfig(MixerConfig):
# Layers
Expand Down
Loading