|
23 | 23 | from transformers.models.falcon import FalconConfig, FalconForCausalLM
|
24 | 24 | from transformers.models.gemma import GemmaConfig, GemmaForCausalLM
|
25 | 25 | from transformers.models.gemma2 import Gemma2Config, Gemma2ForCausalLM
|
26 |
| -from transformers.models.gemma3 import Gemma3ForCausalLM, Gemma3TextConfig |
| 26 | +from transformers.models.gemma3 import Gemma3Config, Gemma3ForCausalLM, Gemma3ForConditionalGeneration, Gemma3TextConfig |
27 | 27 | from transformers.models.gpt_neox import GPTNeoXConfig, GPTNeoXForCausalLM
|
28 | 28 | from transformers.models.llama import LlamaConfig, LlamaForCausalLM
|
29 | 29 | from transformers.models.mistral import MistralConfig, MistralForCausalLM
|
@@ -872,6 +872,78 @@ def test_against_original_gemma_3(model_name, device, dtype):
|
872 | 872 | torch.testing.assert_close(ours_y, theirs_y, rtol=3e-5, atol=3e-5)
|
873 | 873 |
|
874 | 874 |
|
| 875 | +@torch.inference_mode() |
| 876 | +@pytest.mark.parametrize("model_name", ["gemma-3-4b-it", "gemma-3-12b-it", "gemma-3-27b-it"]) |
| 877 | +@pytest.mark.parametrize( |
| 878 | + ("device", "dtype"), |
| 879 | + [ |
| 880 | + (torch.device("cpu"), torch.float32), |
| 881 | + pytest.param( |
| 882 | + torch.device("cuda"), |
| 883 | + torch.float16, |
| 884 | + marks=[ |
| 885 | + # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input |
| 886 | + # is slightly different |
| 887 | + pytest.mark.xfail(raises=AssertionError, strict=False), |
| 888 | + _RunIf(min_cuda_gpus=1), |
| 889 | + ], |
| 890 | + ), |
| 891 | + ], |
| 892 | +) |
| 893 | +def test_against_multimodal_gemma_3(model_name, device, dtype): |
| 894 | + torch.set_default_dtype(dtype) |
| 895 | + |
| 896 | + T = 20 |
| 897 | + ours_config = Config.from_name( |
| 898 | + model_name, |
| 899 | + block_size=T, |
| 900 | + sliding_window_size=T // 2, |
| 901 | + n_layer=2, |
| 902 | + n_head=16, |
| 903 | + n_embd=32, |
| 904 | + intermediate_size=86, |
| 905 | + ) |
| 906 | + |
| 907 | + theirs_config = Gemma3Config( |
| 908 | + Gemma3TextConfig( |
| 909 | + vocab_size=ours_config.padded_vocab_size, |
| 910 | + hidden_size=ours_config.n_embd, |
| 911 | + head_dim=ours_config.head_size, |
| 912 | + num_attention_heads=ours_config.n_head, |
| 913 | + num_hidden_layers=ours_config.n_layer, |
| 914 | + intermediate_size=ours_config.intermediate_size, |
| 915 | + max_position_embeddings=ours_config.block_size, |
| 916 | + sliding_window=ours_config.sliding_window_size, |
| 917 | + rms_norm_eps=ours_config.norm_eps, |
| 918 | + num_key_value_heads=ours_config.n_query_groups, |
| 919 | + rope_theta=ours_config.rope_base, |
| 920 | + attention_bias=ours_config.bias, |
| 921 | + tie_word_embeddings=True, |
| 922 | + hidden_act="gelu_pytorch_tanh", |
| 923 | + attn_implementation="eager", |
| 924 | + query_pre_attn_scalar=ours_config.attention_scores_scalar, |
| 925 | + rope_scaling={"factor": 8.0, "rope_type": "linear"}, |
| 926 | + rope_local_base_freq=ours_config.rope_local_base_freq, |
| 927 | + ) |
| 928 | + ) |
| 929 | + |
| 930 | + theirs_model = Gemma3ForConditionalGeneration(theirs_config).to(device) |
| 931 | + theirs_state_dict = theirs_model.state_dict() |
| 932 | + |
| 933 | + state_dict = {} |
| 934 | + |
| 935 | + copy_weights_gemma_3({}, state_dict, theirs_state_dict, config=ours_config) |
| 936 | + ours_model = GPT(ours_config).to(device) |
| 937 | + ours_model.load_state_dict(state_dict) |
| 938 | + |
| 939 | + # test end to end |
| 940 | + x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0) |
| 941 | + assert x.size(1) == T |
| 942 | + ours_y = ours_model(x) |
| 943 | + theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float |
| 944 | + torch.testing.assert_close(ours_y, theirs_y, rtol=3e-5, atol=3e-5) |
| 945 | + |
| 946 | + |
875 | 947 | @torch.inference_mode()
|
876 | 948 | @pytest.mark.parametrize(
|
877 | 949 | "model_name", ["Qwen2.5-1.5B", "Qwen2.5-Coder-1.5B", "Qwen2.5-Math-1.5B", "QwQ-32B-Preview", "QwQ-32B"]
|
|
0 commit comments