Skip to content

Commit b157e9c

Browse files
authored
add tests for gemma3 (#2006)
1 parent cabec5f commit b157e9c

File tree

5 files changed

+319
-3
lines changed

5 files changed

+319
-3
lines changed

litgpt/scripts/convert_lit_checkpoint.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,57 @@ def copy_weights_gemma_2(
215215
state_dict[to_name] = param
216216

217217

218+
def copy_weights_gemma_3(
219+
config: Config,
220+
state_dict: Dict[str, torch.Tensor],
221+
lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],
222+
untie_weights: bool = False,
223+
saver: Optional[incremental_save] = None,
224+
) -> None:
225+
weight_map = {
226+
"transformer.wte.weight": "model.embed_tokens.weight",
227+
"transformer.h.{}.attn.proj.weight": "model.layers.{}.self_attn.o_proj.weight",
228+
"transformer.h.{}.mlp.fc_1.weight": "model.layers.{}.mlp.gate_proj.weight",
229+
"transformer.h.{}.mlp.fc_2.weight": "model.layers.{}.mlp.up_proj.weight",
230+
"transformer.h.{}.mlp.proj.weight": "model.layers.{}.mlp.down_proj.weight",
231+
"transformer.h.{}.norm_1.weight": "model.layers.{}.input_layernorm.weight",
232+
"transformer.h.{}.post_attention_norm.weight": "model.layers.{}.post_attention_layernorm.weight",
233+
"transformer.h.{}.norm_2.weight": "model.layers.{}.pre_feedforward_layernorm.weight",
234+
"transformer.h.{}.post_mlp_norm.weight": "model.layers.{}.post_feedforward_layernorm.weight",
235+
"transformer.ln_f.weight": "model.norm.weight",
236+
"lm_head.weight": "lm_head.weight",
237+
"transformer.h.{}.attn.norm_q.weight": "model.layers.{}.self_attn.q_norm.weight",
238+
"transformer.h.{}.attn.norm_k.weight": "model.layers.{}.self_attn.k_norm.weight",
239+
}
240+
241+
for from_name, param in lit_weights.items():
242+
if from_name == "lm_head.weight" and untie_weights:
243+
continue
244+
name_template, *ids = layer_template(from_name, num_matches=2)
245+
param = load_param(param, from_name, None)
246+
if from_name.endswith(".attn.qkv.weight"):
247+
to_names = (
248+
"model.layers.{}.self_attn.q_proj.weight".format(*ids),
249+
"model.layers.{}.self_attn.k_proj.weight".format(*ids),
250+
"model.layers.{}.self_attn.v_proj.weight".format(*ids),
251+
)
252+
params = param.split(
253+
(
254+
config.n_head * config.head_size,
255+
config.n_query_groups * config.head_size,
256+
config.n_query_groups * config.head_size,
257+
)
258+
)
259+
else:
260+
to_names = (weight_map[name_template].format(*ids),)
261+
params = (param,)
262+
263+
for to_name, param in zip(to_names, params):
264+
if saver is not None:
265+
param = saver.store_early(param)
266+
state_dict[to_name] = param
267+
268+
218269
def copy_weights_phi(
219270
config: Config,
220271
state_dict: Dict[str, torch.Tensor],

tests/convert/test_lit_checkpoint.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from transformers.models.falcon import FalconConfig, FalconForCausalLM
1212
from transformers.models.gemma import GemmaConfig, GemmaForCausalLM
1313
from transformers.models.gemma2 import Gemma2Config, Gemma2ForCausalLM
14+
from transformers.models.gemma3 import Gemma3ForCausalLM, Gemma3TextConfig
1415
from transformers.models.gpt_neox import GPTNeoXConfig, GPTNeoXForCausalLM
1516
from transformers.models.llama import LlamaConfig, LlamaForCausalLM
1617
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
@@ -27,6 +28,7 @@
2728
convert_lit_checkpoint,
2829
copy_weights_falcon,
2930
copy_weights_gemma_2,
31+
copy_weights_gemma_3,
3032
copy_weights_gpt_neox,
3133
copy_weights_llama,
3234
copy_weights_phi,
@@ -512,6 +514,79 @@ def test_against_original_gemma_2(model_name, device, dtype):
512514
torch.testing.assert_close(ours_y, theirs_y, rtol=3e-5, atol=3e-5)
513515

514516

517+
@torch.inference_mode()
518+
@pytest.mark.parametrize("model_name", ("gemma-3-1b-it", "gemma-3-4b-it", "gemma-3-12b-it", "gemma-3-27b-it"))
519+
@pytest.mark.parametrize(
520+
("device", "dtype"),
521+
[
522+
(torch.device("cpu"), torch.float32),
523+
pytest.param(
524+
torch.device("cuda"),
525+
torch.float16,
526+
marks=[
527+
# the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input
528+
# is slightly different
529+
pytest.mark.xfail(raises=AssertionError, strict=False),
530+
_RunIf(min_cuda_gpus=1),
531+
],
532+
),
533+
],
534+
)
535+
def test_against_original_gemma_3(model_name, device, dtype):
536+
torch.set_default_dtype(dtype)
537+
538+
T = 20
539+
ours_config = Config.from_name(
540+
model_name,
541+
block_size=T,
542+
sliding_window_size=T // 2,
543+
n_layer=2,
544+
n_head=16,
545+
n_embd=32,
546+
intermediate_size=86,
547+
)
548+
theirs_config = Gemma3TextConfig(
549+
vocab_size=ours_config.padded_vocab_size,
550+
hidden_size=ours_config.n_embd,
551+
head_dim=ours_config.head_size,
552+
num_attention_heads=ours_config.n_head,
553+
num_hidden_layers=ours_config.n_layer,
554+
intermediate_size=ours_config.intermediate_size,
555+
max_position_embeddings=ours_config.block_size,
556+
sliding_window=ours_config.sliding_window_size,
557+
rms_norm_eps=ours_config.norm_eps,
558+
num_key_value_heads=ours_config.n_query_groups,
559+
rope_theta=ours_config.rope_base,
560+
attention_bias=ours_config.bias,
561+
tie_word_embeddings=True,
562+
hidden_act="gelu_pytorch_tanh",
563+
attn_logit_softcapping=ours_config.attention_logit_softcapping,
564+
final_logit_softcapping=ours_config.final_logit_softcapping,
565+
initializer_range=1.0, # to make the affect of attention_logit_softcapping more prominent
566+
attn_implementation="eager",
567+
query_pre_attn_scalar=ours_config.attention_scores_scalar,
568+
)
569+
570+
assert ours_config.intermediate_size == theirs_config.intermediate_size
571+
572+
ours_model = GPT(ours_config).to(device)
573+
# tie weights
574+
ours_model.lm_head.weight = ours_model.transformer.wte.weight
575+
ours_state_dict = ours_model.state_dict()
576+
theirs_state_dict = {}
577+
copy_weights_gemma_3(ours_config, theirs_state_dict, ours_state_dict, untie_weights=True)
578+
theirs_model = Gemma3ForCausalLM(theirs_config).to(device)
579+
keys = theirs_model.load_state_dict(theirs_state_dict, strict=False)
580+
assert not keys.unexpected_keys
581+
582+
# test end to end
583+
x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0)
584+
assert x.size(1) == T
585+
ours_y = ours_model(x)
586+
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
587+
torch.testing.assert_close(ours_y, theirs_y, rtol=3e-5, atol=3e-5)
588+
589+
515590
def test_check_conversion_supported_adapter():
516591
lit_weights = {"some.key.name": ANY, "error.key.gating_factor": ANY}
517592
with pytest.raises(NotImplementedError, match="Converting adapter"):

tests/test_adapter.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,15 @@
1616
from torch._dynamo.backends import debugging
1717
from transformers.models.gemma import GemmaConfig, GemmaForCausalLM
1818
from transformers.models.gemma2 import Gemma2Config, Gemma2ForCausalLM
19+
from transformers.models.gemma3 import Gemma3ForCausalLM, Gemma3TextConfig
1920

2021
import litgpt.adapter as gpt_adapter
2122
import litgpt.finetune.adapter as module
2223
import litgpt.model as gpt
2324
from litgpt.adapter import GPT, CausalSelfAttention, Config, adapter_filter
2425
from litgpt.args import EvalArgs, TrainArgs
2526
from litgpt.data import Alpaca
26-
from litgpt.scripts.convert_hf_checkpoint import copy_weights_gemma_2, copy_weights_hf_llama
27+
from litgpt.scripts.convert_hf_checkpoint import copy_weights_gemma_2, copy_weights_gemma_3, copy_weights_hf_llama
2728
from litgpt.scripts.convert_lit_checkpoint import qkv_reassemble as make_qkv_interleaved
2829
from litgpt.utils import _RunIf
2930

@@ -361,6 +362,77 @@ def test_against_original_gemma_2(model_name, device, dtype):
361362
torch.testing.assert_close(ours_y, theirs_y)
362363

363364

365+
@torch.inference_mode()
366+
@pytest.mark.parametrize("model_name", ("gemma-3-1b-it", "gemma-3-4b-it", "gemma-3-12b-it", "gemma-3-27b-it"))
367+
@pytest.mark.parametrize(
368+
("device", "dtype"),
369+
[
370+
(torch.device("cpu"), torch.float32),
371+
pytest.param(
372+
torch.device("cuda"),
373+
torch.float16,
374+
marks=[
375+
# the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input
376+
# is slightly different
377+
pytest.mark.xfail(raises=AssertionError, strict=False),
378+
_RunIf(min_cuda_gpus=1),
379+
],
380+
),
381+
],
382+
)
383+
def test_against_original_gemma_3(model_name, device, dtype):
384+
torch.set_default_dtype(dtype)
385+
386+
T = 20
387+
ours_config = Config.from_name(
388+
model_name,
389+
block_size=T,
390+
sliding_window_size=T // 2,
391+
n_layer=2,
392+
n_head=16,
393+
n_embd=32,
394+
intermediate_size=86,
395+
)
396+
theirs_config = Gemma3TextConfig(
397+
vocab_size=ours_config.padded_vocab_size,
398+
hidden_size=ours_config.n_embd,
399+
head_dim=ours_config.head_size,
400+
num_attention_heads=ours_config.n_head,
401+
num_hidden_layers=ours_config.n_layer,
402+
intermediate_size=ours_config.intermediate_size,
403+
max_position_embeddings=ours_config.block_size,
404+
sliding_window=ours_config.sliding_window_size,
405+
rms_norm_eps=ours_config.norm_eps,
406+
num_key_value_heads=ours_config.n_query_groups,
407+
rope_theta=ours_config.rope_base,
408+
attention_bias=ours_config.bias,
409+
tie_word_embeddings=True,
410+
hidden_act="gelu_pytorch_tanh",
411+
attn_logit_softcapping=ours_config.attention_logit_softcapping,
412+
final_logit_softcapping=ours_config.final_logit_softcapping,
413+
initializer_range=1.0, # to make the affect of attention_logit_softcapping more prominent
414+
attn_implementation="eager",
415+
query_pre_attn_scalar=ours_config.attention_scores_scalar,
416+
)
417+
assert ours_config.intermediate_size == theirs_config.intermediate_size
418+
419+
theirs_model = Gemma3ForCausalLM(theirs_config).to(device)
420+
theirs_state_dict = theirs_model.state_dict()
421+
# Gemma weights are shipped without `lm_head.weight`
422+
theirs_state_dict.pop("lm_head.weight")
423+
state_dict = {}
424+
copy_weights_gemma_3({}, state_dict, theirs_state_dict)
425+
ours_model = GPT(ours_config).to(device)
426+
ours_model.load_state_dict(state_dict)
427+
428+
# test end to end
429+
x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0)
430+
assert x.size(1) == T
431+
ours_y = ours_model(x)
432+
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
433+
torch.testing.assert_close(ours_y, theirs_y)
434+
435+
364436
def test_load_legacy_state_dict():
365437
"""Check that a legacy state dict (with an interleaved placement in QKV matrix) can be loaded into a model with CausalSelfAttention layers."""
366438
config = Config(

tests/test_adapter_v2.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from torch._dynamo.backends import debugging
1616
from transformers.models.gemma import GemmaConfig, GemmaForCausalLM
1717
from transformers.models.gemma2 import Gemma2Config, Gemma2ForCausalLM
18+
from transformers.models.gemma3 import Gemma3ForCausalLM, Gemma3TextConfig
1819
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
1920

2021
import litgpt.config as config_module
@@ -24,7 +25,7 @@
2425
from litgpt.args import EvalArgs, TrainArgs
2526
from litgpt.data import Alpaca
2627
from litgpt.model import GPT as BaseGPT
27-
from litgpt.scripts.convert_hf_checkpoint import copy_weights_gemma_2, copy_weights_hf_llama
28+
from litgpt.scripts.convert_hf_checkpoint import copy_weights_gemma_2, copy_weights_gemma_3, copy_weights_hf_llama
2829
from litgpt.scripts.convert_lit_checkpoint import qkv_reassemble as make_qkv_interleaved
2930
from litgpt.utils import _RunIf
3031

@@ -316,6 +317,67 @@ def test_against_original_gemma_2(model_name):
316317
) # some macOS devices have numerical differences, hence the tol bump
317318

