diff --git a/fast_llm/models/ssm/external/15B_hybrid.ipynb b/fast_llm/models/ssm/external/15B_hybrid.ipynb index a8f0c33b7..8761de425 100644 --- a/fast_llm/models/ssm/external/15B_hybrid.ipynb +++ b/fast_llm/models/ssm/external/15B_hybrid.ipynb @@ -1,5 +1,116 @@ { "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Add MIL innitialized SSM layers to exsiting SSM checkpoint" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/toolkit/.local/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "from transformers import MistralForCausalLM\n", + "\n", + "from fast_llm.models.ssm.external.apriel_15b_hybrid.configuration_ssm_hybrid_apriel15b import AprielSSMHybridConfig\n", + "from fast_llm.models.ssm.external.apriel_15b_hybrid.modeling_ssm_hybrid_apriel15b import AprielThinkerSSMHybridForCausalLM, AprielSSMM2DecoderLayer, AprielSSMDecoderLayer\n", + "from transformers.models.mistral.modeling_mistral import MistralDecoderLayer\n", + "\n", + "# enable file reload \n", + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "path_thinker = \"/mnt/checkpoints/upstream/Apriel-Nemotron-15b-Thinker\"\n", + "n_ssm = 25\n", + "new_ssm_layers = [3]\n", + "\n", + "config_thinker = AutoConfig.from_pretrained(path_thinker)\n", + "# config_thinker.num_hidden_layers = 5\n", + "hybrid_block_layout = [\"t\"] * config_thinker.num_hidden_layers\n", + "hybrid_block_layout[3] = \"m2\"\n", + "\n", + "\n", + "dstate = 16\n", + "expand = 1\n", + "# Calculate derived dimensions for the Mamba1 configuration\n", + "d_model = config_thinker.hidden_size\n", + "d_inner = 4096 # hard code to match thinker #expand * d_model\n", + "d_xb = 1024 # hard code to match thinker #config_thinker.num_key_value_heads * (config_thinker.hidden_size // config_thinker.num_attention_heads)\n", + "\n", + "config_hybrid = AprielSSMHybridConfig(\n", + " **config_thinker.to_dict(),\n", + " hybrid_block_layout=hybrid_block_layout,\n", + " # discrete mamba2\n", + " # ssm_cfg = {\n", + " # \"d_state\": dstate,\n", + " # \"n_v_heads\": 32,\n", + " # \"n_qk_heads\": 32,\n", + " # \"expand\": 1,\n", + " # \"chunk_size\": 128,\n", + " # \"activation\": \"identity\",\n", + " # \"bias\": False,\n", + " # \"d_conv\": 4,\n", + " # \"d_inner\": 32 * 128,\n", + " # }\n", + " # mamba 2: uses expantion nternally\n", + " # https://github.com/jxiw/M1/blob/537a1ca5407a786a99dc6c721873493cf8750d5e/mamba/hybrid_mamba_config.py\n", + " \n", + " ssm_cfg = {\n", + " \"d_state\": dstate,\n", + " \"d_xb\": d_xb,\n", + " # \"d_model\": d_model, # will be set automatically\n", + " \"expand\": expand,\n", + " \"d_conv\": 4,\n", + " \"d_inner\": d_inner, # will be same as d_model * expand,\n", + " \"conv_bias\": True,\n", + " \"bias\": False,\n", + " }\n", + ")\n", + "# model_hybrid = AprielThinkerSSMHybridForCausalLM(config_hybrid)\n", + "# transformer = AutoModelForCausalLM.from_pretrained(path_thinker)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "code", "execution_count": 1, diff --git a/fast_llm/models/ssm/external/apriel_15b_hybrid/configuration_ssm_hybrid_apriel15b.py b/fast_llm/models/ssm/external/apriel_15b_hybrid/configuration_ssm_hybrid_apriel15b.py index 98d2fc28d..f95fe7368 100644 --- a/fast_llm/models/ssm/external/apriel_15b_hybrid/configuration_ssm_hybrid_apriel15b.py +++ b/fast_llm/models/ssm/external/apriel_15b_hybrid/configuration_ssm_hybrid_apriel15b.py @@ -26,15 +26,56 @@ } +class AprielGDNConfig: + def __init__( + self, + linear_num_key_heads=16, + linear_num_value_heads=32, + linear_key_head_dim=128, + linear_value_head_dim=128, + linear_conv_kernel_dim=4, + kl_short_conv_kernel_size=4, + kl_num_heads=32, + kl_head_dim=128, + ): + self.linear_num_key_heads = linear_num_key_heads + self.linear_num_value_heads = linear_num_value_heads + self.linear_key_head_dim = linear_key_head_dim + self.linear_value_head_dim = linear_value_head_dim + self.linear_conv_kernel_dim = linear_conv_kernel_dim + + # Kimi LInear + self.short_conv_kernel_size = kl_short_conv_kernel_size + self.head_dim = kl_head_dim + self.num_heads = kl_num_heads + + +LAYER_TYPES = {"t": "full_attention", "swa": "sliding_attention", "gdn": "gated_delta_net", "kl": "kimi_linear"} + + class AprielSSMHybridConfig(MistralConfig): model_type = "apriel_ssm_thinker_hybrid" - def __init__(self, hybrid_block_layout=["m2d"], ssm_cfg=None, **kwargs): + def __init__(self, hybrid_block_layout=["t"], ssm_cfg=None, gdn_cfg=None, **kwargs): super().__init__(**kwargs) self.hybrid_block_layout = hybrid_block_layout self.head_dim = self.head_dim or self.hidden_size // self.num_attention_heads # as in transformers 4.51.3 self.ssm_cfg = ssm_cfg or ssm_config_default + gdn_config: AprielGDNConfig = ( + AprielGDNConfig(**gdn_cfg) if isinstance(gdn_cfg, dict) else gdn_cfg or AprielGDNConfig() + ) + + # make elements of gdn_config accessible as attributes of self to pass self directly to Qwen3NextGatedDeltaNet + for k, v in vars(gdn_config).items(): + setattr(self, k, v) + for k, v in ssm_config_default.items(): if k not in self.ssm_cfg: self.ssm_cfg[k] = v # to make sure all elements are present in the config + self.layer_types = [LAYER_TYPES[lt] for lt in hybrid_block_layout] # this is for vllm compatibility + self.linear_attn_config = { + "short_conv_kernel_size": gdn_config.short_conv_kernel_size, + "head_dim": gdn_config.head_dim, + "num_heads": gdn_config.num_heads, + } diff --git a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py index 9f4588a29..38c65bc0a 100644 --- a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py +++ b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py @@ -13,19 +13,24 @@ from torch import nn from transformers import GenerationMixin from transformers.cache_utils import Cache, DynamicCache +from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.modeling_utils import PreTrainedModel from transformers.models.mistral.modeling_mistral import MistralDecoderLayer, MistralMLP, MistralModel, MistralRMSNorm +from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextGatedDeltaNet from transformers.processing_utils import Unpack -from transformers.utils import LossKwargs, can_return_tuple, logging +from transformers.utils import logging from transformers.utils.generic import ModelOutput from fast_llm.models.ssm.external.apriel_15b_hybrid.configuration_ssm_hybrid_apriel15b import AprielSSMHybridConfig -# from vllm.model_executor.layers.mamba.ops.mamba_ssm import selective_scan_fn as varlen_selective_scan_fn -# from vllm.model_executor.layers.mamba.ops.causal_conv1d import causal_conv1d_fn as varlen_causal_conv1d_fn - +try: + from fla.modules import FusedRMSNormGated, ShortConvolution + from fla.ops.kda import chunk_kda, fused_recurrent_kda + from fla.ops.kda.gate import fused_kda_gate +except ImportError: + raise ImportError("Plese run `pip install -U fla-core`") logger = logging.get_logger(__name__) @@ -389,6 +394,162 @@ class AprielHybridCausalOutput(ModelOutput): past_key_values: Optional[Cache] = None +class KimiDeltaAttention(nn.Module): + def __init__(self, config: AprielSSMHybridConfig, layer_idx: int): + super().__init__() + self.config = config + self.mode = "chunk" + + self.hidden_size = config.hidden_size + self.conv_size = config.short_conv_kernel_size + self.head_dim = config.head_dim + self.num_heads = config.num_heads + self.head_k_dim = self.head_dim + self.num_k_heads = self.num_heads + + self.layer_idx = layer_idx + + assert self.mode in ["chunk", "fused_recurrent"], f"Not suppoerted mode `{self.mode}`." + + projection_k_size = self.head_k_dim * self.num_k_heads + projection_size = self.head_dim * self.num_heads + + self.q_proj = nn.Linear(self.hidden_size, projection_k_size, bias=False) + self.k_proj = nn.Linear(self.hidden_size, projection_k_size, bias=False) + self.v_proj = nn.Linear(self.hidden_size, projection_size, bias=False) + + self.q_conv1d = ShortConvolution( + hidden_size=projection_k_size, + kernel_size=self.conv_size, + activation="silu", + ) + self.k_conv1d = ShortConvolution( + hidden_size=projection_k_size, + kernel_size=self.conv_size, + activation="silu", + ) + self.v_conv1d = ShortConvolution( + hidden_size=projection_size, + kernel_size=self.conv_size, + activation="silu", + ) + + self.A_log = torch.nn.Parameter( + torch.log(torch.empty(self.num_heads, dtype=torch.float32).uniform_(1, 16)).view(1, 1, -1, 1) + ) + + self.f_a_proj = nn.Linear(self.hidden_size, self.head_dim, bias=False) + self.f_b_proj = nn.Linear(self.head_dim, projection_size, bias=False) + + self.dt_bias = nn.Parameter(torch.empty(projection_size, dtype=torch.float32)) + + self.b_proj = nn.Linear(self.hidden_size, self.num_heads, bias=False) + + self.g_a_proj = nn.Linear(self.hidden_size, self.head_dim, bias=False) + self.g_b_proj = nn.Linear(self.head_dim, projection_size, bias=False) + + self.o_norm = FusedRMSNormGated(self.head_dim, eps=config.rms_norm_eps, activation="sigmoid") + self.o_proj = nn.Linear(projection_size, self.hidden_size, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + cache_params=None, + **kwargs: Unpack[dict], + ) -> tuple[torch.Tensor, torch.Tensor | None, Cache | None]: + if attention_mask is not None: + if attention_mask.dim() != 2: + attention_mask = kwargs.get("padding_mask") + + if attention_mask is not None and attention_mask.dim() != 2: + raise ValueError( + "attention_mask must be a 0-1 matrix of shape [batch_size, seq_len] " + "(0 = padding). 3D masks are not supported here.", + ) + use_cache = cache_params is not None + batch_size, q_len, _ = hidden_states.shape + mode = "fused_recurrent" if q_len <= 64 else self.mode + if self.training: + assert mode == "chunk", "Only chunk mode is supported in training." + + cu_seqlens = kwargs.get("cu_seqlens") + indices = None + if attention_mask is not None: + indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:]) + hidden_states = index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices).unsqueeze(0) + + conv_state_q, conv_state_k, conv_state_v = None, None, None + recurrent_state = None + if cache_params is not None: + if cache_params.conv_states[self.layer_idx] is not None: + conv_state_q, conv_state_k, conv_state_v = cache_params.conv_states[self.layer_idx] + recurrent_state = cache_params.recurrent_states[self.layer_idx] + q, conv_state_q = self.q_conv1d( + x=self.q_proj(hidden_states), + cache=conv_state_q, + output_final_state=use_cache, + cu_seqlens=cu_seqlens, + ) + k, conv_state_k = self.k_conv1d( + x=self.k_proj(hidden_states), + cache=conv_state_k, + output_final_state=use_cache, + cu_seqlens=cu_seqlens, + ) + v, conv_state_v = self.v_conv1d( + x=self.v_proj(hidden_states), + cache=conv_state_v, + output_final_state=use_cache, + cu_seqlens=cu_seqlens, + ) + g = self.f_b_proj(self.f_a_proj(hidden_states)) + g = fused_kda_gate(g, self.A_log, self.head_dim, g_bias=self.dt_bias) + beta = self.b_proj(hidden_states).float().sigmoid() + + q, k = map(lambda x: rearrange(x, "... (h d) -> ... h d", d=self.head_k_dim), (q, k)) + v = rearrange(v, "... (h d) -> ... h d", d=self.head_dim) + + if mode == "chunk": + o, recurrent_state = chunk_kda( + q=q, + k=k, + v=v, + g=g, + beta=beta, + initial_state=recurrent_state, + output_final_state=True, + use_qk_l2norm_in_kernel=True, + cu_seqlens=cu_seqlens, + ) + else: + o, recurrent_state = fused_recurrent_kda( + q=q, + k=k, + v=v, + g=g, + beta=beta, + initial_state=recurrent_state, + output_final_state=True, + use_qk_l2norm_in_kernel=True, + cu_seqlens=cu_seqlens, + ) + if cache_params is not None: + cache_params.recurrent_states[self.layer_idx] = recurrent_state + cache_params.conv_states[self.layer_idx] = (conv_state_q, conv_state_k, conv_state_v) + + g = self.g_b_proj(self.g_a_proj(hidden_states)) + g = rearrange(g, "... (h d) -> ... h d", d=self.head_dim) + o = self.o_norm(o, g) + + o = rearrange(o, "b t h d -> b t (h d)") + o = self.o_proj(o) + if attention_mask is not None: + o = pad_input(o.squeeze(0), indices, batch_size, q_len) + + return o + + def segsum(x): """More stable segment sum calculation.""" # [1, 2, 3] @@ -1187,6 +1348,63 @@ def forward( return outputs +class AprielKLDecoderLayer(nn.Module): + _mixer_class = KimiDeltaAttention + + def __init__(self, config: AprielSSMHybridConfig, layer_idx: int, device=None, dtype=None, **kwargs): + super().__init__(**kwargs) + factory_kwargs = {"device": device, "dtype": dtype} + self.hidden_size = config.hidden_size + + self.mixer = self._mixer_class(config, layer_idx=layer_idx) + + self.mlp = MistralMLP(config) + self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + +class AprielGDNDecoderLayer(nn.Module): + _mixer_class = Qwen3NextGatedDeltaNet + + def __init__(self, config: AprielSSMHybridConfig, layer_idx: int, device=None, dtype=None, **kwargs): + super().__init__(**kwargs) + factory_kwargs = {"device": device, "dtype": dtype} + self.hidden_size = config.hidden_size + + self.mixer = self._mixer_class(config, layer_idx=layer_idx) + + self.mlp = MistralMLP(config) + self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, hidden_states: torch.Tensor, **kwargs + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + + outputs = {} + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + mixer_outputs = self.mixer( + hidden_states, + **kwargs, + ) + + hidden_states = mixer_outputs["hidden_states"].to(residual.dtype) + residual + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + # outputs["hidden_states"] = hidden_states + outputs = (hidden_states,) + + return outputs + + class AprielSSMM2DecoderLayer(AprielSSMDecoderLayer): _mixer_class = Mamba2 @@ -1200,9 +1418,23 @@ def forward(self, hidden_states: torch.Tensor, **kwargs): return (hidden_states,) +class MistralDecoderLayerSWA(MistralDecoderLayer): + """ + Thin wrapper over `MistralDecoderLayer` that marks layers meant to run with sliding-window attention. + """ + + def __init__(self, config: AprielSSMHybridConfig, layer_idx: int): + super().__init__(config, layer_idx) + self.is_swa_layer = True + + class AprielThinkerSSMHybridModel(MistralModel): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`AprielDecoderLayer`, `AprielSSMDecoderLayer`] + Note regarding SWA: + - in the origiinal Msitral model if self.config.sliding_window is set, all layers use SWA. + - in this hybrid model, only layers marked as `swa` in the `hybrid_block_layout` use SWA ans use the global `self.config.sliding_window`, otherwise 't' layers ignore `self.config.sliding_window` + TODO: this has not been tested yet, focused on vllm compatibility first. Args: config: AprielSSMHybridConfig """ @@ -1221,8 +1453,17 @@ def __init__(self, config: AprielSSMHybridConfig, **kwargs): blocks.append(AprielSSMM2DecoderLayer(config, layer_idx)) elif type == "t": blocks.append(MistralDecoderLayer(config, layer_idx)) + elif type == "swa": + blocks.append(MistralDecoderLayerSWA(config, layer_idx)) + assert ( + config.sliding_window is not None + ), "Found `swa` layers in hybrid layout but `swa_sliding_window` is not set." elif type == "i": blocks.append(AprielHybridIdentity(config)) + elif type == "gdn": + blocks.append(AprielGDNDecoderLayer(config, layer_idx)) + elif type == "kl": + blocks.append(AprielKLDecoderLayer(config, layer_idx)) else: raise ValueError(f"Invalid block type: {type}") self.layers = nn.ModuleList(blocks) @@ -1230,7 +1471,7 @@ def __init__(self, config: AprielSSMHybridConfig, **kwargs): # Initialize weights and apply final processing self.post_init() - @can_return_tuple + # @can_return_tuple def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -1246,34 +1487,116 @@ def forward( ) -> BaseModelOutputWithPast: use_cache = use_cache if use_cache is not None else self.config.use_cache if use_cache and past_key_values is None: - # for the case where prepare_inputs_for_generation is not called to create the cache (as in fast-llm test) + # for the case where prepare_inputs_for_generation is not called to create the cache (as in fast-llm test) batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0] past_key_values = HybridMambaAttentionDynamicCache(self.config, batch_size, self.dtype, device=self.device) - output = super().forward( - input_ids=input_ids, + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, cache_position=cache_position, - **flash_attn_kwargs, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + sliding_causal_mask = None + has_swa_blocks = any(block == "swa" for block in self.config.hybrid_block_layout) + if has_swa_blocks: + if self.config.sliding_window is None: + raise ValueError("Found `swa` layers in hybrid layout but `swa_sliding_window` is not set.") + sliding_causal_mask = create_sliding_window_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + layers = self.layers[: self.config.num_hidden_layers] + for layer_idx, decoder_layer in enumerate(layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + block_type = self.config.hybrid_block_layout[layer_idx] + layer_mask = sliding_causal_mask if block_type == "swa" else causal_mask + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=layer_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + if isinstance(layer_outputs, tuple): + hidden_states = layer_outputs[0] + remaining = layer_outputs[1:] + else: + hidden_states = layer_outputs + remaining = () + + if output_attentions: + attn_tensor = remaining[0] if remaining else None + all_self_attns = all_self_attns + (attn_tensor,) + + hidden_states = self.norm(hidden_states) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + outputs = BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, ) - past_key_values: HybridMambaAttentionDynamicCache = output.past_key_values - if past_key_values and not past_key_values.has_previous_state: - past_key_values.has_previous_state = True - return output + + if ( + isinstance(outputs.past_key_values, HybridMambaAttentionDynamicCache) + and not outputs.past_key_values.has_previous_state + ): + outputs.past_key_values.has_previous_state = True + + return outputs -class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... +# class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class AprielThinkerSSMHybridPreTrainedModel(PreTrainedModel): config_class = AprielSSMHybridConfig base_model_prefix = "model" - _no_split_modules = ["MistralDecoderLayer", "AprielSSMDecoderLayer", "AprielSSMM2DecoderLayer"] + _no_split_modules = [ + "MistralDecoderLayer", + "MistralDecoderLayerSWA", + "AprielSSMDecoderLayer", + "AprielSSMM2DecoderLayer", + ] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True @@ -1398,7 +1721,7 @@ def forward( output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, - **kwargs: Unpack[KwargsForCausalLM], + **kwargs, #: Unpack[KwargsForCausalLM], ) -> Union[tuple, CausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): diff --git a/fast_llm/models/ssm/external/llava_hybrid/configuration_llava_hybrid.py b/fast_llm/models/ssm/external/llava_hybrid/configuration_llava_hybrid.py index b8e822d9f..5775cbaf6 100644 --- a/fast_llm/models/ssm/external/llava_hybrid/configuration_llava_hybrid.py +++ b/fast_llm/models/ssm/external/llava_hybrid/configuration_llava_hybrid.py @@ -1,8 +1,14 @@ -from transformers import MistralConfig from transformers.configuration_utils import PretrainedConfig from transformers.models.auto import CONFIG_MAPPING from transformers.utils import logging +try: + # In the fast-llm repo, import from the SSM modeling file + from fast_llm.models.ssm.external.apriel_15b_hybrid.configuration_ssm_hybrid_apriel15b import AprielSSMHybridConfig +except ImportError: + # In the exported checkpoint, import from local file + from .configuration_ssm_hybrid_apriel15b import AprielSSMHybridConfig + logger = logging.get_logger(__name__) # Copied from configuration_ssm_hybrid_apriel15b.py @@ -30,20 +36,6 @@ } -class AprielSSMHybridConfig(MistralConfig): - model_type = "apriel_ssm_thinker_hybrid" - - def __init__(self, hybrid_block_layout=["m2d"], ssm_cfg=None, **kwargs): - super().__init__(**kwargs) - self.hybrid_block_layout = hybrid_block_layout - self.head_dim = self.head_dim or self.hidden_size // self.num_attention_heads # as in transformers 4.51.3 - self.ssm_cfg = ssm_cfg or ssm_config_default - - for k, v in ssm_config_default.items(): - if k not in self.ssm_cfg: - self.ssm_cfg[k] = v # to make sure all elements are present in the config - - class LlavaHybridConfig(PretrainedConfig): """ Configuration class for Llava SSM-Hybrid-decoder model. diff --git a/fast_llm/models/ssm/external/make_llava_hybrid_swa_gdn_checkpoint.py b/fast_llm/models/ssm/external/make_llava_hybrid_swa_gdn_checkpoint.py new file mode 100644 index 000000000..138467046 --- /dev/null +++ b/fast_llm/models/ssm/external/make_llava_hybrid_swa_gdn_checkpoint.py @@ -0,0 +1,182 @@ +import gc +import json +import os +import shutil + +import click +import torch +from transformers import AutoConfig + +from fast_llm.models.ssm.external.apriel_15b_hybrid import ( + configuration_ssm_hybrid_apriel15b, + modeling_ssm_hybrid_apriel15b, +) +from fast_llm.models.ssm.external.llava_hybrid import configuration_llava_hybrid, modeling_llava_hybrid +from fast_llm.models.ssm.external.llava_hybrid.configuration_llava_hybrid import LlavaHybridConfig +from fast_llm.models.ssm.external.llava_hybrid.modeling_llava_hybrid import LlavaHybridForConditionalGeneration + +device = "cuda" if torch.cuda.is_available() else "cpu" +# swa_size = 2048 +dstate = 16 +expand = 1 +# Calculate derived dimensions for the Mamba1 configuration +# d_model = config_base.text_config.hidden_size +d_inner = 4096 # hard code to match thinker #expand * d_model +d_xb = 1024 # hard code to match thinker #config_thinker.num_key_value_heads * (config_thinker.hidden_size // config_thinker.num_attention_heads) + + +def make_hybrid_llava_config(transformer_config, swa_size): + config_dict = transformer_config.to_dict() + config_dict["text_config"]["model_type"] = "apriel_ssm_thinker_hybrid" + if "swa" in transformer_config.text_config.hybrid_block_layout: + config_dict["text_config"]["sliding_window"] = swa_size + if "dtype" not in config_dict["text_config"] or config_dict["text_config"]["dtype"] is None: + config_dict["text_config"]["dtype"] = config_dict["dtype"] + config_dict["text_config"]["ssm_cfg"] = { + "activation": "silu", + "d_state": dstate, + "d_xb": d_xb, + # "d_model": d_model, # will be set automatically + "expand": expand, + "d_conv": 4, + "d_inner": d_inner, # will be same as d_model * expand, + "conv_bias": True, + "bias": False, + } + config_dict["auto_map"] = { + "AutoConfig": "configuration_llava_hybrid.LlavaHybridConfig", + "AutoModel": "modeling_llava_hybrid.LlavaHybridModel", + "AutoModelForCausalLM": "modeling_llava_hybrid.LlavaHybridForConditionalGeneration", + "AutoModelForVision2Seq": "modeling_llava_hybrid.LlavaHybridForConditionalGeneration", + } + config_dict["text_config"]["auto_map"] = { + "AutoConfig": "configuration_ssm_hybrid_apriel15b.AprielSSMHybridConfig", + "AutoModel": "modeling_ssm_hybrid_apriel15b.AprielThinkerSSMHybridModel", + "AutoModelForCausalLM": "modeling_ssm_hybrid_apriel15b.AprielThinkerSSMHybridForCausalLM", + } + llava_hybrid_config = LlavaHybridConfig(**config_dict) + return llava_hybrid_config + + +def make_hybrid_llava_model(transformer, llava_hybrid_config): + """ + Create a LlavaHybridForConditionalGeneration model with the same configuration as the given transformer model. + """ + llava_hybrid_model = LlavaHybridForConditionalGeneration(llava_hybrid_config) + # llava_hybrid_model.to(dtype=torch.bfloat16).to(device) + llava_hybrid_model.load_state_dict(transformer.state_dict(), strict=False) + return llava_hybrid_model + + +@click.command() +@click.option( + "--base_checkpoint", type=str, required=False, default="/mnt/checkpoints/upstream/Apriel-1.5-15b-Thinker" +) +@click.option("--n_swa", type=int, required=False, default=0) +@click.option("--n_gdn", type=int, default=1, required=False) +@click.option("--n_kl", type=int, default=0, required=False) +@click.option( + "--save_dir", + type=str, + required=False, + default="/mnt/checkpoints/ssm/vllm_checkpoints/apriel_hybrid_throughput_checkpoints/checkpoints_gdn_swa_2048/test", +) +@click.option("--skip_if_exists", is_flag=True, default=False) +@click.option("--tokenizer_dir", type=str, required=False, default="/mnt/checkpoints/upstream/Apriel-1.5-15b-Thinker") +@click.option("--swa_size", type=int, required=False, default=2048) +def main( + base_checkpoint: str, + n_swa: int, + n_gdn: int, + n_kl: int, + save_dir: str, + skip_if_exists: bool, + tokenizer_dir: str, + swa_size: int, +): + """ + base_checkpoint: path to base transformer-model (teacher model) + m2_indices: indices of layers to convert to mamba layers with MiL init + hybrid_checkpoint: path to hybrid model (student model). Can be a hybrid with only transformer layers for the first distillation run. + save_dir: directory to save the converted model. + tokenizer_dir: directory containing tokenizer files to copy over to save_dir. + """ + if skip_if_exists and os.path.exists(save_dir): + print(f"Checkpoint {save_dir} already exists, skipping...") + return + if n_swa + n_gdn + n_kl > 48: + raise ValueError("n_swa + n_gdn + n_kl exceeds total number of layers (48)") + + base_config = AutoConfig.from_pretrained(base_checkpoint, trust_remote_code=True) + + hybrid_block_layout = ["t"] * base_config.text_config.num_hidden_layers + assert ( + n_swa + n_gdn + n_kl <= base_config.text_config.num_hidden_layers + ), "n_swa + n_gdn + n_kl exceeds total number of layers" + + for swa_idx in range(n_swa): + hybrid_block_layout[swa_idx] = "swa" + for gdn_idx in range(n_gdn): + hybrid_block_layout[gdn_idx + n_swa] = "gdn" + for kl_idx in range(n_kl): + hybrid_block_layout[kl_idx + n_swa + n_gdn] = "kl" + + setattr(base_config.text_config, "hybrid_block_layout", hybrid_block_layout) + hybrid_config = make_hybrid_llava_config(base_config, swa_size) + + print(hybrid_config.text_config.hybrid_block_layout) + + hybrid_config.text_config.ssm_cfg["activation"] = "silu" + llava_hybrid_model = LlavaHybridForConditionalGeneration(hybrid_config) + + # Save state-dict + llava_hybrid_model.save_pretrained(save_dir) # here dtype is set to float32 for some reason + # Save new config + hybrid_config.save_pretrained(save_dir) + + # Copy modeling and tokenizer files + modeling_files = [ + configuration_ssm_hybrid_apriel15b.__file__, + configuration_llava_hybrid.__file__, + modeling_llava_hybrid.__file__, + modeling_ssm_hybrid_apriel15b.__file__, + ] + tokenizer_files = [ + f"{tokenizer_dir}/tokenizer.json", + f"{tokenizer_dir}/tokenizer_config.json", + f"{tokenizer_dir}/generation_config.json", + f"{tokenizer_dir}/special_tokens_map.json", + f"{tokenizer_dir}/preprocessor_config.json", + ] + for f in modeling_files + tokenizer_files: + shutil.copy(f, save_dir) + + # Update config with auto_maps + config_file = f"{save_dir}/config.json" + with open(config_file) as f: + dumped_config = json.load(f) + + dumped_config["auto_map"] = { + "AutoConfig": "configuration_llava_hybrid.LlavaHybridConfig", + "AutoModel": "modeling_llava_hybrid.LlavaHybridModel", + "AutoModelForVision2Seq": "modeling_llava_hybrid.LlavaHybridForConditionalGeneration", + "AutoModelForCausalLM": "modeling_llava_hybrid.LlavaHybridForConditionalGeneration", + } + dumped_config["text_config"]["auto_map"] = { + "AutoConfig": "configuration_ssm_hybrid_apriel15b.AprielSSMHybridConfig", + "AutoModel": "modeling_ssm_hybrid_apriel15b.AprielThinkerSSMHybridModel", + "AutoModelForCausalLM": "modeling_ssm_hybrid_apriel15b.AprielThinkerSSMHybridForCausalLM", + } + dumped_config["architectures"] = ["LlavaHybridForConditionalGeneration"] + dumped_config["text_config"]["architectures"] = ["AprielThinkerSSMHybridForCausalLM"] + with open(config_file, "w") as f: + json.dump(dumped_config, f, indent=2) + + print(f"Done to {save_dir}") + + torch.cuda.empty_cache() + gc.collect() + + +if __name__ == "__main__": + main()