diff --git a/fast_llm/layers/common/linear/convolution.py b/fast_llm/layers/common/linear/convolution.py index b88b7b2e..69018fd0 100644 --- a/fast_llm/layers/common/linear/convolution.py +++ b/fast_llm/layers/common/linear/convolution.py @@ -45,12 +45,13 @@ def _forward_torch(self, input_: torch.Tensor) -> torch.Tensor: )[..., : input_.size(1)] ) - def _forward_causal_conv1d(self, input_: torch.Tensor) -> torch.Tensor: + def _forward_causal_conv1d(self, input_: torch.Tensor, **kwargs) -> torch.Tensor: return _causal_conv1d_fn( input_, self.weight.squeeze(1), self.bias, activation=(None if self._activation == ActivationType.identity else self._activation.value), + **kwargs, ) def get_compute_usage(self, input_: TensorMeta, config: ResourceUsageConfig) -> int: diff --git a/fast_llm/layers/common/normalization/config.py b/fast_llm/layers/common/normalization/config.py index a80a1928..4ecb7a3b 100644 --- a/fast_llm/layers/common/normalization/config.py +++ b/fast_llm/layers/common/normalization/config.py @@ -127,3 +127,14 @@ def module_class(self): from fast_llm.layers.common.normalization.normalization import RMSNormalization return RMSNormalization + + +@config_class(dynamic_type={NormalizationConfig: "gated_rms_norm"}) +class GatedRMSNormalizationConfig(RMSNormalizationConfig): + _abstract = False + + @property + def module_class(self): + from fast_llm.layers.common.normalization.normalization import GatedRMSNormalization + + return GatedRMSNormalization diff --git a/fast_llm/layers/common/normalization/normalization.py b/fast_llm/layers/common/normalization/normalization.py index d0a5ab15..ec8a52e2 100644 --- a/fast_llm/layers/common/normalization/normalization.py +++ b/fast_llm/layers/common/normalization/normalization.py @@ -1,6 +1,7 @@ import abc import torch +import torch.nn.functional as F from fast_llm.config import Configurable from fast_llm.engine.config_utils.initialization import init_ones_, init_zeros_ @@ -9,6 +10,7 @@ from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.normalization import triton_normalization_autograd from fast_llm.layers.common.normalization.config import ( + GatedRMSNormalizationConfig, LayerNormalizationConfig, NoNormalizationConfig, NormalizationConfig, @@ -33,6 +35,12 @@ _fast_normalization_available = False +try: + from fla.modules.fused_norm_gate import rms_norm_gated # noqa +except ImportError: + rms_norm_gated = None + + _PERSIST_LN_SIZES = ( 1024, 1536, @@ -292,3 +300,37 @@ def _forward_fused(self, input_: torch.Tensor) -> torch.Tensor: def _forward_torch(self, input_: torch.Tensor) -> torch.Tensor: return torch.rms_norm(input_.to(self.weight.dtype), self._normalized_shape, self.weight, self._config.epsilon) + + +class GatedRMSNormalization[ConfigType: GatedRMSNormalizationConfig](RMSNormalization[ConfigType], torch.nn.Module): + """ + A gated RMS normalization layer. + """ + + def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | None = None): + super().__init__(config, hidden_dim, lr_scale) + + if rms_norm_gated is not None: + self._forward = self._forward_fused + else: + self._forward = self._forward + + def forward(self, input_: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: + return self._forward(input_.view(-1, *self._normalized_shape), gate).view_as(input_) + + def _forward_fused(self, input_: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: + return rms_norm_gated( + input_, + gate, + self.weight, + None, + activation="silu", + eps=self._config.epsilon, + residual=None, + prenorm=False, + residual_in_fp32=False, + ) + + def _forward(self, input_: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: + normalized = self.rmsnorm(input_) + return normalized * F.silu(gate) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index e541341e..2fa90aff 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -4,7 +4,10 @@ from fast_llm.config import Field, FieldHint, check_field, config_class from fast_llm.engine.config_utils.initialization import InitializationConfig, Initializer, LambdaInitializer from fast_llm.engine.config_utils.parameter import ParameterConfig +from fast_llm.functional.config import ActivationType +from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.common.linear.config import AffineLinearConfig, CausalConv1dConfig, LinearConfig +from fast_llm.layers.common.normalization.config import GatedRMSNormalizationConfig from fast_llm.layers.decoder.config import MixerConfig from fast_llm.utils import Assert @@ -16,6 +19,196 @@ from fast_llm.tensor import ParameterMeta +class LinearAttentionKwargs(BlockKwargs): + cu_seqlens = "cu_seqlens" + seq_idx = "seq_idx" + + +@config_class(dynamic_type={MixerConfig: "gdn"}) +class GatedDeltaNetConfig(MixerConfig): + """ + Configuration for the gated DeltaNet mixer used in Qwen3Next style linear attention blocks. + """ + + _abstract = False + normalization: GatedRMSNormalizationConfig = Field( + desc="Configuration for the block normalization layers.", + hint=FieldHint.architecture, + ) + qkv_projection_layer: AffineLinearConfig = Field( + desc="Projection that produces query, key, value and modulation vectors.", + hint=FieldHint.architecture, + ) + ba_projection_layer: AffineLinearConfig = Field( + desc="Projection that produces the decay and beta terms.", + hint=FieldHint.architecture, + ) + convolution_layer: CausalConv1dConfig = Field( + desc="Depth-wise convolution applied to the concatenated QKV streams.", + hint=FieldHint.architecture, + ) + output_layer: AffineLinearConfig = Field( + desc="Output projection applied after the DeltaNet recurrence and gated RMS norm.", + hint=FieldHint.architecture, + ) + dt_bias_weight: ParameterConfig = Field( + desc="Parameter configuration for the DeltaNet time-step bias.", + hint=FieldHint.architecture, + ) + a_log_weight: ParameterConfig = Field( + desc="Parameter configuration for the DeltaNet decay rates.", + hint=FieldHint.architecture, + ) + + value_heads: int = Field( + default=16, + desc="Number of value heads.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + key_heads: int = Field( + default=8, + desc="Number of key heads.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + key_head_dim: int = Field( + default=64, + desc="Dimension of each key head.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + value_head_dim: int = Field( + default=64, + desc="Dimension of each value head.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + norm_epsilon: float = Field( + default=1e-6, + desc="Epsilon used by the gated RMS norm.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + activation: ActivationType = Field( + default=ActivationType.silu, + desc="Activation used after the convolution.", + hint=FieldHint.architecture, + ) + + def _validate(self) -> None: + super()._validate() + Assert.multiple(self.value_heads, self.key_heads) + + @property + def layer_class(self) -> "type": + from fast_llm.layers.ssm.gdn import GatedDeltaNet + + return GatedDeltaNet + + def _validate(self) -> None: + with self._set_implicit_default(): + if "epsilon" not in self.normalization._explicit_fields: + self.normalization.epsilon = 1.0e-5 + if "activation" not in self.convolution_layer._explicit_fields: + self.convolution_layer.activation = "silu" + if "kernel_size" not in self.convolution_layer._explicit_fields: + self.convolution_layer.kernel_size = 4 + + super()._validate() + + +@config_class(dynamic_type={MixerConfig: "kda"}) +class KimiDeltaAttentionConfig(MixerConfig): + """ + Configuration for the KimiDeltaAttention mixer inspired by the Kimi Linear models. + """ + + _abstract = False + normalization: GatedRMSNormalizationConfig = Field( + desc="Configuration for the gated normalization applied to the KDA output.", + hint=FieldHint.architecture, + ) + q_projection_layer: AffineLinearConfig = Field( + desc="Projection that produces query vectors.", + hint=FieldHint.architecture, + ) + k_projection_layer: AffineLinearConfig = Field( + desc="Projection that produces key vectors.", + hint=FieldHint.architecture, + ) + v_projection_layer: AffineLinearConfig = Field( + desc="Projection that produces value vectors.", + hint=FieldHint.architecture, + ) + f_a_projection_layer: AffineLinearConfig = Field( + desc="Projection used for the Delta gating pre-activation.", + hint=FieldHint.architecture, + ) + f_b_projection_layer: AffineLinearConfig = Field( + desc="Projection used for the Delta gating expansion.", + hint=FieldHint.architecture, + ) + g_a_projection_layer: AffineLinearConfig = Field( + desc="Projection used for the output gating pre-activation.", + hint=FieldHint.architecture, + ) + g_b_projection_layer: AffineLinearConfig = Field( + desc="Projection used for the output gating expansion.", + hint=FieldHint.architecture, + ) + beta_projection_layer: AffineLinearConfig = Field( + desc="Projection that produces the Beta gate.", + hint=FieldHint.architecture, + ) + output_projection_layer: AffineLinearConfig = Field( + desc="Projection applied after the Delta recurrence and gated normalization.", + hint=FieldHint.architecture, + ) + convolution_layer: CausalConv1dConfig = Field( + desc="Depth-wise convolution applied independently on each Q, K and V stream.", + hint=FieldHint.architecture, + ) + dt_bias_weight: ParameterConfig = Field( + desc="Parameter configuration for the Delta gate bias.", + hint=FieldHint.architecture, + ) + a_log_weight: ParameterConfig = Field( + desc="Parameter configuration for the decay rates.", + hint=FieldHint.architecture, + ) + + heads: int = Field( + default=16, + desc="Number of attention heads.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + head_dim: int = Field( + default=64, + desc="Dimension of each head.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + + @property + def layer_class(self) -> "type": + from fast_llm.layers.ssm.kda import KimiDeltaAttention + + return KimiDeltaAttention + + def _validate(self) -> None: + with self._set_implicit_default(): + if "epsilon" not in self.normalization._explicit_fields: + self.normalization.epsilon = 1.0e-5 + if "activation" not in self.convolution_layer._explicit_fields: + self.convolution_layer.activation = "silu" + if "kernel_size" not in self.convolution_layer._explicit_fields: + self.convolution_layer.kernel_size = 4 + + super()._validate() + + @config_class() class SSMConfig(MixerConfig): # Layers diff --git a/fast_llm/layers/ssm/gdn.py b/fast_llm/layers/ssm/gdn.py new file mode 100644 index 00000000..d07cb5e2 --- /dev/null +++ b/fast_llm/layers/ssm/gdn.py @@ -0,0 +1,337 @@ +import logging +import typing + +import torch +import torch.nn.functional as F + +from fast_llm.engine.base_model.config import ResourceUsageConfig +from fast_llm.engine.config_utils.initialization import LambdaInitializer, init_normal_, init_ones_ +from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim +from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames +from fast_llm.layers.block.config import BlockKwargs +from fast_llm.layers.common.peft.config import PeftConfig +from fast_llm.layers.decoder.block import BlockWithBias +from fast_llm.layers.ssm.config import GatedDeltaNetConfig +from fast_llm.tensor import ParameterMeta, TensorMeta +from fast_llm.utils import div + +logger = logging.getLogger(__name__) + +try: + from fla.ops.gated_delta_rule import chunk_gated_delta_rule +except ImportError: + chunk_gated_delta_rule = None + + +is_fast_path_available = chunk_gated_delta_rule is not None + + +def _l2norm(x: torch.Tensor, dim: int = -1, eps: float = 1e-6) -> torch.Tensor: + return x * torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps) + + +def torch_chunk_gated_delta_rule( + query, + key, + value, + g, + beta, + chunk_size=64, + initial_state=None, + output_final_state=False, + use_qk_l2norm_in_kernel=False, +): + initial_dtype = query.dtype + if use_qk_l2norm_in_kernel: + query = _l2norm(query, dim=-1, eps=1e-6) + key = _l2norm(key, dim=-1, eps=1e-6) + query, key, value, beta, g = ( + x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g) + ) + + batch_size, num_heads, sequence_length, k_head_dim = key.shape + v_head_dim = value.shape[-1] + pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size + query = F.pad(query, (0, 0, 0, pad_size)) + key = F.pad(key, (0, 0, 0, pad_size)) + value = F.pad(value, (0, 0, 0, pad_size)) + beta = F.pad(beta, (0, pad_size)) + g = F.pad(g, (0, pad_size)) + total_sequence_length = sequence_length + pad_size + scale = 1 / (query.shape[-1] ** 0.5) + query = query * scale + + v_beta = value * beta.unsqueeze(-1) + k_beta = key * beta.unsqueeze(-1) + # reshape to chunks + query, key, value, k_beta, v_beta = ( + x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) for x in (query, key, value, k_beta, v_beta) + ) + g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size) + mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=0) + + # chunk decay + g = g.cumsum(dim=-1) + decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril() + attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) + for i in range(1, chunk_size): + row = attn[..., i, :i].clone() + sub = attn[..., :i, :i].clone() + attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) + attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) + value = attn @ v_beta + k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) + last_recurrent_state = ( + torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value) + if initial_state is None + else initial_state.to(value) + ) + core_attn_out = torch.zeros_like(value) + mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1) + + # for each chunk + for i in range(0, total_sequence_length // chunk_size): + q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i] + attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0) + v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state + v_new = v_i - v_prime + attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state + core_attn_out[:, :, i] = attn_inter + attn @ v_new + last_recurrent_state = ( + last_recurrent_state * g[:, :, i, -1, None, None].exp() + + (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(-1, -2) @ v_new + ) + + if not output_final_state: + last_recurrent_state = None + core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1]) + core_attn_out = core_attn_out[:, :, :sequence_length] + core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) + return core_attn_out, last_recurrent_state + + +class GatedDeltaNet[ConfigType: GatedDeltaNetConfig](BlockWithBias[ConfigType]): + """ + Follows implementation here: https://github.com/huggingface/transformers/blob/a5c903f877fda21e739027eed133e03162eb7712/src/transformers/models/qwen3_next/modeling_qwen3_next.py#L593 + - For tensor parallel implementtion (no sequnece prallel): we scatter teh heads accross ranks. + - Sequence Tensor parallel: in_proj_qkvz all reduces across sequence dim. --> each rank performs work on full sequence but only a subset of heads (standrd TP). + + """ + + _config: ConfigType + + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + *, + hidden_dim: TensorDim, + lr_scale: float | None, + peft: PeftConfig | None, + return_bias: bool = True, + ): + super().__init__( + config, distributed_config, hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft, return_bias=return_bias + ) + self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) + self._value_heads_dim = TensorDim( + "gdn_value_heads", self._config.value_heads, self._parallel_dim if self._config.value_heads > 1 else None + ) + self._key_heads_dim = TensorDim( + "gdn_key_heads", self._config.key_heads, self._parallel_dim if self._config.key_heads > 1 else None + ) + self._value_head_dim = TensorDim("gdn_value_head_dim", self._config.value_head_dim) + self._key_head_dim = TensorDim("gdn_key_head_dim", self._config.key_head_dim) + self._local_value_heads = self._value_heads_dim.size + self._local_key_heads = self._key_heads_dim.size + self._value_heads_per_key = div(self._local_value_heads, max(self._local_key_heads, 1)) + + query_dim = CompositeTensorDim("gdn_query", (self._key_heads_dim, self._key_head_dim)) + key_dim = CompositeTensorDim("gdn_key", (self._key_heads_dim, self._key_head_dim)) + value_dim = CompositeTensorDim("gdn_value", (self._value_heads_dim, self._value_head_dim)) + z_dim = CompositeTensorDim("gdn_z", (self._value_heads_dim, self._value_head_dim)) + qkvz_dim = ConcatenatedTensorDim("gdn_qkvz", (query_dim, key_dim, value_dim, z_dim)) + ba_dim = ConcatenatedTensorDim( + "gdn_ba", + ( + CompositeTensorDim("gdn_beta", (self._value_heads_dim,)), + CompositeTensorDim("gdn_alpha", (self._value_heads_dim,)), + ), + ) + + qkv_channels_dim = ConcatenatedTensorDim("gdn_qkv", (query_dim, key_dim, value_dim)) + + self.in_proj_qkvz = self._config.qkv_projection_layer.get_layer( + hidden_dim, + qkvz_dim, + default_weight_initialization=init_normal_(std=self._hidden_size**-0.5), + default_add_bias=False, + sequence_parallel=self._sequence_parallel, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.in_proj_ba = self._config.ba_projection_layer.get_layer( + hidden_dim, + ba_dim, + default_weight_initialization=init_normal_(std=self._hidden_size**-0.5), + default_add_bias=False, + sequence_parallel=self._sequence_parallel, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.convolution = self._config.convolution_layer.get_layer( + qkv_channels_dim, + default_add_bias=False, + default_activation=self._config.activation, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.out_proj = self._config.output_layer.get_layer( + value_dim, + hidden_dim, + default_weight_initialization=init_normal_(std=self._hidden_size**-0.5), + default_add_bias=False, + sequence_parallel=self._sequence_parallel, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.dt_bias: ParameterMeta = self._config.dt_bias_weight.get_parameter( + (self._value_heads_dim,), + default_initialization=init_ones_, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.A_log: ParameterMeta = self._config.a_log_weight.get_parameter( + (self._value_heads_dim,), + default_initialization=LambdaInitializer( + lambda _, tensor, generator: tensor.uniform_(0, 16, generator=generator).log_() + ), + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.norm = self._config.normalization.get_layer( + self._value_head_dim, lr_scale=self._lr_scale, peft=self._peft + ) + + self.chunk_gated_delta_rule = chunk_gated_delta_rule or torch_chunk_gated_delta_rule + + if not is_fast_path_available: + logger.warning( + "Fast paths for GatedDeltaNet are not available. Please ensure that 'causal_conv1d' and 'fla' are properly installed." + ) + + def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba): + """ + Derives `query`, `key` and `value` tensors from `mixed_qkvz` and `mixed_ba`. + """ + + new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + ( + self._local_key_heads, + 2 * self._config.key_head_dim + + 2 * self._config.value_head_dim * self._local_value_heads // self._local_key_heads, + ) + new_tensor_shape_ba = mixed_ba.size()[:-1] + ( + self._local_key_heads, + 2 * self._local_value_heads // self._local_key_heads, + ) + + mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz) + mixed_ba = mixed_ba.view(*new_tensor_shape_ba) + split_arg_list_qkvz = [ + self._config.key_head_dim, + self._config.key_head_dim, + (self._local_value_heads // self._local_key_heads * self._config.value_head_dim), + (self._local_value_heads // self._local_key_heads * self._config.value_head_dim), + ] + split_arg_list_ba = [ + self._local_value_heads // self._local_key_heads, + self._local_value_heads // self._local_key_heads, + ] + query, key, value, z = torch.split(mixed_qkvz, split_arg_list_qkvz, dim=3) + b, a = torch.split(mixed_ba, split_arg_list_ba, dim=3) + # [b, sq, ng, np/ng * hn] -> [b, sq, np, hn] + value = value.reshape(value.size(0), value.size(1), -1, self._config.value_head_dim) + z = z.reshape(z.size(0), z.size(1), -1, self._config.value_head_dim) + b = b.reshape(b.size(0), b.size(1), self._local_value_heads) + a = a.reshape(a.size(0), a.size(1), self._local_value_heads) + return query, key, value, z, b, a + + def _forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict[str, typing.Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + sequence_first = kwargs[BlockKwargs.sequence_first] + # in sequence parallel TP the input here is already scattered across sequence dimension + # TODO: do we need masking of padding tokens? + # TODO: make sure varlen is supported + hidden_states = input_ + + # batch_size, sequence_length, _ = hidden_states.shape + projected_states_qkvz = self.in_proj_qkvz(hidden_states) # bs x seq_len x (qkvz) + projected_states_ba = self.in_proj_ba(hidden_states) # bs x seq_len x (b a) + if sequence_first: + projected_states_qkvz = projected_states_qkvz.transpose(0, 1) + projected_states_ba = projected_states_ba.transpose(0, 1) + + query, key, value, z, beta, alpha = self.fix_query_key_value_ordering( + projected_states_qkvz, projected_states_ba + ) + query, key, value = (x.reshape(x.shape[0], x.shape[1], -1) for x in (query, key, value)) + + mixed_qkv = torch.cat((query, key, value), dim=-1) + mixed_qkv = mixed_qkv.transpose(1, 2) + mixed_qkv = self.convolution(mixed_qkv) + mixed_qkv = mixed_qkv.transpose(1, 2) + query, key, value = torch.split( + mixed_qkv, + ( + self._local_key_heads * self._config.key_head_dim, + self._local_key_heads * self._config.key_head_dim, + self._local_value_heads * self._config.value_head_dim, + ), + dim=-1, + ) + query = query.reshape(query.shape[0], query.shape[1], -1, self._config.key_head_dim) + key = key.reshape(key.shape[0], key.shape[1], -1, self._config.key_head_dim) + value = value.reshape(value.shape[0], value.shape[1], -1, self._config.value_head_dim) + + beta = beta.sigmoid() + g = -torch.exp(self.A_log) * F.softplus(alpha + self.dt_bias) + + if self._value_heads_per_key > 1: + query = query.repeat_interleave(self._value_heads_per_key, dim=2) + key = key.repeat_interleave(self._value_heads_per_key, dim=2) + + core_attn_out, _ = self.chunk_gated_delta_rule( + query, + key, + value, + g=g, + beta=beta, + initial_state=None, + output_final_state=False, + use_qk_l2norm_in_kernel=True, + ) + + z_shape_og = z.shape + core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) + z = z.reshape(-1, z.shape[-1]) + core_attn_out = self.norm(core_attn_out, z) + core_attn_out = core_attn_out.reshape(z_shape_og) + core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1) + if sequence_first: + core_attn_out = core_attn_out.transpose(0, 1) + output = self.out_proj(core_attn_out) + + return output + + def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: + # return ( + # self.in_proj_qkvz.get_compute_usage(input_, config) + # + self.in_proj_ba.get_compute_usage(input_, config) + # + self.out_proj.get_compute_usage(input_, config) + # ) + raise NotImplementedError() diff --git a/fast_llm/layers/ssm/kda.py b/fast_llm/layers/ssm/kda.py new file mode 100644 index 00000000..b14fd459 --- /dev/null +++ b/fast_llm/layers/ssm/kda.py @@ -0,0 +1,344 @@ +import logging +import typing + +import torch +from einops import rearrange, repeat + +from fast_llm.engine.base_model.config import ResourceUsageConfig +from fast_llm.engine.config_utils.initialization import LambdaInitializer, init_normal_, init_ones_ +from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, TensorDim +from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames +from fast_llm.layers.block.config import BlockKwargs +from fast_llm.layers.common.peft.config import PeftConfig +from fast_llm.layers.decoder.block import BlockWithBias +from fast_llm.layers.ssm.config import KimiDeltaAttentionConfig, LinearAttentionKwargs +from fast_llm.tensor import ParameterMeta, TensorMeta + +logger = logging.getLogger(__name__) + +try: + from fla.ops.kda import chunk_kda + from fla.ops.kda.gate import fused_kda_gate +except ImportError: + chunk_kda = None + fused_kda_gate = None + + +def index_first_axis(x, indices): + other_shape = x.shape[1:] + second_dim = other_shape.numel() + return torch.gather( + rearrange(x, "b ... -> b (...)"), + 0, + repeat(indices, "z -> z d", d=second_dim), + ).reshape(-1, *other_shape) + + +class KimiDeltaAttention[ConfigType: KimiDeltaAttentionConfig](BlockWithBias[ConfigType]): + """ + Implementation of the Kimi Delta Attention mixer. + Reference Implementation: https://huggingface.co/moonshotai/Kimi-Linear-48B-A3B-Instruct/blob/main/modeling_kimi.py + """ + + _config: ConfigType + + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + *, + hidden_dim: TensorDim, + lr_scale: float | None, + peft: PeftConfig | None, + return_bias: bool = True, + ): + super().__init__( + config, distributed_config, hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft, return_bias=return_bias + ) + if chunk_kda is None or fused_kda_gate is None: + raise ImportError( + "KimiDeltaAttention requires the `fla-core` package. " + "Please install it with `pip install -U fla-core`." + ) + + self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) + self._heads_dim = TensorDim( + "kda_heads", self._config.heads, self._parallel_dim if self._config.heads > 1 else None + ) + self._head_dim = TensorDim("kda_head_dim", self._config.head_dim) + self._projection_dim = CompositeTensorDim("kda_projection", (self._heads_dim, self._head_dim)) + self._local_heads = self._heads_dim.size + self._projection_size = self._projection_dim.size + + init = init_normal_(std=self._hidden_size**-0.5) + self.q_proj = self._config.q_projection_layer.get_layer( + hidden_dim, + self._projection_dim, + default_weight_initialization=init, + default_add_bias=False, + sequence_parallel=self._sequence_parallel, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.k_proj = self._config.k_projection_layer.get_layer( + hidden_dim, + self._projection_dim, + default_weight_initialization=init, + default_add_bias=False, + sequence_parallel=self._sequence_parallel, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.v_proj = self._config.v_projection_layer.get_layer( + hidden_dim, + self._projection_dim, + default_weight_initialization=init, + default_add_bias=False, + sequence_parallel=self._sequence_parallel, + lr_scale=self._lr_scale, + peft=self._peft, + ) + + self.q_conv = self._config.convolution_layer.get_layer( + self._projection_dim, + default_add_bias=False, + default_activation=self._config.convolution_layer.activation, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.k_conv = self._config.convolution_layer.get_layer( + self._projection_dim, + default_add_bias=False, + default_activation=self._config.convolution_layer.activation, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.v_conv = self._config.convolution_layer.get_layer( + self._projection_dim, + default_add_bias=False, + default_activation=self._config.convolution_layer.activation, + lr_scale=self._lr_scale, + peft=self._peft, + ) + + self.f_a_proj = self._config.f_a_projection_layer.get_layer( + hidden_dim, + self._head_dim, + default_weight_initialization=init, + default_add_bias=False, + sequence_parallel=False, # self._sequence_parallel, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.f_b_proj = self._config.f_b_projection_layer.get_layer( + self._head_dim, + self._projection_dim, + default_weight_initialization=init, + default_add_bias=False, + sequence_parallel=self._sequence_parallel, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.g_a_proj = self._config.g_a_projection_layer.get_layer( + hidden_dim, + self._head_dim, + default_weight_initialization=init, + default_add_bias=False, + sequence_parallel=False, # self._sequence_parallel, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.g_b_proj = self._config.g_b_projection_layer.get_layer( + self._head_dim, + self._projection_dim, + default_weight_initialization=init, + default_add_bias=False, + sequence_parallel=self._sequence_parallel, + lr_scale=self._lr_scale, + peft=self._peft, + ) + + self.beta_proj = self._config.beta_projection_layer.get_layer( + hidden_dim, + self._heads_dim, + default_weight_initialization=init, + default_add_bias=False, + sequence_parallel=self._sequence_parallel, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.o_proj = self._config.output_projection_layer.get_layer( + self._projection_dim, + hidden_dim, + default_weight_initialization=init, + default_add_bias=False, + sequence_parallel=self._sequence_parallel, + lr_scale=self._lr_scale, + peft=self._peft, + ) + + self.dt_bias: ParameterMeta = self._config.dt_bias_weight.get_parameter( + (self._projection_dim,), + default_initialization=init_ones_, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.A_log: ParameterMeta = self._config.a_log_weight.get_parameter( + (self._heads_dim,), + default_initialization=LambdaInitializer( + lambda _, tensor, generator: tensor.uniform_(1, 16, generator=generator).log_() + ), + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.norm = self._config.normalization.get_layer( + self._head_dim, + lr_scale=self._lr_scale, + peft=self._peft, + ) + + def _apply_conv(self, tensor: torch.Tensor, conv: torch.nn.Module, seq_idx: torch.Tensor = None) -> torch.Tensor: + """ + Applies convolution. + Note, in the reference code they use Short Convolution from flash-linear-attention/fla/modules/convolution.py, but that one just uses causal_conv1d anyways. + Varlen: + - seq. idx are only suppored in channel last layout, i.e. no transpose + """ + tensor = rearrange(tensor, "b t d -> b d t") + # tensor = tensor.transpose(1, 2).contiguous() if seq_idx is None else tensor.transpose(1, 2) + tensor = conv(tensor, seq_idx=seq_idx) + return tensor.transpose(1, 2).contiguous() + + def _reshape_heads(self, tensor: torch.Tensor) -> torch.Tensor: + tensor = tensor.contiguous() + # since head_dim is the same vor k,q and v + # same as rearrange(v, '... (h d) -> ... h d', d=self.head_dim) + return tensor.view(tensor.shape[0], tensor.shape[1], self._local_heads, self._config.head_dim) + + def _forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict[str, typing.Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + # TODO: make sure varlen is supported + # TODO: do we need to deal with padding tokens? + sequence_first = kwargs[BlockKwargs.sequence_first] + hidden_states = input_ + + cu_seqlens = kwargs.get(LinearAttentionKwargs.cu_seqlens, None) + seq_idx = kwargs.get(LinearAttentionKwargs.seq_idx, None) + # TODO: can be made more efficeint by rearranging hidden states directly + residual_dtype = hidden_states.dtype + + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + + if sequence_first: + # make bs first dim again + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + + batch_size, sequence_length, _ = q.size() + + # work with bs = 1 to make sure varlen works correctly, only needed if micro batch size is > 1 + # can this be applied once to hidden state only? pr + q = rearrange(q, "b s ... -> (b s) ...").unsqueeze(0) + k = rearrange(k, "b s ... -> (b s) ...").unsqueeze(0) + v = rearrange(v, "b s ... -> (b s) ...").unsqueeze(0) + + # because we use cu_seqlens, chunk_kda requires batch size to be 1 (flatten, see https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/ops/kda/chunk.py#L303) + # similarly to ShortConvolution from fla we already operate on flattened batches here (https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/modules/convolution.py#L914) + q = self._apply_conv(q, self.q_conv, seq_idx=seq_idx) + k = self._apply_conv(k, self.k_conv, seq_idx=seq_idx) + v = self._apply_conv(v, self.v_conv, seq_idx=seq_idx) + + g_kernel = self.f_b_proj(self.f_a_proj(hidden_states)) + if sequence_first: + g_kernel = g_kernel.transpose(0, 1) + g_kernel = rearrange(g_kernel, "b s ... -> (b s) ...").unsqueeze(0) + + g_kernel = fused_kda_gate(g_kernel, self.A_log, self._config.head_dim, g_bias=self.dt_bias) + + beta = torch.sigmoid(self.beta_proj(hidden_states).float()) + q = self._reshape_heads(q) + k = self._reshape_heads(k) + v = self._reshape_heads(v) + if sequence_first: + beta = beta.transpose(0, 1) + beta = rearrange(beta, "b s h -> (b s) h").unsqueeze(0) + + # need to install nightly triton for this to work on H100, see https://github.com/fla-org/flash-linear-attention/blob/main/FAQs.md + # cu_seqlens requires batch ssize to be 1, i.e. flattened bacthes + attn_out, _ = chunk_kda( + q=q, + k=k, + v=v, + g=g_kernel, + beta=beta, + initial_state=None, + output_final_state=False, + use_qk_l2norm_in_kernel=True, + cu_seqlens=cu_seqlens, + ) + + attn_out = attn_out.to(residual_dtype) + + g_out = self.g_b_proj(self.g_a_proj(hidden_states)) # bs x seq x n_local_heads x head dim + g_out = self._reshape_heads(g_out) + if sequence_first: + g_out = g_out.transpose(0, 1) + + attn_out = rearrange(attn_out.squeeze(0), "(b s) h d -> b s h d", b=batch_size, s=sequence_length) + attn_out = self.norm(attn_out, g_out) + attn_out = rearrange(attn_out, "b s h d -> b s (h d)") + if sequence_first: + attn_out = attn_out.transpose(0, 1) + attn_out = self.o_proj(attn_out) + + return attn_out + + def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: + raise NotImplementedError() + + def _preprocess_for_varlen(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: + sequence_lengths = kwargs[LinearAttentionKwargs.sequence_lengths] + if sequence_lengths is None: + raise ValueError("sequence_lengths must be provided in kwargs for variable-length sequences.") + sequence_lengths = kwargs[LinearAttentionKwargs.sequence_lengths] + seqlens = torch.cat(sequence_lengths) + cu_seqlens = torch.cat( + ( + torch.zeros(1, dtype=torch.long, device=batch.device), + torch.cumsum(seqlens, dim=0, dtype=torch.long).to(batch.device), + ) + ) + # this is supposed to be flattened, see https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/ops/kda/chunk.py#L303 + # also whenever cu_seqlens is used, batchs size must be forced to 1: see https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/ops/kda/chunk.py#L347 + kwargs[LinearAttentionKwargs.cu_seqlens] = cu_seqlens + # seq_idx has to be (bs, seqlen), but bs is forced to 1 + kwargs[LinearAttentionKwargs.seq_idx] = ( + ( + torch.cat( + [ + torch.arange(n, dtype=cu_seqlens.dtype, device=cu_seqlens.device) + for n in (torch.diff(cu_seqlens).to(torch.int32)) + ], + dim=0, + ) + .eq(0) + .cumsum(0) + - 1 + ) + .to(torch.int32) + .unsqueeze(0) + ) + + def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: + if LinearAttentionKwargs.sequence_lengths in kwargs: + # TODO: packing is enabled by default, i.e. its always used? + # only get here when cross_document_attention is False + self._preprocess_for_varlen(batch, kwargs) diff --git a/fast_llm/models/auto.py b/fast_llm/models/auto.py index 32293266..6682137f 100644 --- a/fast_llm/models/auto.py +++ b/fast_llm/models/auto.py @@ -2,6 +2,13 @@ Import these submodules to ensure classes are added to the dynamic class registry. """ -from fast_llm.layers.ssm.config import MambaConfig, Mamba2Config, DiscreteMamba2Config # isort: skip +from fast_llm.layers.ssm.config import ( # isort: skip + DiscreteMamba2Config, + GatedDeltaNetConfig, + Mamba2Config, + MambaConfig, +) + +from fast_llm.layers.attention.config import AttentionConfig # isort: skip from fast_llm.models.gpt.config import GPTModelConfig, GPTTrainerConfig # isort: skip from fast_llm.engine.evaluation.evaluators import EvaluatorsConfig # isort: skip diff --git a/fast_llm/models/gpt/conversion/apriel.py b/fast_llm/models/gpt/conversion/apriel.py index 7550df04..21ed0ee0 100644 --- a/fast_llm/models/gpt/conversion/apriel.py +++ b/fast_llm/models/gpt/conversion/apriel.py @@ -9,7 +9,12 @@ from fast_llm.layers.block.config import BlockSequenceConfig, FixedBlockSequenceConfig, PatternBlockSequenceConfig from fast_llm.layers.decoder.config import DecoderBlockConfig from fast_llm.layers.decoder.mlp.config import MLPConfig -from fast_llm.layers.ssm.config import DiscreteMamba2Config, Mamba2Config +from fast_llm.layers.ssm.config import ( + DiscreteMamba2Config, + GatedDeltaNetConfig, + KimiDeltaAttentionConfig, + Mamba2Config, +) from fast_llm.models.gpt.config import GPTModelConfig from fast_llm.models.gpt.conversion.config import AprielHybridSSMCheckpointFormat from fast_llm.models.gpt.conversion.llama import ( @@ -229,6 +234,128 @@ def get_converters( ] +class KimiDeltaAttentionConverter: + @classmethod + def import_config(cls, config: dict) -> dict: + return { + "type": "kda", + "head_dim": config["linear_attn_config"]["head_dim"], + "heads": config["linear_attn_config"]["num_heads"], + "convolution_layer": { + "kernel_size": config["linear_attn_config"]["short_conv_kernel_size"], + }, + } + + @classmethod + def export_config(cls, config: KimiDeltaAttentionConfig) -> dict: + return { + "linear_attn_config": { + "head_dim": config.head_dim, + "num_heads": config.heads, + "short_conv_kernel_size": config.convolution_layer.kernel_size, + }, + } + + @classmethod + def get_converters( + cls, + config: KimiDeltaAttentionConfig, + fast_llm_prefix: str, + hf_prefix: str, + drop_on_export: bool = False, + ) -> list[WeightConverter]: + return [ + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.q_proj", + f"{hf_prefix}.q_proj", + False, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.k_proj", + f"{hf_prefix}.k_proj", + False, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.v_proj", + f"{hf_prefix}.v_proj", + False, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.q_conv", + f"{hf_prefix}.q_conv", + False, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.k_conv", + f"{hf_prefix}.k_conv", + False, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.v_conv", + f"{hf_prefix}.v_conv", + False, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.f_a_proj", + f"{hf_prefix}.f_a_proj", + False, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.f_b_proj", + f"{hf_prefix}.f_b_proj", + False, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.g_a_proj", + f"{hf_prefix}.g_a_proj", + False, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.g_b_proj", + f"{hf_prefix}.g_b_proj", + False, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.beta_proj", + f"{hf_prefix}.beta_proj", + False, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.o_proj", + f"{hf_prefix}.o_proj", + False, + drop_on_export=drop_on_export, + ), + get_parameter_converter( + f"{fast_llm_prefix}.A_log", + f"{hf_prefix}.A_log", + drop_on_export=drop_on_export, + ), + get_parameter_converter( + f"{fast_llm_prefix}.dt_bias", + f"{hf_prefix}.dt_bias", + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.norm", + f"{hf_prefix}.norm", + False, + drop_on_export=drop_on_export, + ), + ] + + class AprielMLPConverter(LlamaMLPConverter): @classmethod def import_config(cls, config: dict) -> dict: @@ -242,6 +369,84 @@ def export_config(cls, config: MLPConfig) -> dict: return out +class GatedDeltaNetConverter: + @classmethod + def import_config(cls, config: dict) -> dict: + return { + "type": "gated_delta_net", + "value_heads": config["linear_attn_config"]["gdn_value_head_dim"], + "key_heads": config["linear_attn_config"]["gdn_num_key_heads"], + "key_head_dim": config["linear_attn_config"]["gdn_key_head_dim"], + "value_head_dim": config["linear_attn_config"]["value_head_dim"], + "convolution_layer": { + "kernel_size": config["linear_attn_config"]["gdn_linear_conv_kernel_size"], + }, + } + + @classmethod + def export_config(cls, config: GatedDeltaNetConfig) -> dict: + return { + "linear_attn_config": { + "gdn_num_value_heads": config.value_heads, + "gdn_num_key_heads": config.key_heads, + "gdn_key_head_dim": config.key_head_dim, + "gdn_value_head_dim": config.value_head_dim, + "gdn_linear_conv_kernel_size": config.convolution_layer.kernel_size, + }, + } + + @classmethod + def get_converters( + cls, + config: KimiDeltaAttentionConfig, + fast_llm_prefix: str, + hf_prefix: str, + drop_on_export: bool = False, + ) -> list[WeightConverter]: + return [ + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.in_proj_qkvz", + f"{hf_prefix}.in_proj_qkvz", + False, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.in_proj_ba", + f"{hf_prefix}.in_proj_ba", + False, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.convolution", + f"{hf_prefix}.convolution", + False, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.out_proj", + f"{hf_prefix}.out_proj", + False, + drop_on_export=drop_on_export, + ), + get_parameter_converter( + f"{fast_llm_prefix}.A_log", + f"{hf_prefix}.A_log", + drop_on_export=drop_on_export, + ), + get_parameter_converter( + f"{fast_llm_prefix}.dt_bias", + f"{hf_prefix}.dt_bias", + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.norm", + f"{hf_prefix}.norm", + False, + drop_on_export=drop_on_export, + ), + ] + + class AprielBlockConverterBase(MistralBlockConverter): mlp_converter_class: typing.ClassVar[type[AprielMLPConverter]] = AprielMLPConverter @@ -256,16 +461,30 @@ class AprielMamba2BlockConverter(AprielBlockConverterBase): hf_mixer_name: typing.ClassVar[str] = "mixer" +class AprielKimiDeltaAttentionBlockConverter(AprielBlockConverterBase): + mixer_converter_class: typing.ClassVar[type[KimiDeltaAttentionConverter]] = KimiDeltaAttentionConverter + hf_mixer_name: typing.ClassVar[str] = "mixer" + + +class AprielGatedDeltaNetBlockConverter(AprielBlockConverterBase): + mixer_converter_class: typing.ClassVar[type[GatedDeltaNetConverter]] = GatedDeltaNetConverter + hf_mixer_name: typing.ClassVar[str] = "mixer" + + class AprielBlockConverter: layout_names = { AttentionConfig: "t", Mamba2Config: "m2", DiscreteMamba2Config: "m2d", + KimiDeltaAttentionConfig: "kda", + GatedDeltaNetConfig: "gdn", } _converter_classes = { AttentionConfig: AprielBlockConverterBase, Mamba2Config: AprielMamba2BlockConverter, DiscreteMamba2Config: AprielDiscreteMamba2BlockConverter, + KimiDeltaAttentionConfig: AprielKimiDeltaAttentionBlockConverter, + GatedDeltaNetConfig: AprielGatedDeltaNetBlockConverter, } _config_classes = {value: key for key, value in layout_names.items()} diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index efa348ec..d34dc438 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -251,7 +251,7 @@ def preprocess_batch( loss_mask[start : end + 1, idx] = False else: loss_mask[idx, start : end + 1] = False - if self._config.output_layer.distillation_model is not None: + if self._config.head.distillation_model is not None: kwargs[LanguageModelKwargs.loss_mask] = loss_mask labels = torch.where(loss_mask, labels, -100) kwargs[LanguageModelKwargs.labels] = labels diff --git a/fast_llm_external_models/apriel_hybrid_ssm/modeling_apriel_hybrid_ssm.py b/fast_llm_external_models/apriel_hybrid_ssm/modeling_apriel_hybrid_ssm.py index 40c4cfa8..a80c031a 100644 --- a/fast_llm_external_models/apriel_hybrid_ssm/modeling_apriel_hybrid_ssm.py +++ b/fast_llm_external_models/apriel_hybrid_ssm/modeling_apriel_hybrid_ssm.py @@ -18,7 +18,7 @@ from transformers.modeling_utils import PreTrainedModel from transformers.models.mistral.modeling_mistral import MistralDecoderLayer, MistralMLP, MistralModel, MistralRMSNorm from transformers.processing_utils import Unpack -from transformers.utils import LossKwargs, logging +from transformers.utils import TransformersKwargs, logging from transformers.utils.generic import ModelOutput from fast_llm_external_models.apriel_hybrid_ssm.configuration_apriel_hybrid_ssm import AprielHybridSSMConfig @@ -1252,9 +1252,6 @@ def forward( return output -class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... - - class AprielHybridSSMPreTrainedModel(PreTrainedModel): config_class = AprielHybridSSMConfig base_model_prefix = "model" @@ -1383,7 +1380,7 @@ def forward( output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, - **kwargs: Unpack[KwargsForCausalLM], + **kwargs: Unpack[TransformersKwargs], ) -> Union[tuple, CausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): diff --git a/tests/test_ssm_varlen.py b/tests/test_ssm_varlen.py new file mode 100644 index 00000000..9ca491e3 --- /dev/null +++ b/tests/test_ssm_varlen.py @@ -0,0 +1,259 @@ +import inspect +import itertools + +import pytest +import torch + +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.layers.block.config import BlockKwargs +from fast_llm.layers.ssm import kda as kda_module +from fast_llm.layers.ssm.config import KimiDeltaAttentionConfig, LinearAttentionKwargs + +# from mamba2 import NemotronHMamba2 + + +_mamba_varlen = False +try: + from mamba_ssm.ops.selective_scan_interface import selective_scan_fn # noqa + + _mamba_available = True + sig = inspect.signature(selective_scan_fn) + if "position_indices" in sig.parameters: + _mamba_varlen = True + else: + _mamba_varlen = False + # for training with packing install https://github.com/jxiw/varlen_mamba + # see https://github.com/jxiw/M1/blob/main/HYBRID_PACK.md + +except (ImportError, RuntimeError): + _mamba_available = False + + +@pytest.fixture +def distributed_config(): + return DistributedConfig( + tensor_parallel=1, + pipeline_parallel=1, + sequence_data_parallel=1, + local_world_size=1, + world_size=1, + ) + + +@pytest.fixture +def distributed(distributed_config): + return Distributed(config=distributed_config) + + +def materialize_meta_tensors(model, tensor_space): + # Materialize parameters that are on meta device + for name, param in model.named_parameters(): + if param.device.type == "meta": + # Check if the parameter is a custom tensor type + if hasattr(param, "tensor_name") and hasattr(param, "init_parameter"): + param_data = param.new_empty(param.shape, device="cuda") + # Initialize param_data + param.init_parameter(param_data, tensor_space.distributed) + # Replace the parameter in the module + module_path, param_name = name.rsplit(".", 1) if "." in name else (None, name) + module = model + if module_path is not None: + for part in module_path.split("."): + module = getattr(module, part) + param = torch.nn.Parameter(param_data, requires_grad=param.requires_grad) + # TODO: add param_grad_is_zero etc., grad_buffer, etc., see test_mlp_recomputation + param.grad = None + param.grad_buffer = torch.empty_like(param) + param.param_grad_is_zero = True + module._parameters[param_name] = param + return model + + +def unpack(packed_hidden_states, cu_seqlens): + batch_size = packed_hidden_states.shape[0] + package_num = cu_seqlens.shape[0] - 1 + seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + hidden_dim = packed_hidden_states.shape[2] + hidden_states = torch.zeros( + package_num * batch_size, + seq_len, + hidden_dim, + dtype=packed_hidden_states.dtype, + device=packed_hidden_states.device, + ) + for j in range(batch_size): + for i in range(package_num): + line = j * package_num + i + hidden_states[line, : cu_seqlens[i + 1] - cu_seqlens[i], :] = packed_hidden_states[ + j, cu_seqlens[i] : cu_seqlens[i + 1], : + ] + return hidden_states + + +def pack(hidden_states, cu_seqlens, batch_size): + package_num, seq_len, hidden_dim = hidden_states.shape + seq_len_list = cu_seqlens[1:] - cu_seqlens[:-1] + seq_len_list_3d = seq_len_list.unsqueeze(1).unsqueeze(2) + indices_3d = ( + torch.arange(seq_len, device=hidden_states.device).unsqueeze(0).unsqueeze(2).repeat(package_num, 1, hidden_dim) + ) + mask_3d = indices_3d < seq_len_list_3d.repeat(batch_size, 1, 1) + packed_hidden_states = hidden_states[mask_3d].view(batch_size, -1, hidden_dim) + return packed_hidden_states + + +def generate_random_cu_seqlens(seq_len, packages_num=2): + if packages_num < 1: + raise ValueError("packages_num must be at least 1") + + # base size of each chunk, and how many get an extra token + base, rem = divmod(seq_len, packages_num) + # lengths: e.g. for seq_len=10, packages=3 → [4,3,3] + lengths = [base + 1 if i < rem else base for i in range(packages_num)] + + # split points exclude the final cumulative (seq_len) + split_points = list(itertools.accumulate(lengths))[:-1] + + cu_seqlens = [0] + split_points + [seq_len] + # cu_seqlens = split_points # + [seq_len] + + # index: for each chunk, we emit 0,1,...,length-1 + index = [] + for length in lengths: + index.extend(range(length)) + + # sanity check + assert len(cu_seqlens) - 1 == packages_num + assert sum(lengths) == seq_len + assert len(index) == seq_len + + return cu_seqlens, index + + +def _materialize_kda_tensors(module: torch.nn.Module, distributed: Distributed, device: torch.device) -> None: + """ + Materialize meta parameters on the requested device for KDA mixer layers. + """ + for name, param in module.named_parameters(): + if param.device.type != "meta": + continue + param_data = torch.empty_like(param, device=device) + param.init_parameter(param_data, distributed) + module_path, param_name = name.rsplit(".", 1) if "." in name else (None, name) + target = module + if module_path is not None: + for part in module_path.split("."): + target = getattr(target, part) + new_param = torch.nn.Parameter(param_data, requires_grad=param.requires_grad) + new_param.grad = None + new_param.grad_buffer = torch.zeros_like(param_data) + new_param.param_grad_is_zero = True + target._parameters[param_name] = new_param + + +def _param_grad(param: torch.nn.Parameter) -> torch.Tensor | None: + # ParameterMeta stores grads in grad_buffer; fall back to .grad otherwise. + return param.grad_buffer if hasattr(param, "grad_buffer") and param.grad_buffer is not None else param.grad + + +@pytest.mark.slow +@pytest.mark.skipif(not torch.cuda.is_available(), reason="KDA varlen needs CUDA") +@pytest.mark.skipif( + kda_module.chunk_kda is None or kda_module.fused_kda_gate is None, + reason="KDA fused kernels not available", +) +def test_kda_varlen_stacking_equivalence(distributed_config, distributed): + """ + Check that KDA forward/backward match with and without stacking using the real kernels. + """ + device = torch.device("cuda") + dtype = torch.float16 + heads, head_dim = 2, 16 + hidden_size = heads * head_dim + + config = KimiDeltaAttentionConfig(heads=heads, head_dim=head_dim) + hidden_dim = TensorDim("hidden", hidden_size) + kda_packed = config.get_layer(distributed_config, hidden_dim, lr_scale=None, peft=None, return_bias=False) + kda_ref = config.get_layer(distributed_config, hidden_dim, lr_scale=None, peft=None, return_bias=False) + kda_packed.setup(distributed) + kda_ref.setup(distributed) + _materialize_kda_tensors(kda_packed, distributed, device) + _materialize_kda_tensors(kda_ref, distributed, device) + kda_ref.load_state_dict(kda_packed.state_dict()) + kda_packed.to(device=device, dtype=dtype) + kda_ref.to(device=device, dtype=dtype) + + batch_size = 2 # cu_seqlens path requires flattened batch + seq_len = 15 + packages_num = torch.randint(2, 5, (1, batch_size))[0] # randomize packages num between 2 and 4 + lengths = [ + torch.tensor( + generate_random_cu_seqlens(seq_len, packages_num=packages_num[i].item())[0], + device=device, + dtype=torch.long, + ).diff() + for i in range(batch_size) + ] + + # lengths = torch.tensor(cu_seqlens, device=device, dtype=torch.long)#.diff() + # total_tokens = lengths.sum().item() + packed = torch.randn(batch_size, seq_len, hidden_size, device=device, dtype=dtype, requires_grad=True) + + kwargs_packed = { + LinearAttentionKwargs.sequence_lengths: lengths, + BlockKwargs.sequence_first: False, + BlockKwargs.hidden_dims: (hidden_dim,), + # BlockKwargs.sequence_q_dim: TensorDim("sequence_q", lengths.sum().item()), + } + # Use the layer's preprocess to construct cu_seqlens/seq_idx the same way as the implementation. + kda_packed.preprocess(packed, kwargs_packed) + + kwargs_ref = { + BlockKwargs.sequence_first: False, + BlockKwargs.hidden_dims: (hidden_dim,), + } + + out_packed = kda_packed(packed, kwargs_packed) + # Run reference path separately per sequence without varlen packing, then concatenate. + ref_outs = [] + for b in range(batch_size): + out_batch = [] + length = lengths[b] + ref_seqs = torch.split(packed[b].unsqueeze(0), length.tolist(), dim=1) + for seq in ref_seqs: + kwargs_ref_seq = { + **kwargs_ref, + BlockKwargs.sequence_q_dim: TensorDim("sequence_q", seq.shape[1]), + } + out_batch.append(kda_ref(seq, kwargs_ref_seq)) + ref_outs.append(torch.cat(out_batch, dim=1)) + out_ref = torch.cat(ref_outs, dim=0) + out_ref_packed = out_ref + + assert out_packed.shape == packed.shape + assert out_ref_packed.shape == out_packed.shape + assert torch.allclose(out_packed, out_ref_packed, atol=1e-3, rtol=1e-3) + + out_packed.sum().backward() + out_ref_packed.sum().backward() + + assert _param_grad(kda_packed.q_proj.weight) is not None + assert _param_grad(kda_ref.q_proj.weight) is not None + assert torch.allclose( + _param_grad(kda_packed.q_proj.weight), _param_grad(kda_ref.q_proj.weight), atol=1e-3, rtol=1e-3 + ) + assert torch.allclose( + _param_grad(kda_packed.k_proj.weight), _param_grad(kda_ref.k_proj.weight), atol=1e-3, rtol=1e-3 + ) + assert torch.allclose( + _param_grad(kda_packed.v_proj.weight), _param_grad(kda_ref.v_proj.weight), atol=1e-3, rtol=1e-3 + ) + assert torch.allclose( + _param_grad(kda_packed.o_proj.weight), _param_grad(kda_ref.o_proj.weight), atol=1e-3, rtol=1e-3 + ) + + +if __name__ == "__main__": + pytest.main([__file__])