318319

320+
@torch.inference_mode()
321+
@pytest.mark.parametrize("model_name", ("gemma-3-1b-it", "gemma-3-4b-it", "gemma-3-12b-it", "gemma-3-27b-it"))
322+
def test_against_original_gemma_3(model_name):
323+
device = torch.device("cpu")
324+
dtype = torch.float32
325+
326+
T = 20
327+
ours_config = Config.from_name(
328+
model_name,
329+
block_size=T,
330+
sliding_window_size=T // 2,
331+
n_layer=2,
332+
n_head=16,
333+
n_embd=32,
334+
intermediate_size=86,
335+
)
336+
337+
theirs_config = Gemma3TextConfig(
338+
vocab_size=ours_config.padded_vocab_size,
339+
hidden_size=ours_config.n_embd,
340+
head_dim=ours_config.head_size,
341+
num_attention_heads=ours_config.n_head,
342+
num_hidden_layers=ours_config.n_layer,
343+
intermediate_size=ours_config.intermediate_size,
344+
max_position_embeddings=ours_config.block_size,
345+
sliding_window=ours_config.sliding_window_size,
346+
rms_norm_eps=ours_config.norm_eps,
347+
num_key_value_heads=ours_config.n_query_groups,
348+
rope_theta=ours_config.rope_base,
349+
attention_bias=ours_config.bias,
350+
tie_word_embeddings=True,
351+
hidden_act="gelu_pytorch_tanh",
352+
attn_implementation="eager",
353+
query_pre_attn_scalar=ours_config.attention_scores_scalar,
354+
rope_scaling={"factor": 8.0, "rope_type": "linear"},
355+
rope_local_base_freq=ours_config.rope_local_base_freq,
356+
)
357+
358+
theirs_model = Gemma3ForCausalLM(theirs_config).to(device)
359+
theirs_state_dict = theirs_model.state_dict()
360+
# Gemma weights are shipped without `lm_head.weight`
361+
theirs_state_dict.pop("lm_head.weight")
362+
state_dict = {}
363+
364+
copy_weights_gemma_3({}, state_dict, theirs_state_dict)
365+
ours_model = AdapterV2GPT(ours_config).to(device)
366+
keys = ours_model.load_state_dict(state_dict, strict=False)
367+
assert not keys.unexpected_keys
368+
for k in keys.missing_keys:
369+
assert adapter_filter(k, None)
370+
371+
# test end to end
372+
x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0)
373+
assert x.size(1) == T
374+
ours_y = ours_model(x)
375+
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
376+
torch.testing.assert_close(
377+
ours_y, theirs_y, rtol=3e-5, atol=3e-5
378+
) # some macOS devices have numerical differences, hence the tol bump
379+
380+
319381
@_RunIf(min_cuda_gpus=1)
320382
def test_adapter_v2_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alpaca_path):
321383
if not _BITSANDBYTES_AVAILABLE:

0 commit comments

Comments
 (0)