Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions fast_llm/models/ssm/external/15B_hybrid.ipynb
Original file line number Diff line number Diff line change
@@ -1,5 +1,116 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Add MIL innitialized SSM layers to exsiting SSM checkpoint"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/toolkit/.local/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"from transformers import MistralForCausalLM\n",
"\n",
"from fast_llm.models.ssm.external.apriel_15b_hybrid.configuration_ssm_hybrid_apriel15b import AprielSSMHybridConfig\n",
"from fast_llm.models.ssm.external.apriel_15b_hybrid.modeling_ssm_hybrid_apriel15b import AprielThinkerSSMHybridForCausalLM, AprielSSMM2DecoderLayer, AprielSSMDecoderLayer\n",
"from transformers.models.mistral.modeling_mistral import MistralDecoderLayer\n",
"\n",
"# enable file reload \n",
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"path_thinker = \"/mnt/checkpoints/upstream/Apriel-Nemotron-15b-Thinker\"\n",
"n_ssm = 25\n",
"new_ssm_layers = [3]\n",
"\n",
"config_thinker = AutoConfig.from_pretrained(path_thinker)\n",
"# config_thinker.num_hidden_layers = 5\n",
"hybrid_block_layout = [\"t\"] * config_thinker.num_hidden_layers\n",
"hybrid_block_layout[3] = \"m2\"\n",
"\n",
"\n",
"dstate = 16\n",
"expand = 1\n",
"# Calculate derived dimensions for the Mamba1 configuration\n",
"d_model = config_thinker.hidden_size\n",
"d_inner = 4096 # hard code to match thinker #expand * d_model\n",
"d_xb = 1024 # hard code to match thinker #config_thinker.num_key_value_heads * (config_thinker.hidden_size // config_thinker.num_attention_heads)\n",
"\n",
"config_hybrid = AprielSSMHybridConfig(\n",
" **config_thinker.to_dict(),\n",
" hybrid_block_layout=hybrid_block_layout,\n",
" # discrete mamba2\n",
" # ssm_cfg = {\n",
" # \"d_state\": dstate,\n",
" # \"n_v_heads\": 32,\n",
" # \"n_qk_heads\": 32,\n",
" # \"expand\": 1,\n",
" # \"chunk_size\": 128,\n",
" # \"activation\": \"identity\",\n",
" # \"bias\": False,\n",
" # \"d_conv\": 4,\n",
" # \"d_inner\": 32 * 128,\n",
" # }\n",
" # mamba 2: uses expantion nternally\n",
" # https://github.com/jxiw/M1/blob/537a1ca5407a786a99dc6c721873493cf8750d5e/mamba/hybrid_mamba_config.py\n",
" \n",
" ssm_cfg = {\n",
" \"d_state\": dstate,\n",
" \"d_xb\": d_xb,\n",
" # \"d_model\": d_model, # will be set automatically\n",
" \"expand\": expand,\n",
" \"d_conv\": 4,\n",
" \"d_inner\": d_inner, # will be same as d_model * expand,\n",
" \"conv_bias\": True,\n",
" \"bias\": False,\n",
" }\n",
")\n",
"# model_hybrid = AprielThinkerSSMHybridForCausalLM(config_hybrid)\n",
"# transformer = AutoModelForCausalLM.from_pretrained(path_thinker)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,56 @@
}


class AprielGDNConfig:
def __init__(
self,
linear_num_key_heads=16,
linear_num_value_heads=32,
linear_key_head_dim=128,
linear_value_head_dim=128,
linear_conv_kernel_dim=4,
kl_short_conv_kernel_size=4,
kl_num_heads=32,
kl_head_dim=128,
):
self.linear_num_key_heads = linear_num_key_heads
self.linear_num_value_heads = linear_num_value_heads
self.linear_key_head_dim = linear_key_head_dim
self.linear_value_head_dim = linear_value_head_dim
self.linear_conv_kernel_dim = linear_conv_kernel_dim

# Kimi LInear
self.short_conv_kernel_size = kl_short_conv_kernel_size
self.head_dim = kl_head_dim
self.num_heads = kl_num_heads


LAYER_TYPES = {"t": "full_attention", "swa": "sliding_attention", "gdn": "gated_delta_net", "kl": "kimi_linear"}


class AprielSSMHybridConfig(MistralConfig):
model_type = "apriel_ssm_thinker_hybrid"

def __init__(self, hybrid_block_layout=["m2d"], ssm_cfg=None, **kwargs):
def __init__(self, hybrid_block_layout=["t"], ssm_cfg=None, gdn_cfg=None, **kwargs):
super().__init__(**kwargs)
self.hybrid_block_layout = hybrid_block_layout
self.head_dim = self.head_dim or self.hidden_size // self.num_attention_heads # as in transformers 4.51.3
self.ssm_cfg = ssm_cfg or ssm_config_default

gdn_config: AprielGDNConfig = (
AprielGDNConfig(**gdn_cfg) if isinstance(gdn_cfg, dict) else gdn_cfg or AprielGDNConfig()
)

# make elements of gdn_config accessible as attributes of self to pass self directly to Qwen3NextGatedDeltaNet
for k, v in vars(gdn_config).items():
setattr(self, k, v)

for k, v in ssm_config_default.items():
if k not in self.ssm_cfg:
self.ssm_cfg[k] = v # to make sure all elements are present in the config
self.layer_types = [LAYER_TYPES[lt] for lt in hybrid_block_layout] # this is for vllm compatibility
self.linear_attn_config = {
"short_conv_kernel_size": gdn_config.short_conv_kernel_size,
"head_dim": gdn_config.head_dim,
"num_heads": gdn_config.num_heads,
}
Loading