@@ -93,21 +93,27 @@ def generate(
9393class UnifiedGemma3Wrapper (verifier .ReauthoredModelWrapper ):
9494 """Unified Gemma3 model wrapper for verification."""
9595
96- def __init__ (self , model : torch .nn .Module ):
97- super ().__init__ (model , kv_layout = kv_utils .KV_LAYOUT_TRANSPOSED )
96+ def __init__ (
97+ self ,
98+ model : torch .nn .Module ,
99+ kv_cache_max_len : int = verifier .DEFAULT_KV_CACHE_MAX_LEN ,
100+ ):
101+ super ().__init__ (
102+ model ,
103+ kv_layout = kv_utils .KV_LAYOUT_TRANSPOSED ,
104+ kv_cache_max_len = kv_cache_max_len ,
105+ )
98106
99107 def _init_kv_cache (self ):
100108 return kv_utils .KVCache .from_model_config (
101- self .model .model .config , kv_layout = self .kv_layout
109+ self .kv_cache_max_len , self . model .model .config , kv_layout = self .kv_layout
102110 )
103111
104112 def forward (
105113 self , tokens : torch .Tensor , pixel_values : torch .Tensor = None
106114 ) -> torch .Tensor :
107115 """Forwards the model."""
108- mask = attn_utils .build_causal_mask_cache (
109- self .model .model .config .kv_cache_max_len
110- )
116+ mask = attn_utils .build_causal_mask_cache (self .kv_cache_max_len )
111117 input_pos = torch .arange (0 , tokens .shape [1 ], dtype = torch .int )
112118 mask = mask .index_select (2 , input_pos )
113119 output = self .model .model .forward (
@@ -127,9 +133,7 @@ def generate(
127133 tokens = torch .tensor ([input_ids ])
128134 input_pos = torch .arange (0 , tokens .shape [1 ], dtype = torch .int )
129135 kv_cache = self ._init_kv_cache ()
130- mask_cache = attn_utils .build_causal_mask_cache (
131- self .model .model .config .kv_cache_max_len
132- )
136+ mask_cache = attn_utils .build_causal_mask_cache (self .kv_cache_max_len )
133137 for _ in range (max_new_tokens ):
134138 mask = mask_cache .index_select (2 , input_pos )
135139 output = self .model .model .forward (
@@ -245,7 +249,11 @@ def verify_gemma3(
245249
246250 if variant == "1b" :
247251 reauthored_model = UnifiedGemma3Wrapper (
248- gemma3 .build_model_1b (gemma3_model_path , custom_loader )
252+ gemma3 .build_model_1b (
253+ gemma3_model_path ,
254+ custom_loader ,
255+ mask_cache_size = verifier .DEFAULT_KV_CACHE_MAX_LEN ,
256+ )
249257 )
250258 else :
251259 raise ValueError (f"Unsupported Gemma3 variant: { variant } " )
0 commit comments