Skip to content

Commit d09e97a

Browse files
talumbaucopybara-github
authored andcommitted
Support multiple KV Cache layouts based on dimension of K and V tensors
PiperOrigin-RevId: 719386925
1 parent 243fae8 commit d09e97a

File tree

12 files changed

+1759
-7
lines changed

12 files changed

+1759
-7
lines changed
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# Copyright 2024 The AI Edge Torch Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
"""Example of converting a Gemma2 model to multi-signature tflite model."""
17+
18+
import os
19+
import pathlib
20+
21+
from absl import app
22+
from absl import flags
23+
from ai_edge_torch.generative.examples.experimental.gemma import gemma2_gpu
24+
from ai_edge_torch.generative.layers.experimental import kv_cache
25+
from ai_edge_torch.generative.utilities import converter
26+
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
27+
import torch
28+
29+
_CHECKPOINT_PATH = flags.DEFINE_string(
30+
'checkpoint_path',
31+
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma2-2b'),
32+
'The path to the model checkpoint, or directory holding the checkpoint.',
33+
)
34+
_OUTPUT_PATH = flags.DEFINE_string(
35+
'output_path',
36+
'/tmp/',
37+
'The path to export the tflite model.',
38+
)
39+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
40+
'output_name_prefix',
41+
'gemma2',
42+
'The prefix of the output tflite model name.',
43+
)
44+
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
45+
'prefill_seq_lens',
46+
(8, 64, 128, 256, 512, 1024),
47+
'List of the maximum sizes of prefill input tensors.',
48+
)
49+
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
50+
'kv_cache_max_len',
51+
1280,
52+
'The maximum size of KV cache buffer, including both prefill and decode.',
53+
)
54+
_QUANTIZE = flags.DEFINE_bool(
55+
'quantize',
56+
True,
57+
'Whether the model should be quantized.',
58+
)
59+
_LORA_RANKS = flags.DEFINE_multi_integer(
60+
'lora_ranks',
61+
None,
62+
'If set, the model will be converted with the provided list of LoRA ranks.',
63+
)
64+
65+
66+
def _create_mask(mask_len, kv_cache_max_len):
67+
mask = torch.full(
68+
(mask_len, kv_cache_max_len), float('-inf'), dtype=torch.float32
69+
)
70+
mask = torch.triu(mask, diagonal=1).unsqueeze(0).unsqueeze(0)
71+
return mask
72+
73+
74+
def _create_export_config(
75+
prefill_seq_lens: list[int], kv_cache_max_len: int
76+
) -> ExportConfig:
77+
"""Creates the export config for the model."""
78+
export_config = ExportConfig()
79+
if isinstance(prefill_seq_lens, list):
80+
prefill_mask = [_create_mask(i, kv_cache_max_len) for i in prefill_seq_lens]
81+
else:
82+
prefill_mask = _create_mask(prefill_seq_lens, kv_cach_max_len)
83+
84+
export_config.prefill_mask = prefill_mask
85+
86+
decode_mask = torch.full(
87+
(1, kv_cache_max_len), float('-inf'), dtype=torch.float32
88+
)
89+
decode_mask = torch.triu(decode_mask, diagonal=1).unsqueeze(0).unsqueeze(0)
90+
export_config.decode_mask = decode_mask
91+
export_config.kvcache_cls = kv_cache.KVCacheTransposed
92+
return export_config
93+
94+
95+
def main(_):
96+
pytorch_model = gemma2_gpu.build_2b_model(
97+
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
98+
)
99+
100+
converter.convert_to_tflite(
101+
pytorch_model,
102+
output_path=_OUTPUT_PATH.value,
103+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
104+
prefill_seq_len=_PREFILL_SEQ_LENS.value,
105+
quantize=_QUANTIZE.value,
106+
lora_ranks=_LORA_RANKS.value,
107+
export_config=_create_export_config(
108+
_PREFILL_SEQ_LENS.value, _KV_CACHE_MAX_LEN.value
109+
),
110+
)
111+
112+
113+
if __name__ == '__main__':
114+
app.run(main)

0 commit comments

Comments
 (0)