Skip to content

Commit c741a95

Browse files
ai-edge-botcopybara-github
authored andcommitted
Handle mask properly in gemma2
PiperOrigin-RevId: 716258803
1 parent 44d4487 commit c741a95

File tree

2 files changed

+10
-17
lines changed

2 files changed

+10
-17
lines changed

ai_edge_torch/generative/examples/gemma/gemma2.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -144,12 +144,13 @@ def forward(
144144
attn_config = self.config.block_config(0).attn_config
145145
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
146146
rope = rotary_pos_emb.build_rope(input_pos, n_elem, attn_config.rotary_base)
147-
mask = [
148-
self.get_attention_mask(
149-
self.config.block_config(i).attn_config.attn_type, input_pos
150-
)
151-
for i in range(self.config.num_layers)
152-
]
147+
if mask is None:
148+
mask = [
149+
self.get_attention_mask(
150+
self.config.block_config(i).attn_config.attn_type, input_pos
151+
)
152+
for i in range(self.config.num_layers)
153+
]
153154

154155
return self._forward_with_embeds(
155156
input_embeds, rope, mask, input_pos, kv_cache, export_config
@@ -159,7 +160,7 @@ def _forward_with_embeds(
159160
self,
160161
input_embeds: torch.Tensor,
161162
rope: Tuple[torch.Tensor, torch.Tensor],
162-
mask: List[torch.Tensor],
163+
mask: torch.Tensor | List[torch.Tensor],
163164
input_pos: torch.Tensor,
164165
kv_cache: kv_utils.KVCache,
165166
export_config: Optional[model_builder.ExportConfig] = None,
@@ -174,17 +175,10 @@ def _forward_with_embeds(
174175
input_embeds = input_embeds * self.config.embedding_scale
175176
x = input_embeds
176177
updated_kv_entries = []
177-
mask_input = mask is not None
178178
for i, block in enumerate(self.transformer_blocks):
179-
mask = (
180-
mask
181-
if mask_input
182-
else self.get_attention_mask(
183-
block.config.attn_config.attn_type, input_pos
184-
)
185-
)
179+
mask_entry = mask[i] if isinstance(mask, list) else mask
186180
kv_entry = kv_cache.caches[i] if kv_cache else None
187-
x, kv_entry = block(x, rope, mask[i], input_pos, kv_entry)
181+
x, kv_entry = block(x, rope, mask_entry, input_pos, kv_entry)
188182
if kv_entry:
189183
updated_kv_entries.append(kv_entry)
190184
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))

ai_edge_torch/generative/examples/paligemma/decoder2.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ def forward(
8686
embeds_len = input_embeds.shape[1]
8787
mask = torch.zeros(embeds_len, self.config.kv_cache_max)
8888
mask[:, embeds_len:] = float("-inf")
89-
mask = [mask] * self.config.num_layers
9089

9190
return self._forward_with_embeds(
9291
input_embeds, rope, mask, input_pos, kv_cache, export_config

0 commit comments

Comments
 (0)