Skip to content

Commit 552ac10

Browse files
authored
feat: load only text weights from multimodal gemma (#2008)
1 parent 15461ff commit 552ac10

File tree

2 files changed

+86
-3
lines changed

2 files changed

+86
-3
lines changed

litgpt/scripts/convert_hf_checkpoint.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import json
55
import os
66
import re
7+
import warnings
78
from collections import defaultdict
89
from functools import partial
910
from pathlib import Path
@@ -294,6 +295,7 @@ def copy_weights_gemma_3(
294295
pbar: Optional[tqdm] = None,
295296
progress_per_file: Optional[float] = None,
296297
debug_mode: Optional[bool] = False,
298+
config: Optional[Config] = None,
297299
) -> None:
298300
weight_map = {
299301
"model.embed_tokens.weight": "transformer.wte.weight",
@@ -316,11 +318,20 @@ def copy_weights_gemma_3(
316318

317319
if progress_per_file is not None:
318320
progress_per_file = progress_per_file / max(1, len(hf_weights) + len(qkv_weights))
319-
321+
# gemma3 4b+ are multimodel models, but we are only loading the text weights
322+
is_multimodal = any(k.startswith("language_model") for k in hf_weights)
323+
if is_multimodal:
324+
warnings.warn("For Gemma3 models only the text component is supported.")
325+
weight_map = {f"language_model.{k}": v for k, v in weight_map.items()}
320326
for from_name, param in hf_weights.items():
327+
if from_name.startswith("vision_tower") or from_name.startswith("multi_modal_projector"):
328+
continue
321329
name_template, *ids = layer_template(from_name, num_matches=2)
322330
to_name = weight_map[name_template]
323331
param = load_param(param, from_name, dtype, verbose=debug_mode)
332+
# in multimodal models, the text weights are the first part of the weights
333+
if is_multimodal and to_name == "transformer.wte.weight" and config is not None:
334+
param = param[: config.vocab_size]
324335
if any(w in from_name for w in ("q_proj", "k_proj", "v_proj")):
325336
qkv = qkv_weights.setdefault(ids[0], defaultdict(dict))
326337
weight_name, weight_type = from_name.split(".")[-2:]
@@ -604,7 +615,7 @@ def convert_hf_checkpoint(
604615
copy_fn = partial(copy_weights_gemma_2, qkv_weights)
605616
elif model_name.lower().startswith("gemma-3"):
606617
qkv_weights = {}
607-
copy_fn = partial(copy_weights_gemma_3, qkv_weights)
618+
copy_fn = partial(copy_weights_gemma_3, qkv_weights, config=config)
608619
elif model_name.lower().startswith("phi"):
609620
# holder to reconstitute the split q, k, v
610621
qkv_weights = {}

tests/test_model.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from transformers.models.falcon import FalconConfig, FalconForCausalLM
2424
from transformers.models.gemma import GemmaConfig, GemmaForCausalLM
2525
from transformers.models.gemma2 import Gemma2Config, Gemma2ForCausalLM
26-
from transformers.models.gemma3 import Gemma3ForCausalLM, Gemma3TextConfig
26+
from transformers.models.gemma3 import Gemma3Config, Gemma3ForCausalLM, Gemma3ForConditionalGeneration, Gemma3TextConfig
2727
from transformers.models.gpt_neox import GPTNeoXConfig, GPTNeoXForCausalLM
2828
from transformers.models.llama import LlamaConfig, LlamaForCausalLM
2929
from transformers.models.mistral import MistralConfig, MistralForCausalLM
@@ -872,6 +872,78 @@ def test_against_original_gemma_3(model_name, device, dtype):
872872
torch.testing.assert_close(ours_y, theirs_y, rtol=3e-5, atol=3e-5)
873873

874874

875+
@torch.inference_mode()
876+
@pytest.mark.parametrize("model_name", ["gemma-3-4b-it", "gemma-3-12b-it", "gemma-3-27b-it"])
877+
@pytest.mark.parametrize(
878+
("device", "dtype"),
879+
[
880+
(torch.device("cpu"), torch.float32),
881+
pytest.param(
882+
torch.device("cuda"),
883+
torch.float16,
884+
marks=[
885+
# the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input
886+
# is slightly different
887+
pytest.mark.xfail(raises=AssertionError, strict=False),
888+
_RunIf(min_cuda_gpus=1),
889+
],
890+
),
891+
],
892+
)
893+
def test_against_multimodal_gemma_3(model_name, device, dtype):
894+
torch.set_default_dtype(dtype)
895+
896+
T = 20
897+
ours_config = Config.from_name(
898+
model_name,
899+
block_size=T,
900+
sliding_window_size=T // 2,
901+
n_layer=2,
902+
n_head=16,
903+
n_embd=32,
904+
intermediate_size=86,
905+
)
906+
907+
theirs_config = Gemma3Config(
908+
Gemma3TextConfig(
909+
vocab_size=ours_config.padded_vocab_size,
910+
hidden_size=ours_config.n_embd,
911+
head_dim=ours_config.head_size,
912+
num_attention_heads=ours_config.n_head,
913+
num_hidden_layers=ours_config.n_layer,
914+
intermediate_size=ours_config.intermediate_size,
915+
max_position_embeddings=ours_config.block_size,
916+
sliding_window=ours_config.sliding_window_size,
917+
rms_norm_eps=ours_config.norm_eps,
918+
num_key_value_heads=ours_config.n_query_groups,
919+
rope_theta=ours_config.rope_base,
920+
attention_bias=ours_config.bias,
921+
tie_word_embeddings=True,
922+
hidden_act="gelu_pytorch_tanh",
923+
attn_implementation="eager",
924+
query_pre_attn_scalar=ours_config.attention_scores_scalar,
925+
rope_scaling={"factor": 8.0, "rope_type": "linear"},
926+
rope_local_base_freq=ours_config.rope_local_base_freq,
927+
)
928+
)
929+
930+
theirs_model = Gemma3ForConditionalGeneration(theirs_config).to(device)
931+
theirs_state_dict = theirs_model.state_dict()
932+
933+
state_dict = {}
934+
935+
copy_weights_gemma_3({}, state_dict, theirs_state_dict, config=ours_config)
936+
ours_model = GPT(ours_config).to(device)
937+
ours_model.load_state_dict(state_dict)
938+
939+
# test end to end
940+
x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0)
941+
assert x.size(1) == T
942+
ours_y = ours_model(x)
943+
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
944+
torch.testing.assert_close(ours_y, theirs_y, rtol=3e-5, atol=3e-5)
945+
946+
875947
@torch.inference_mode()
876948
@pytest.mark.parametrize(
877949
"model_name", ["Qwen2.5-1.5B", "Qwen2.5-Coder-1.5B", "Qwen2.5-Math-1.5B", "QwQ-32B-Preview", "QwQ-32B"]

0 commit comments

Comments
 (0)