Skip to content

Commit e69c0f8

Browse files
ai-edge-botcopybara-github
authored andcommitted
Update Hammer model convert_to_tflite.py to use shared ExportConfig utilities.
PiperOrigin-RevId: 754054026
1 parent b200614 commit e69c0f8

File tree

1 file changed

+2
-40
lines changed

1 file changed

+2
-40
lines changed

ai_edge_torch/generative/examples/hammer/convert_to_tflite.py

Lines changed: 2 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,10 @@
1717

1818
from absl import app
1919
from ai_edge_torch.generative.examples.hammer import hammer
20-
from ai_edge_torch.generative.layers import kv_cache
2120
from ai_edge_torch.generative.utilities import converter
22-
from ai_edge_torch.generative.utilities import export_config as export_cfg
23-
import torch
24-
21+
from ai_edge_torch.generative.utilities import export_config
2522

2623
flags = converter.define_conversion_flags('hammer')
27-
ExportConfig = export_cfg.ExportConfig
28-
2924

3025
_MODEL_SIZE = flags.DEFINE_enum(
3126
'model_size',
@@ -40,35 +35,6 @@
4035
}
4136

4237

43-
def _create_mask(mask_len, kv_cache_max_len):
44-
mask = torch.full(
45-
(mask_len, kv_cache_max_len), float('-inf'), dtype=torch.float32
46-
)
47-
mask = torch.triu(mask, diagonal=1).unsqueeze(0).unsqueeze(0)
48-
return mask
49-
50-
51-
def _create_export_config(
52-
prefill_seq_lens: list[int], kv_cache_max_len: int
53-
) -> ExportConfig:
54-
"""Creates the export config for the model."""
55-
export_config = ExportConfig()
56-
if isinstance(prefill_seq_lens, list):
57-
prefill_mask = [_create_mask(i, kv_cache_max_len) for i in prefill_seq_lens]
58-
else:
59-
prefill_mask = _create_mask(prefill_seq_lens, kv_cache_max_len)
60-
61-
export_config.prefill_mask = prefill_mask
62-
63-
decode_mask = torch.full(
64-
(1, kv_cache_max_len), float('-inf'), dtype=torch.float32
65-
)
66-
decode_mask = torch.triu(decode_mask, diagonal=1).unsqueeze(0).unsqueeze(0)
67-
export_config.decode_mask = decode_mask
68-
export_config.kvcache_layout = kv_cache.KV_LAYOUT_TRANSPOSED
69-
return export_config
70-
71-
7238
def main(_):
7339
pytorch_model = _BUILDER[_MODEL_SIZE.value](
7440
flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
@@ -80,11 +46,7 @@ def main(_):
8046
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
8147
quantize=flags.FLAGS.quantize,
8248
lora_ranks=flags.FLAGS.lora_ranks,
83-
export_config=_create_export_config(
84-
flags.FLAGS.prefill_seq_lens, flags.FLAGS.kv_cache_max_len
85-
)
86-
if flags.FLAGS.transpose_kv_cache
87-
else ExportConfig(),
49+
export_config=export_config.get_from_flags(),
8850
)
8951

9052

0 commit comments

Comments
 (0)