Skip to content

Commit ebb8fa6

Browse files
talumbaucopybara-github
authored andcommitted
Add mask as optional input for converter
PiperOrigin-RevId: 718980201
1 parent 2b97001 commit ebb8fa6

File tree

2 files changed

+8
-0
lines changed

2 files changed

+8
-0
lines changed

ai_edge_torch/generative/utilities/converter.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,9 @@ def _export_helper(
174174
'input_pos': prefill_input_pos,
175175
'kv_cache': kv,
176176
}
177+
if export_config.prefill_mask is not None:
178+
sample_kwargs['mask'] = export_config.prefill_mask
179+
177180
if lora is not None:
178181
prefill_signature_name += f'_lora_r{lora.get_rank()}'
179182
sample_kwargs['lora'] = lora
@@ -199,6 +202,8 @@ def _export_helper(
199202
'input_pos': decode_input_pos,
200203
'kv_cache': kv,
201204
}
205+
if export_config.decode_mask is not None:
206+
sample_kwargs['mask'] = export_config.decode_mask
202207
if lora is not None:
203208
sample_kwargs['lora'] = lora
204209

ai_edge_torch/generative/utilities/model_builder.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ class ExportConfig:
5555
# On prefill signatures, should the model produce logit output?
5656
# When False, only decode signatures will produce output.
5757
output_logits_on_prefill: bool = False
58+
# Attention masks given as inputs to the model.
59+
prefill_mask: Optional[torch.Tensor] = None
60+
decode_mask: Optional[torch.Tensor] = None
5861

5962

6063
class DecoderOnlyModel(nn.Module):

0 commit comments

Comments
 (0)