1717from tensorrt_llm .bindings .executor import KvCacheConfig
1818from tensorrt_llm .mapping import Mapping
1919
20- # This is copied from https://huggingface.co/google/gemma-3-1b-it/blob/main/config.json.
21- # Updated to have 1 local layer and 1 global layer. Sliding window size updated to 4.
22- GEMMA3_1B_MINI_CONFIG = {
20+ GEMMA3_1B_CONFIG = {
2321 "architectures" : ["Gemma3ForCausalLM" ],
2422 "attention_bias" : False ,
2523 "attention_dropout" : 0.0 ,
3634 "max_position_embeddings" : 32768 ,
3735 "model_type" : "gemma3_text" ,
3836 "num_attention_heads" : 4 ,
39- "num_hidden_layers" : 2 , # Modified for testing.
37+ "num_hidden_layers" : 26 ,
4038 "num_key_value_heads" : 1 ,
4139 "pad_token_id" : 0 ,
4240 "query_pre_attn_scalar" : 256 ,
4341 "rms_norm_eps" : 1e-06 ,
4442 "rope_local_base_freq" : 10000 ,
4543 "rope_scaling" : None ,
4644 "rope_theta" : 1000000 ,
47- "sliding_window" : 4 , # Modified for testing.
48- "sliding_window_pattern" : 2 , # Modified for testing.
45+ "sliding_window" : 512 ,
46+ "sliding_window_pattern" : 6 ,
4947 "torch_dtype" : "bfloat16" ,
5048 "transformers_version" : "4.50.0.dev0" ,
5149 "use_cache" : True ,
5250 "vocab_size" : 262144
5351}
5452
53+ GEMMA3_27B_CONFIG = {
54+ "architectures" : ["Gemma3ForConditionalGeneration" ],
55+ "boi_token_index" : 255999 ,
56+ "eoi_token_index" : 256000 ,
57+ "eos_token_id" : [1 , 106 ],
58+ "image_token_index" : 262144 ,
59+ "initializer_range" : 0.02 ,
60+ "mm_tokens_per_image" : 256 ,
61+ "model_type" : "gemma3" ,
62+ "text_config" : {
63+ "head_dim" : 128 ,
64+ "hidden_size" : 5376 ,
65+ "intermediate_size" : 21504 ,
66+ "model_type" : "gemma3_text" ,
67+ "num_attention_heads" : 32 ,
68+ "num_hidden_layers" : 62 ,
69+ "num_key_value_heads" : 16 ,
70+ "query_pre_attn_scalar" : 168 ,
71+ "rope_scaling" : {
72+ "factor" : 8.0 ,
73+ "rope_type" : "linear"
74+ },
75+ "sliding_window" : 1024
76+ },
77+ "torch_dtype" : "bfloat16" ,
78+ "transformers_version" : "4.50.0.dev0" ,
79+ "vision_config" : {
80+ "hidden_size" : 1152 ,
81+ "image_size" : 896 ,
82+ "intermediate_size" : 4304 ,
83+ "model_type" : "siglip_vision_model" ,
84+ "num_attention_heads" : 16 ,
85+ "num_hidden_layers" : 27 ,
86+ "patch_size" : 14 ,
87+ "vision_use_head" : False
88+ }
89+ }
90+
5591
5692@dataclass (repr = False )
5793class Scenario :
5894 backend : str
95+ config_name : str
5996
6097 def __repr__ (self ) -> str :
61- return f"backend:{ self .backend .lower ()} "
98+ return f"backend:{ self .backend .lower ()} _config: { self . config_name . lower () } "
6299
63100
64101class TestGemma3 (unittest .TestCase ):
@@ -95,7 +132,8 @@ def get_kv_cache_manager(self, dtype: torch.dtype, config: Gemma3Config,
95132
96133 def test_gemma3_sanity (self ):
97134
98- config_dict = deepcopy (GEMMA3_1B_MINI_CONFIG )
135+ config_dict = deepcopy (
136+ GEMMA3_1B_CONFIG ) # Using 1B config for sanity test.
99137 gemma3_config = Gemma3Config .from_dict (config_dict )
100138
101139 dtype = gemma3_config .torch_dtype
@@ -174,8 +212,12 @@ def test_gemma3_sanity(self):
174212 kv_cache_manager .shutdown ()
175213
176214 @parameterized .expand ([
177- Scenario (backend = "TRTLLM" ),
178- Scenario (backend = "VANILLA" ),
215+ Scenario (backend = "TRTLLM" , config_name = "1B" ),
216+ Scenario (backend = "VANILLA" , config_name = "1B" ),
217+ Scenario (backend = "FLASHINFER" , config_name = "1B" ),
218+ Scenario (backend = "TRTLLM" , config_name = "27B" ),
219+ Scenario (backend = "VANILLA" , config_name = "27B" ),
220+ Scenario (backend = "FLASHINFER" , config_name = "27B" ),
179221 ], lambda testcase_func , param_num , param :
180222 f"{ testcase_func .__name__ } [{ param .args [0 ]} ]" )
181223 @torch .no_grad ()
@@ -184,14 +226,31 @@ def test_gemma3_allclose_to_hf(self, scenario: Scenario) -> None:
184226 Compare output to HF.
185227 """
186228 backend = scenario .backend
229+ config_name = scenario .config_name
187230 metadata_cls = get_attention_backend (backend ).Metadata
188231
189232 torch .random .manual_seed (0 )
190- config_dict = deepcopy (GEMMA3_1B_MINI_CONFIG )
233+
234+ # Select the appropriate config based on the scenario
235+ if config_name == "1B" :
236+ config_dict = deepcopy (GEMMA3_1B_CONFIG )
237+ elif config_name == "27B" :
238+ config_dict = deepcopy (GEMMA3_27B_CONFIG )
239+ else :
240+ raise ValueError (f"Unknown config_name: { config_name } " )
241+
191242 gemma3_config = Gemma3Config .from_dict (config_dict )
243+ if config_name == "27B" :
244+ gemma3_config .text_config .torch_dtype = gemma3_config .torch_dtype
245+ gemma3_config = gemma3_config .text_config
192246 dtype = gemma3_config .torch_dtype
193247 device = torch .device ('cuda' )
194248
249+ # 2-layer network with one local (sliding window=4) and one global layer.
250+ gemma3_config .num_hidden_layers = 2
251+ gemma3_config .sliding_window = 4
252+ gemma3_config .sliding_window_pattern = 2
253+
195254 num_blocks = 1
196255 tokens_per_block = 128
197256 max_seq_len = num_blocks * tokens_per_block
@@ -253,6 +312,7 @@ def test_gemma3_allclose_to_hf(self, scenario: Scenario) -> None:
253312 position_ids = position_ids ,
254313 past_key_values = hf_cache ,
255314 use_cache = True )
315+
256316 torch .testing .assert_close (logits ,
257317 ref .logits [:, - 1 ].float (),
258318 atol = 0.05 ,
0 commit comments