File tree Expand file tree Collapse file tree 1 file changed +2
-3
lines changed
ai_edge_torch/generative/examples/gemma3 Expand file tree Collapse file tree 1 file changed +2
-3
lines changed Original file line number Diff line number Diff line change @@ -261,7 +261,6 @@ def forward(
261261 pixel_mask = self .build_pixel_mask (image_indices )
262262 # RoPE parameters are the same for all blocks. Use the first layer.
263263 attn_config = self .config .block_config (0 ).attn_config
264- n_elem = int (attn_config .rotary_percentage * attn_config .head_dim )
265264 # Different rotary base for global and local attention
266265 # based on attention pattern
267266 rope = [
@@ -305,7 +304,7 @@ def _forward_with_embeds(
305304 if pixel_mask is None :
306305 mask = [
307306 self .get_local_global_attention_mask (
308- mask ,
307+ mask [ i ] if isinstance ( mask , list ) else mask ,
309308 self .config .block_config (i ).attn_config .attn_type ,
310309 input_pos ,
311310 self .config .block_config (i ).attn_config .sliding_window_size ,
@@ -316,7 +315,7 @@ def _forward_with_embeds(
316315 pixel_mask = pixel_mask .index_select (2 , input_pos )
317316 mask = [
318317 self .compose_mask (
319- mask [i ],
318+ mask [i ] if isinstance ( mask , list ) else mask ,
320319 pixel_mask ,
321320 self .config .block_config (i ).attn_config .attn_type ,
322321 )
You can’t perform that action at this time.
0 commit comments