Skip to content

Commit 0cdcda0

Browse files
protobird-gitcopybara-github
authored andcommitted
Convert gemma3 with mask_as_input=false
PiperOrigin-RevId: 756518297
1 parent 1c4100d commit 0cdcda0

File tree

1 file changed

+2
-3
lines changed
  • ai_edge_torch/generative/examples/gemma3

1 file changed

+2
-3
lines changed

ai_edge_torch/generative/examples/gemma3/decoder.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,6 @@ def forward(
261261
pixel_mask = self.build_pixel_mask(image_indices)
262262
# RoPE parameters are the same for all blocks. Use the first layer.
263263
attn_config = self.config.block_config(0).attn_config
264-
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
265264
# Different rotary base for global and local attention
266265
# based on attention pattern
267266
rope = [
@@ -305,7 +304,7 @@ def _forward_with_embeds(
305304
if pixel_mask is None:
306305
mask = [
307306
self.get_local_global_attention_mask(
308-
mask,
307+
mask[i] if isinstance(mask, list) else mask,
309308
self.config.block_config(i).attn_config.attn_type,
310309
input_pos,
311310
self.config.block_config(i).attn_config.sliding_window_size,
@@ -316,7 +315,7 @@ def _forward_with_embeds(
316315
pixel_mask = pixel_mask.index_select(2, input_pos)
317316
mask = [
318317
self.compose_mask(
319-
mask[i],
318+
mask[i] if isinstance(mask, list) else mask,
320319
pixel_mask,
321320
self.config.block_config(i).attn_config.attn_type,
322321
)

0 commit comments

Comments
 (0)