Skip to content

Commit 7fcec7e

Browse files
haozha111copybara-github
authored andcommitted
*In ModelConfig, allow user to specify the negative infinity mask to override the default float(-inf). Since in certain accelerators, they couldn't handle inf numbers well.
*Also updated the mask computation logic in ExportConfig. PiperOrigin-RevId: 754207101
1 parent e69c0f8 commit 7fcec7e

File tree

6 files changed

+48
-37
lines changed

6 files changed

+48
-37
lines changed

ai_edge_torch/generative/examples/gemma3/decoder.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,11 @@ def create_sliding_mask(
199199
sliding_mask = torch.where(
200200
sliding_mask_bool,
201201
torch.zeros_like(sliding_mask_bool, dtype=torch.float),
202-
torch.full_like(sliding_mask_bool, float("-inf"), dtype=torch.float),
202+
torch.full_like(
203+
sliding_mask_bool,
204+
self.config.get_causal_mask_value(),
205+
dtype=torch.float,
206+
),
203207
)
204208

205209
return sliding_mask
@@ -215,7 +219,7 @@ def compose_mask(
215219
mask = torch.logical_and(mask, pixel_mask)
216220
else:
217221
mask = torch.logical_or(mask, pixel_mask)
218-
mask = torch.where(mask, 0, float("-inf"))
222+
mask = torch.where(mask, 0, self.config.get_causal_mask_value())
219223
return mask
220224

221225
def build_pixel_mask(self, image_indices: torch.Tensor):

ai_edge_torch/generative/examples/paligemma/decoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def forward(
7575
if mask is None:
7676
embeds_len = input_embeds.shape[1]
7777
mask = torch.zeros(embeds_len, self.config.kv_cache_max)
78-
mask[:, embeds_len:] = float("-inf")
78+
mask[:, embeds_len:] = attn_config.causal_mask_value
7979

8080
return self._forward_with_embeds(
8181
input_embeds,

ai_edge_torch/generative/examples/paligemma/decoder2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def forward(
7575
# By default, don't mask image embeds with a diagonal causal mask.
7676
embeds_len = input_embeds.shape[1]
7777
mask = torch.zeros(embeds_len, self.config.kv_cache_max)
78-
mask[:, embeds_len:] = float("-inf")
78+
mask[:, embeds_len:] = attn_config.causal_mask_value
7979

8080
return self._forward_with_embeds(
8181
input_embeds, rope, mask, input_pos, kv_cache, export_config

ai_edge_torch/generative/layers/model_config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@ class AttentionConfig:
116116
attn_type: Optional[AttentionType] = None
117117
# The size of the sliding window used for local attention.
118118
sliding_window_size: Optional[int] = None
119+
# The default causal mask value used by attention layer.
120+
causal_mask_value: float = float("-inf")
119121

120122

121123
@dataclasses.dataclass
@@ -247,3 +249,7 @@ def block_config(self, idx: int) -> TransformerBlockConfig:
247249
f"Index {idx} is out of range for layer configs: {self.block_configs}"
248250
)
249251
return self.block_configs[idx]
252+
253+
@property
254+
def get_causal_mask_value(self) -> float:
255+
return self.block_config(0).attn_config.causal_mask_value

ai_edge_torch/generative/utilities/converter.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,18 @@ def define_conversion_flags(model_name: str):
9595
return flags
9696

9797

98+
def _build_mask(mask_len, kv_cache_max_len, causal_mask_value) -> torch.Tensor:
99+
if isinstance(mask_len, list):
100+
return [
101+
_build_mask(i, kv_cache_max_len, causal_mask_value) for i in mask_len
102+
]
103+
104+
mask = torch.full(
105+
(mask_len, kv_cache_max_len), causal_mask_value, dtype=torch.float32
106+
)
107+
return torch.triu(mask, diagonal=1).unsqueeze(0).unsqueeze(0)
108+
109+
98110
def convert_to_tflite(
99111
pytorch_model: torch.nn.Module,
100112
output_path: str,
@@ -229,14 +241,15 @@ def _export_helper(
229241
torch.arange(0, seq_len + pixel_seq_len, dtype=torch.int)
230242
)
231243

232-
if export_config.prefill_mask is None:
233-
prefill_masks = None
234-
elif isinstance(export_config.prefill_mask, torch.Tensor):
235-
prefill_masks = [export_config.prefill_mask]
236-
elif isinstance(export_config.prefill_mask, list):
237-
prefill_masks = export_config.prefill_mask
238-
else:
239-
raise ValueError('Prefill masks unrecognized.')
244+
prefill_masks = None
245+
if flags.FLAGS.mask_as_input:
246+
prefill_masks = [
247+
_build_mask(
248+
flags.FLAGS.prefill_seq_lens,
249+
flags.FLAGS.kv_cache_max_len,
250+
config.get_causal_mask_value(),
251+
)
252+
]
240253

241254
if prefill_masks:
242255
assert len(prefill_masks) == len(prefill_seq_lens)
@@ -299,8 +312,17 @@ def _export_helper(
299312
'input_pos': decode_input_pos,
300313
'kv_cache': decode_kv,
301314
}
302-
if export_config.decode_mask is not None:
303-
sample_kwargs['mask'] = export_config.decode_mask
315+
if flags.FLAGS.mask_as_input:
316+
# Note that the decode mask is not a correct causal mask, but it is okay
317+
# for the conversion purpose because only the shape matters in conversion.
318+
# A correct causal mask of decode for a given token position of decode, it
319+
# should be built like:
320+
#
321+
# torch.triu(mask, diagonal=decode_position).unsqueeze(0).unsqueeze(0)
322+
#
323+
sample_kwargs['mask'] = _build_mask(
324+
1, flags.FLAGS.kv_cache_max_len, config.get_causal_mask_value()
325+
)
304326
if lora is not None:
305327
sample_kwargs['lora'] = lora
306328

ai_edge_torch/generative/utilities/export_config.py

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ class ExportConfig:
3333
# When False, only decode signatures will produce output.
3434
output_logits_on_prefill: bool = False
3535
# Attention masks given as inputs to the model.
36+
# Note that `prefill_mask`, `decode_mask`, and `kvcache_cls` are deprecated
37+
# and will be removed in a future version.
3638
prefill_mask: Optional[torch.Tensor | List[torch.Tensor]] = None
3739
decode_mask: Optional[torch.Tensor | List[torch.Tensor]] = None
3840
# The KV Cache layout for K and V buffers in attention.
@@ -43,33 +45,10 @@ class ExportConfig:
4345
decode_batch_size: int = 1
4446

4547

46-
def _build_mask(mask_len, kv_cache_max_len) -> torch.Tensor:
47-
if isinstance(mask_len, list):
48-
return [_build_mask(i, kv_cache_max_len) for i in mask_len]
49-
50-
mask = torch.full(
51-
(mask_len, kv_cache_max_len), float('-inf'), dtype=torch.float32
52-
)
53-
return torch.triu(mask, diagonal=1).unsqueeze(0).unsqueeze(0)
54-
55-
5648
def get_from_flags() -> ExportConfig:
5749
"""Builds an export config according to the commandline flags."""
5850
export_config = ExportConfig()
5951

60-
if flags.FLAGS.mask_as_input:
61-
export_config.prefill_mask = _build_mask(
62-
flags.FLAGS.prefill_seq_lens, flags.FLAGS.kv_cache_max_len
63-
)
64-
# Note that the decode mask is not a correct causal mask, but it is okay
65-
# for the conversion purpose because only the shape matters in conversion.
66-
# A correct causal mask of decode for a given token position of decode, it
67-
# should be built like:
68-
#
69-
# torch.triu(mask, diagonal=decode_position).unsqueeze(0).unsqueeze(0)
70-
#
71-
export_config.decode_mask = _build_mask(1, flags.FLAGS.kv_cache_max_len)
72-
7352
if flags.FLAGS.transpose_kv_cache:
7453
export_config.kvcache_layout = kv_utils.KV_LAYOUT_TRANSPOSED
7554

0 commit comments

Comments
 (0)