Skip to content

Commit 45017e4

Browse files
protobird-gitcopybara-github
authored andcommitted
Set default of mask_as_input and transport_kv_cache true for gemma2 & 3
PiperOrigin-RevId: 756532578
1 parent 0cdcda0 commit 45017e4

File tree

3 files changed

+13
-5
lines changed

3 files changed

+13
-5
lines changed

ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
from ai_edge_torch.generative.utilities import converter
2121
from ai_edge_torch.generative.utilities import export_config
2222

23-
flags = converter.define_conversion_flags("gemma2-2b")
23+
flags = converter.define_conversion_flags(
24+
"gemma2-2b", default_mask_as_input=True, default_transpose_kv_cache=True
25+
)
2426

2527

2628
def main(_):

ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
from ai_edge_torch.generative.utilities import converter
2121
from ai_edge_torch.generative.utilities import export_config
2222

23-
flags = converter.define_conversion_flags('gemma3-1b')
23+
flags = converter.define_conversion_flags(
24+
'gemma3-1b', default_mask_as_input=True, default_transpose_kv_cache=True
25+
)
2426

2527
_MODEL_SIZE = flags.DEFINE_string(
2628
'model_size',

ai_edge_torch/generative/utilities/converter.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,11 @@ def forward(self, *export_args, **export_kwargs):
4242
return self.module(*export_args, **full_kwargs)
4343

4444

45-
def define_conversion_flags(model_name: str):
45+
def define_conversion_flags(
46+
model_name: str,
47+
default_mask_as_input: bool = False,
48+
default_transpose_kv_cache: bool = False,
49+
):
4650
"""Defines common flags used for model conversion."""
4751

4852
flags.DEFINE_string(
@@ -83,13 +87,13 @@ def define_conversion_flags(model_name: str):
8387
)
8488
flags.DEFINE_bool(
8589
'mask_as_input',
86-
False,
90+
default_mask_as_input,
8791
'If true, the mask will be passed in as input. Otherwise, mask will be '
8892
'built by the model internally.',
8993
)
9094
flags.DEFINE_bool(
9195
'transpose_kv_cache',
92-
False,
96+
default_transpose_kv_cache,
9397
'If true, the model will be converted with transposed KV cache.',
9498
)
9599
return flags

0 commit comments

Comments
 (0)