Skip to content

Commit 2189970

Browse files
talumbaucopybara-github
authored andcommitted
Optional Mask input for LM examples
PiperOrigin-RevId: 713729372
1 parent 4bf9d76 commit 2189970

File tree

7 files changed

+56
-22
lines changed

7 files changed

+56
-22
lines changed

ai_edge_torch/generative/examples/gemma/gemma2.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ def forward(
129129
tokens: torch.Tensor,
130130
input_pos: torch.Tensor,
131131
kv_cache: kv_utils.KVCache,
132+
mask: Optional[torch.Tensor] = None,
132133
export_config: Optional[model_builder.ExportConfig] = None,
133134
) -> dict[torch.Tensor, kv_utils.KVCache]:
134135
_, seq_len = tokens.size()
@@ -175,7 +176,15 @@ def _forward_with_embeds(
175176
input_embeds = input_embeds * self.config.embedding_scale
176177
x = input_embeds
177178
updated_kv_entries = []
179+
mask_input = mask is not None
178180
for i, block in enumerate(self.transformer_blocks):
181+
mask = (
182+
mask
183+
if mask_input
184+
else self.get_attention_mask(
185+
block.config.attn_config.attn_type, input_pos
186+
)
187+
)
179188
kv_entry = kv_cache.caches[i] if kv_cache else None
180189
x, kv_entry = block(x, rope, mask[i], input_pos, kv_entry)
181190
if kv_entry:

ai_edge_torch/generative/examples/paligemma/decoder.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def forward(
5454
input_pos: torch.Tensor,
5555
kv_cache: kv_utils.KVCache,
5656
input_embeds: torch.Tensor = None,
57+
mask: Optional[torch.Tensor] = None,
5758
export_config: Optional[model_builder.ExportConfig] = None,
5859
called_by_generate: bool = True,
5960
) -> dict[torch.Tensor, kv_utils.KVCache]:
@@ -73,8 +74,9 @@ def forward(
7374
# The first part of input_embeds are image embeddings. Diagonal causal mask
7475
# doesn't work here.
7576
embeds_len = input_embeds.shape[1]
76-
mask = torch.zeros(embeds_len, self.config.kv_cache_max)
77-
mask[:, embeds_len:] = float("-inf")
77+
if mask is None:
78+
mask = torch.zeros(embeds_len, self.config.kv_cache_max)
79+
mask[:, embeds_len:] = float("-inf")
7880

7981
return self._forward_with_embeds(
8082
input_embeds, rope, mask, input_pos, kv_cache

ai_edge_torch/generative/examples/paligemma/decoder2.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def forward(
5757
input_pos: torch.Tensor,
5858
kv_cache: kv_utils.KVCache,
5959
input_embeds: torch.Tensor = None,
60+
mask: Optional[torch.Tensor] = None,
6061
export_config: Optional[model_builder.ExportConfig] = None,
6162
called_by_generate: bool = True,
6263
) -> dict[torch.Tensor, kv_utils.KVCache]:
@@ -73,17 +74,21 @@ def forward(
7374
repo_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
7475
)
7576

76-
if called_by_generate:
77-
# PaliGemma2 generate() use a diagonal causal mask even with image embeds.
78-
mask = [self.get_attention_mask(
79-
self.config.block_config(i).attn_config.attn_type, input_pos
80-
) for i in range(self.config.num_layers)]
81-
else:
82-
# By default, don't mask image embeds with a diagonal causal mask.
83-
embeds_len = input_embeds.shape[1]
84-
mask = torch.zeros(embeds_len, self.config.kv_cache_max)
85-
mask[:, embeds_len:] = float("-inf")
86-
mask = [mask] * self.config.num_layers
77+
if mask is None:
78+
if called_by_generate:
79+
# PaliGemma2 generate() use a diagonal causal mask even with image embeds.
80+
mask = [
81+
self.get_attention_mask(
82+
self.config.block_config(i).attn_config.attn_type, input_pos
83+
)
84+
for i in range(self.config.num_layers)
85+
]
86+
else:
87+
# By default, don't mask image embeds with a diagonal causal mask.
88+
embeds_len = input_embeds.shape[1]
89+
mask = torch.zeros(embeds_len, self.config.kv_cache_max)
90+
mask[:, embeds_len:] = float("-inf")
91+
mask = [mask] * self.config.num_layers
8792

8893
return self._forward_with_embeds(
8994
input_embeds, rope, mask, input_pos, kv_cache, export_config

ai_edge_torch/generative/examples/paligemma/paligemma.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def forward(
7070
tokens: torch.Tensor,
7171
input_pos: torch.Tensor,
7272
kv_cache: kv_utils.KVCache,
73+
mask: Optional[torch.Tensor] = None,
7374
pixel_values: torch.Tensor = None,
7475
export_config: Optional[model_builder.ExportConfig] = None,
7576
called_by_generate: bool = True,
@@ -79,6 +80,7 @@ def forward(
7980
tokens=tokens,
8081
input_pos=input_pos,
8182
kv_cache=kv_cache,
83+
mask=mask,
8284
input_embeds=None,
8385
export_config=export_config,
8486
called_by_generate=called_by_generate,
@@ -111,6 +113,7 @@ def forward(
111113
tokens=None,
112114
input_pos=input_pos,
113115
kv_cache=kv_cache,
116+
mask=mask,
114117
input_embeds=input_embeds,
115118
export_config=export_config,
116119
called_by_generate=called_by_generate,

ai_edge_torch/generative/examples/test_models/toy_model.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515
# A toy example which has a single-layer transformer block.
16-
from typing import Tuple
16+
from typing import Optional, Tuple
1717

1818
from ai_edge_torch.generative.layers import builder
1919
from ai_edge_torch.generative.layers.attention import TransformerBlock
@@ -52,14 +52,20 @@ def __init__(self, config: cfg.ModelConfig) -> None:
5252
self.config = config
5353

5454
@torch.inference_mode
55-
def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
55+
def forward(
56+
self,
57+
idx: torch.Tensor,
58+
input_pos: torch.Tensor,
59+
mask: Optional[torch.Tensor] = None,
60+
) -> torch.Tensor:
5661
x = self.tok_embedding(idx)
5762
cos, sin = self.rope_cache
5863

5964
cos = cos.index_select(0, input_pos)
6065
sin = sin.index_select(0, input_pos)
61-
mask = self.mask_cache.index_select(2, input_pos)
62-
mask = mask[:, :, :, : self.config.max_seq_len]
66+
if mask is None:
67+
mask = self.mask_cache.index_select(2, input_pos)
68+
mask = mask[:, :, :, : self.config.max_seq_len]
6369

6470
x = self.transformer_block(x, (cos, sin), mask, input_pos)
6571
x = self.final_norm(x)
@@ -98,7 +104,12 @@ def __init__(self, config: cfg.ModelConfig) -> None:
98104
self.config = config
99105

100106
@torch.inference_mode
101-
def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
107+
def forward(
108+
self,
109+
idx: torch.Tensor,
110+
input_pos: torch.Tensor,
111+
mask: Optional[torch.Tensor] = None,
112+
) -> torch.Tensor:
102113
x = self.tok_embedding(idx)
103114
cos, sin = self.rope_cache
104115

ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,16 @@ def forward(
6363
tokens: torch.Tensor,
6464
input_pos: torch.Tensor,
6565
kv_cache: kv_utils.KVCache,
66+
mask: Optional[torch.Tensor] = None,
6667
export_config: Optional[ExportConfig] = None,
6768
) -> Tuple[torch.Tensor, kv_utils.KVCache]:
6869
x = self.tok_embedding(tokens)
6970
cos, sin = self.rope_cache
7071
cos = cos.index_select(0, input_pos)
7172
sin = sin.index_select(0, input_pos)
72-
mask = self.mask_cache.index_select(2, input_pos)
73-
mask = mask[:, :, :, : self.config.max_seq_len]
73+
if mask is None:
74+
mask = self.mask_cache.index_select(2, input_pos)
75+
mask = mask[:, :, :, : self.config.max_seq_len]
7476

7577
updated_kv_entries = []
7678
for i, block in enumerate(self.transformer_blocks):

ai_edge_torch/generative/utilities/model_builder.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def forward(
9999
tokens: torch.Tensor,
100100
input_pos: torch.Tensor,
101101
kv_cache: kv_utils.KVCache,
102+
mask: Optional[torch.Tensor] = None,
102103
lora: Optional[lora_utils.LoRA] = None,
103104
export_config: Optional[ExportConfig] = None,
104105
) -> dict[torch.Tensor, kv_utils.KVCache]:
@@ -122,8 +123,9 @@ def forward(
122123
# input_pos=input_pos, n_elem=n_elem, base=attn_config.rotary_base
123124
)
124125

125-
mask = self.mask_cache.index_select(2, input_pos)
126-
mask = mask[:, :, :, : self.config.kv_cache_max]
126+
if mask is None:
127+
mask = self.mask_cache.index_select(2, input_pos)
128+
mask = mask[:, :, :, : self.config.kv_cache_max]
127129

128130
return self.forward_with_embeds(
129131
input_embeds, rope, mask, input_pos, kv_cache, lora, export_config

0 commit comments

Comments
 (0)