Skip to content

Commit cd1df18

Browse files
authored
Multimodal-SSM fixes and utils (#357)
1 parent 4e860cc commit cd1df18

File tree

6 files changed

+324
-3
lines changed

6 files changed

+324
-3
lines changed

fast_llm/data/dataset/gpt/memmap.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,9 @@ def _init(
6565
offset = stream.tell()
6666

6767
if num_documents is not None:
68-
assert self._num_documents == num_documents
68+
assert (
69+
self._num_documents == num_documents
70+
), f"Inconsistent num_documents for dataset {self.name} - {self._prefix}. Expected {num_documents}, got {self._num_documents}."
6971

7072
self._index_bin_buffer_mmap = np.memmap(self._prefix.with_suffix(".idx"), mode="r", order="C")
7173
self._index_bin_buffer = memoryview(self._index_bin_buffer_mmap)

fast_llm/functional/cross_entropy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def _fused_cross_entropy_forward_backward(
160160

161161
per_sample_loss = sum_exp_logits.log() - predicted_logits
162162
if loss_mask is not None:
163-
per_sample_loss = per_sample_loss[loss_mask]
163+
per_sample_loss = per_sample_loss * loss_mask
164164

165165
loss = per_sample_loss.mean()
166166
if target_format != TargetFormat.labels and group is not None:

fast_llm/layers/vision_encoder/adapter.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,15 @@ def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace):
2626
bias=True,
2727
weight_init_method=init_normal_(std=config.adapter_init_method_std),
2828
bias_init_method=init_normal_(std=config.adapter_init_method_std),
29+
lr_scale=config.adapter_lr_scale,
2930
)
3031
self.layer_2 = Linear(
3132
tensor_space[VisionEncoderDimNames.adapter_size],
3233
tensor_space[TransformerDimNames.hidden],
3334
bias=True,
3435
weight_init_method=init_normal_(std=config.adapter_init_method_std),
3536
bias_init_method=init_normal_(std=config.adapter_init_method_std),
37+
lr_scale=config.adapter_lr_scale,
3638
)
3739

3840
def forward(

fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def prepare_inputs_for_generation(
7676
cache_position=None,
7777
position_ids=None,
7878
use_cache=True,
79+
pixel_values=None,
7980
**kwargs,
8081
):
8182
# Copy of the method from `AprielThinkerSSMHybridForCausalLM`
@@ -95,7 +96,7 @@ def prepare_inputs_for_generation(
9596
input_ids = input_ids[:, cache_position]
9697
else:
9798
past_key_values = HybridMambaAttentionDynamicCache(
98-
self.config, input_ids.shape[0], self.dtype, device=self.device
99+
self.config.text_config, input_ids.shape[0], self.dtype, device=self.device
99100
)
100101

101102
if attention_mask is not None and position_ids is None:
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
import gc
2+
3+
import click
4+
import torch
5+
from transformers import AutoModelForCausalLM
6+
7+
from fast_llm.models.ssm.external.apriel_15b_hybrid.configuration_ssm_hybrid_apriel15b import AprielSSMHybridConfig
8+
from fast_llm.models.ssm.external.apriel_15b_hybrid.modeling_ssm_hybrid_apriel15b import (
9+
AprielSSMM2DecoderLayer,
10+
AprielThinkerSSMHybridForCausalLM,
11+
)
12+
13+
device = "cuda" if torch.cuda.is_available() else "cpu"
14+
15+
dstate = 16
16+
expand = 1
17+
# Calculate derived dimensions for the Mamba1 configuration
18+
# d_model = config_base.text_config.hidden_size
19+
d_inner = 4096 # hard code to match thinker #expand * d_model
20+
d_xb = 1024 # hard code to match thinker #config_thinker.num_key_value_heads * (config_thinker.hidden_size // config_thinker.num_attention_heads)
21+
22+
23+
def convert_layers(
24+
transformer_config,
25+
transformer_model,
26+
mamba_config,
27+
hybrid_block_layout,
28+
init_with_kqvo,
29+
torch_dtype=torch.bfloat16,
30+
):
31+
config = transformer_config
32+
embed_dim = config.hidden_size
33+
num_heads = config.num_attention_heads
34+
num_heads_kv = config.num_key_value_heads
35+
head_dim = embed_dim // num_heads
36+
head_dim * num_heads
37+
head_dim * num_heads_kv
38+
39+
for layer_idx, type in enumerate(hybrid_block_layout):
40+
print("Converting layer %d...", layer_idx)
41+
# Fetch the layer module for easier access
42+
layer_module = transformer_model.layers._modules[f"{layer_idx}"]
43+
if type == "t":
44+
print("Skipping transformer layer %d..." % layer_idx)
45+
elif type == "m2":
46+
print("Converting layer %d..." % layer_idx)
47+
# Use MambaDecoderLayer for the remaining layers
48+
mamba_encoder = AprielSSMM2DecoderLayer(
49+
mamba_config,
50+
layer_idx,
51+
device="cpu",
52+
dtype=torch_dtype,
53+
)
54+
55+
mamba_encoder.mlp.load_state_dict(layer_module.mlp.state_dict())
56+
mamba_encoder.input_layernorm.load_state_dict(layer_module.input_layernorm.state_dict())
57+
mamba_encoder.post_attention_layernorm.load_state_dict(layer_module.post_attention_layernorm.state_dict())
58+
mamba_encoder.mixer.out_proj.load_state_dict(layer_module.self_attn.o_proj.state_dict())
59+
60+
if init_with_kqvo:
61+
# Copy weights: [z, x, B, C, dt], x -> v, B -> k, C -> q
62+
mamba_encoder.mixer.in_proj.weight.data[
63+
mamba_config.ssm_cfg["d_inner"] : mamba_config.ssm_cfg["d_inner"] + mamba_config.ssm_cfg["d_xb"], :
64+
].copy_(layer_module.self_attn.v_proj.weight.data)
65+
mamba_encoder.mixer.in_proj.weight.data[
66+
mamba_config.ssm_cfg["d_inner"]
67+
+ mamba_config.ssm_cfg["d_xb"] : mamba_config.ssm_cfg["d_inner"]
68+
+ 2 * mamba_config.ssm_cfg["d_xb"],
69+
:,
70+
].copy_(layer_module.self_attn.k_proj.weight.data)
71+
mamba_encoder.mixer.in_proj.weight.data[
72+
mamba_config.ssm_cfg["d_inner"]
73+
+ 2 * mamba_config.ssm_cfg["d_xb"] : 2 * mamba_config.ssm_cfg["d_inner"]
74+
+ 2 * mamba_config.ssm_cfg["d_xb"],
75+
:,
76+
].copy_(layer_module.self_attn.q_proj.weight.data)
77+
78+
print("Init Mamba using Attention")
79+
80+
transformer_model.layers[layer_idx] = mamba_encoder
81+
82+
else:
83+
raise ValueError(f"Invalid layer type: {type}")
84+
85+
86+
def make_hybrid_config(transformer):
87+
config_dict = transformer.config.to_dict()
88+
config_dict["hybrid_block_layout"] = ["t"] * transformer.config.num_hidden_layers
89+
config_dict["model_type"] = "apriel_ssm_thinker_hybrid"
90+
config_dict["ssm_cfg"] = {
91+
"activation": "silu",
92+
"d_state": dstate,
93+
"d_xb": d_xb,
94+
"expand": expand,
95+
"d_conv": 4,
96+
"d_inner": d_inner,
97+
"conv_bias": True,
98+
"bias": False,
99+
}
100+
hybrid_config = AprielSSMHybridConfig.from_dict(**config_dict)
101+
return hybrid_config
102+
103+
104+
@click.command()
105+
@click.option(
106+
"--base_checkpoint", type=str, required=False, default="/mnt/checkpoints/upstream/Apriel-Nemotron-15b-Thinker"
107+
)
108+
@click.option("--m2_indices", type=int, multiple=True, required=True)
109+
@click.option("--hybrid_checkpoint", type=str, required=True)
110+
@click.option("--save_dir", type=str, required=True)
111+
def main(base_checkpoint: str, m2_indices: list, hybrid_checkpoint: str, save_dir: str):
112+
"""
113+
base_checkpoint: path to base transformer-model (teacher model)
114+
m2_indices: indices of layers to convert to mamba layers with MiL init
115+
hybrid_checkpoint: path to hybrid model (student model).
116+
save_dir: directory to save the converted model.
117+
118+
TODO: base_checkpoint can actually be a hybrid. Rename transformer variable to a better name
119+
"""
120+
m2_indices = list(m2_indices) # convert tuple -> list
121+
transformer = AutoModelForCausalLM.from_pretrained(base_checkpoint, trust_remote_code=True)
122+
if hybrid_checkpoint == "none":
123+
print("No hybrid checkpoint provided, creating new config from base model.")
124+
hybrid_config = make_hybrid_config(transformer)
125+
else:
126+
hybrid_config = AprielSSMHybridConfig.from_pretrained(hybrid_checkpoint)
127+
128+
hybrid_block_layout = hybrid_config.hybrid_block_layout
129+
for m2_index in m2_indices:
130+
hybrid_block_layout[m2_index] = "m2"
131+
print(hybrid_block_layout)
132+
133+
convert_layers(
134+
transformer.config,
135+
transformer.model,
136+
hybrid_config,
137+
hybrid_block_layout,
138+
init_with_kqvo=True,
139+
torch_dtype=torch.bfloat16,
140+
)
141+
hybrid_config.ssm_cfg["activation"] = "silu"
142+
143+
# load all existing ssm layers
144+
if hybrid_checkpoint != "none":
145+
hybrid_model = AprielThinkerSSMHybridForCausalLM.from_pretrained(hybrid_checkpoint)
146+
state_dict = hybrid_model.state_dict()
147+
missing, unexpected = transformer.load_state_dict(state_dict, strict=False)
148+
for m2_index in m2_indices:
149+
assert f"model.layers.{m2_index}.mixer.A_log" in missing
150+
assert f"model.layers.{m2_index}.self_attn.q_proj.weight" in unexpected
151+
print("MISSING", missing)
152+
print("UNEXPECTED", unexpected)
153+
154+
# Save state-dict
155+
transformer.save_pretrained(save_dir)
156+
157+
hybrid_config.save_pretrained(save_dir)
158+
159+
gc.collect()
160+
161+
162+
if __name__ == "__main__":
163+
main()
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
import gc
2+
import json
3+
import shutil
4+
5+
import click
6+
import torch
7+
from transformers import AutoModelForVision2Seq
8+
9+
from fast_llm.models.ssm.external.apriel_15b_hybrid import modeling_ssm_hybrid_apriel15b
10+
from fast_llm.models.ssm.external.llava_hybrid import configuration_llava_hybrid, modeling_llava_hybrid
11+
from fast_llm.models.ssm.external.llava_hybrid.configuration_llava_hybrid import LlavaHybridConfig
12+
from fast_llm.models.ssm.external.llava_hybrid.modeling_llava_hybrid import LlavaHybridForConditionalGeneration
13+
from fast_llm.models.ssm.external.make_hybrid_checkpoint import convert_layers
14+
15+
device = "cuda" if torch.cuda.is_available() else "cpu"
16+
17+
dstate = 16
18+
expand = 1
19+
# Calculate derived dimensions for the Mamba1 configuration
20+
# d_model = config_base.text_config.hidden_size
21+
d_inner = 4096 # hard code to match thinker #expand * d_model
22+
d_xb = 1024 # hard code to match thinker #config_thinker.num_key_value_heads * (config_thinker.hidden_size // config_thinker.num_attention_heads)
23+
24+
25+
def make_hybrid_llava_config(transformer):
26+
config_dict = transformer.config.to_dict()
27+
config_dict["text_config"]["hybrid_block_layout"] = ["t"] * transformer.config.text_config.num_hidden_layers
28+
config_dict["text_config"]["model_type"] = "apriel_ssm_thinker_hybrid"
29+
config_dict["text_config"]["ssm_cfg"] = {
30+
"activation": "silu",
31+
"d_state": dstate,
32+
"d_xb": d_xb,
33+
# "d_model": d_model, # will be set automatically
34+
"expand": expand,
35+
"d_conv": 4,
36+
"d_inner": d_inner, # will be same as d_model * expand,
37+
"conv_bias": True,
38+
"bias": False,
39+
}
40+
llava_hybrid_config = LlavaHybridConfig(**config_dict)
41+
return llava_hybrid_config
42+
43+
44+
def make_hybrid_llava_model(transformer, llava_hybrid_config):
45+
"""
46+
Create a LlavaHybridForConditionalGeneration model with the same configuration as the given transformer model.
47+
"""
48+
llava_hybrid_model = LlavaHybridForConditionalGeneration(llava_hybrid_config)
49+
# llava_hybrid_model.to(dtype=torch.bfloat16).to(device)
50+
llava_hybrid_model.load_state_dict(transformer.state_dict(), strict=False)
51+
return llava_hybrid_model
52+
53+
54+
@click.command()
55+
@click.option("--base_checkpoint", type=str, required=False, default="ServiceNow-AI/Apriel-Nemotron-15b-Thinker")
56+
@click.option("--m2_indices", type=int, multiple=True, required=True)
57+
@click.option("--hybrid_checkpoint", type=str, required=True)
58+
@click.option("--save_dir", type=str, required=True)
59+
@click.option(
60+
"--tokenizer_dir", type=str, required=False, default="/mnt/plato/checkpoints/upstream/Mistral-Nemo-Base-2407/"
61+
)
62+
def main(base_checkpoint: str, m2_indices: list[int], hybrid_checkpoint: str, save_dir: str, tokenizer_dir: str):
63+
"""
64+
base_checkpoint: path to base transformer-model (teacher model)
65+
m2_indices: indices of layers to convert to mamba layers with MiL init
66+
hybrid_checkpoint: path to hybrid model (student model). Can be a hybrid with only transformer layers for the first distillation run.
67+
save_dir: directory to save the converted model.
68+
tokenizer_dir: directory containing tokenizer files to copy over to save_dir.
69+
"""
70+
m2_indices = list(m2_indices) # convert tuple -> list
71+
transformer = AutoModelForVision2Seq.from_pretrained(base_checkpoint, trust_remote_code=True)
72+
if hybrid_checkpoint == "none":
73+
print("No hybrid checkpoint provided, creating new config from base model.")
74+
hybrid_config = make_hybrid_llava_config(transformer)
75+
else:
76+
hybrid_config = LlavaHybridConfig.from_pretrained(hybrid_checkpoint)
77+
78+
hybrid_block_layout = hybrid_config.text_config.hybrid_block_layout
79+
for m2_index in m2_indices:
80+
hybrid_block_layout[m2_index] = "m2"
81+
print(hybrid_block_layout)
82+
83+
# MiL init
84+
convert_layers(
85+
transformer.model.language_model.config,
86+
transformer.model.language_model,
87+
hybrid_config.text_config,
88+
hybrid_block_layout,
89+
init_with_kqvo=True,
90+
torch_dtype=torch.bfloat16,
91+
)
92+
hybrid_config.text_config.ssm_cfg["activation"] = "silu"
93+
94+
# Load existing SSM layers
95+
if hybrid_checkpoint != "none":
96+
hybrid_llava_model = AutoModelForVision2Seq.from_pretrained(
97+
hybrid_checkpoint, torch_dtype=torch.bfloat16, trust_remote_code=True
98+
)
99+
llava_state_dict = hybrid_llava_model.state_dict()
100+
missing, unexpected = transformer.load_state_dict(llava_state_dict, strict=False)
101+
for m2_index in m2_indices:
102+
assert f"model.layers.{m2_index}.mixer.A_log" in missing
103+
assert f"model.layers.{m2_index}.self_attn.q_proj.weight" in unexpected
104+
print("MISSING", missing)
105+
print("UNEXPECTED", unexpected)
106+
107+
# Save state-dict
108+
transformer.save_pretrained(save_dir)
109+
# Save new config
110+
hybrid_config.save_pretrained(save_dir)
111+
112+
# Copy modeling and tokenizer files
113+
modeling_files = [
114+
configuration_llava_hybrid.__file__,
115+
modeling_llava_hybrid.__file__,
116+
modeling_ssm_hybrid_apriel15b.__file__,
117+
]
118+
tokenizer_files = [
119+
f"{tokenizer_dir}/tokenizer.json",
120+
f"{tokenizer_dir}/tokenizer_config.json",
121+
f"{tokenizer_dir}/generation_config.json",
122+
f"{tokenizer_dir}/special_tokens_map.json",
123+
]
124+
for f in modeling_files + tokenizer_files:
125+
shutil.copy(f, save_dir)
126+
127+
# Update config with auto_maps
128+
config_file = f"{save_dir}/config.json"
129+
with open(config_file) as f:
130+
dumped_config = json.load(f)
131+
132+
dumped_config["auto_map"] = {
133+
"AutoConfig": "configuration_llava_hybrid.LlavaHybridConfig",
134+
"AutoModel": "modeling_llava_hybrid.LlavaHybridModel",
135+
"AutoModelForVision2Seq": "modeling_llava_hybrid.LlavaHybridForConditionalGeneration",
136+
"AutoModelForCausalLM": "modeling_llava_hybrid.LlavaHybridForConditionalGeneration",
137+
}
138+
dumped_config["text_config"]["auto_map"] = {
139+
"AutoConfig": "configuration_ssm_hybrid_apriel15b.AprielSSMHybridConfig",
140+
"AutoModel": "modeling_ssm_hybrid_apriel15b.AprielThinkerSSMHybridModel",
141+
"AutoModelForCausalLM": "modeling_ssm_hybrid_apriel15b.AprielThinkerSSMHybridForCausalLM",
142+
}
143+
dumped_config["architectures"] = ["LlavaHybridForConditionalGeneration"]
144+
dumped_config["text_config"]["architectures"] = ["AprielThinkerSSMHybridForCausalLM"]
145+
with open(config_file, "w") as f:
146+
json.dump(dumped_config, f, indent=2)
147+
148+
torch.cuda.empty_cache()
149+
gc.collect()
150+
151+
152+
if __name__ == "__main__":
153+
main()

0 commit comments

Comments
 (0)