Skip to content

Commit f114cce

Browse files
hheydarycopybara-github
authored andcommitted
Support custom batch size in decode signature.
PiperOrigin-RevId: 731316538
1 parent 6003b17 commit f114cce

File tree

8 files changed

+58
-34
lines changed

8 files changed

+58
-34
lines changed

ai_edge_torch/generative/examples/smollm/convert_to_tflite.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from absl import flags
2323
from ai_edge_torch.generative.examples.smollm import smollm
2424
from ai_edge_torch.generative.utilities import converter
25-
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
25+
from ai_edge_torch.generative.utilities import model_builder
2626

2727
_CHECKPOINT_PATH = flags.DEFINE_string(
2828
'checkpoint_path',
@@ -59,6 +59,11 @@
5959
None,
6060
'If set, the model will be converted with the provided list of LoRA ranks.',
6161
)
62+
_DECODE_BATCH_SIZE = flags.DEFINE_integer(
63+
'decode_batch_size',
64+
1,
65+
'The batch size for the decode signature.',
66+
)
6267

6368

6469
def main(_):
@@ -72,7 +77,9 @@ def main(_):
7277
prefill_seq_len=_PREFILL_SEQ_LENS.value,
7378
quantize=_QUANTIZE.value,
7479
lora_ranks=_LORA_RANKS.value,
75-
export_config=ExportConfig(),
80+
export_config=model_builder.ExportConfig(
81+
decode_batch_size=_DECODE_BATCH_SIZE.value
82+
),
7683
)
7784

7885

ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,22 @@
2222
from absl import flags
2323
from ai_edge_torch.generative.examples.smollm import smollm
2424
from ai_edge_torch.generative.utilities import converter
25-
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
25+
from ai_edge_torch.generative.utilities import model_builder
2626

2727
_CHECKPOINT_PATH = flags.DEFINE_string(
2828
'checkpoint_path',
2929
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/smollm2'),
3030
'The path to the model checkpoint, or directory holding the checkpoint.',
3131
)
32-
_TFLITE_PATH = flags.DEFINE_string(
33-
'tflite_path',
32+
_OUTPUT_PATH = flags.DEFINE_string(
33+
'output_path',
3434
'/tmp/',
35-
'The tflite file path to export.',
35+
'The path to export the tflite model.',
36+
)
37+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38+
'output_name_prefix',
39+
'smollm2',
40+
'The prefix of the output tflite model name.',
3641
)
3742
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
3843
'prefill_seq_lens',
@@ -49,21 +54,33 @@
4954
True,
5055
'Whether the model should be quantized.',
5156
)
57+
_LORA_RANKS = flags.DEFINE_multi_integer(
58+
'lora_ranks',
59+
None,
60+
'If set, the model will be converted with the provided list of LoRA ranks.',
61+
)
62+
_DECODE_BATCH_SIZE = flags.DEFINE_integer(
63+
'decode_batch_size',
64+
1,
65+
'The batch size for the decode signature.',
66+
)
5267

5368

5469
def main(_):
5570
pytorch_model = smollm.build_model_v2(
5671
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
5772
)
5873

59-
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
60-
output_filename = f'smollm2_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
6174
converter.convert_to_tflite(
6275
pytorch_model,
63-
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
76+
output_path=_OUTPUT_PATH.value,
77+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
6478
prefill_seq_len=_PREFILL_SEQ_LENS.value,
6579
quantize=_QUANTIZE.value,
66-
export_config=ExportConfig(),
80+
lora_ranks=_LORA_RANKS.value,
81+
export_config=model_builder.ExportConfig(
82+
decode_batch_size=_DECODE_BATCH_SIZE.value
83+
),
6784
)
6885

6986

ai_edge_torch/generative/fx_passes/test/test_remove_sdpa_zero_mask_pass.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ def __init__(self, config: unet_cfg.AttentionBlock2DConfig):
6565
super().__init__()
6666
self.config = config
6767
self.attention = SelfAttention(
68-
config.attention_batch_size,
6968
config.dim,
7069
config.attention_config,
7170
enable_hlfb=config.enable_hlfb,

ai_edge_torch/generative/layers/attention.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ def __init__(
4848
config.pre_attention_norm_config,
4949
)
5050
self.atten_func = CausalSelfAttention(
51-
model_config.batch_size,
5251
model_config.embedding_dim,
5352
config.attn_config,
5453
model_config.enable_hlfb,
@@ -115,22 +114,19 @@ class CausalSelfAttention(nn.Module):
115114

116115
def __init__(
117116
self,
118-
batch_size: int,
119117
dim: int,
120118
config: cfg.AttentionConfig,
121119
enable_hlfb: bool,
122120
) -> None:
123121
"""Initialize an instance of CausalSelfAttention.
124122
125123
Args:
126-
batch_size (int): batch size of the input tensor.
127124
dim (int): causal attention's input/output dimmension.
128125
config (cfg.AttentionConfig): attention specific configurations.
129126
enable_hlfb (bool): whether hlfb is enabled or not.
130127
"""
131128
super().__init__()
132129
self.kv_cache = None
133-
self.batch_size = batch_size
134130
qkv_shape = (
135131
config.num_heads + 2 * config.num_query_groups
136132
) * config.head_dim
@@ -179,11 +175,6 @@ def forward(
179175
"""
180176
# Batch size, sequence length, embedding dimensionality.
181177
B, T, E = x.size()
182-
assert B == self.batch_size, (
183-
"batch size of input tensor must match with the batch size specified in"
184-
" the model configuration."
185-
)
186-
187178
qkv = self.qkv_projection(x)
188179

189180
# Assemble into a number of query groups to support MHA, MQA and GQA.
@@ -290,7 +281,6 @@ class CrossAttention(nn.Module):
290281

291282
def __init__(
292283
self,
293-
batch_size: int,
294284
query_dim: int,
295285
cross_dim: int,
296286
hidden_dim: int,
@@ -301,7 +291,6 @@ def __init__(
301291
"""Initialize an instance of CrossAttention.
302292
303293
Args:
304-
batch_size (int): batch size of the input tensor.
305294
query_dim (int): query tensor's dimension.
306295
cross_dim (int): cross attention's dimensions, for key and value tensors.
307296
hidden_dim (int): hidden dimension that q, k, v tensors project to.

ai_edge_torch/generative/layers/kv_cache.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,11 @@
1818
import dataclasses
1919
from typing import List, Tuple
2020

21-
from ai_edge_torch import hlfb
2221
from ai_edge_torch.generative.layers import model_config
2322
from ai_edge_torch.generative.utilities.dynamic_update_slice import dynamic_update_slice
2423
import torch
2524
import torch.utils._pytree as pytree
2625

27-
BATCH_SIZE = 1
28-
2926

3027
@dataclasses.dataclass
3128
class KVCacheEntry:
@@ -45,9 +42,10 @@ def from_model_config(
4542
config: model_config.AttentionConfig,
4643
dtype: torch.dtype = torch.float32,
4744
device: torch.device = None,
45+
batch_size: int = 1,
4846
) -> "KVCacheEntry":
4947
"""Build an instance of the class based on model config."""
50-
shape = (BATCH_SIZE, kv_cache_max, config.num_query_groups, config.head_dim)
48+
shape = (batch_size, kv_cache_max, config.num_query_groups, config.head_dim)
5149
k = torch.zeros(shape, dtype=dtype, device=device)
5250
v = torch.zeros(shape, dtype=dtype, device=device)
5351
obj = cls(k_cache=k, v_cache=v)
@@ -66,6 +64,7 @@ def from_model_config(
6664
config: model_config.ModelConfig,
6765
dtype: torch.dtype = torch.float32,
6866
device: torch.device = None,
67+
batch_size: int = 1,
6968
) -> "KVCache":
7069
"""Build an instance of the class based on model config.
7170
@@ -75,17 +74,21 @@ def from_model_config(
7574
Defaults to torch.float32.
7675
device (torch.device, optional): The device placement of the cache
7776
tensors. Defaults to None.
77+
batch_size (int, optional): The batch size of the cache tensors.
78+
Defaults to 1.
7879
7980
Returns:
8081
KVCache: The created cache object.
8182
"""
8283
caches = [
8384
KVCacheEntry.from_model_config(
84-
config.kv_cache_max if not config.block_config(idx).kv_cache_max_len
85+
config.kv_cache_max
86+
if not config.block_config(idx).kv_cache_max_len
8587
else config.block_config(idx).kv_cache_max_len,
8688
config.block_config(idx).attn_config,
8789
dtype,
8890
device,
91+
batch_size,
8992
)
9093
for idx in range(config.num_layers)
9194
]

ai_edge_torch/generative/layers/model_config.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -220,9 +220,6 @@ class ModelConfig:
220220
# The maximum sequence length of the KV cache. Should not exceed max_seq_len.
221221
kv_cache_max_len: int = 0
222222

223-
# Default batch size of the exported model. Default value is 1.
224-
batch_size: int = 1
225-
226223
# Softcap on the model output logits.
227224
final_logit_softcap: Optional[float] = None
228225

ai_edge_torch/generative/utilities/converter.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,11 @@ def convert_to_tflite(
110110
lora_suffix = (
111111
'' if not lora_ranks else f'_lora{",".join(map(str, lora_ranks))}'
112112
)
113+
114+
if export_config is not None:
115+
if export_config.decode_batch_size > 1:
116+
output_name_prefix += f'_dbs{export_config.decode_batch_size}'
117+
113118
output_filename = (
114119
f'{output_name_prefix}_{quant_suffix}_ekv{kv_size}{lora_suffix}.tflite'
115120
)
@@ -162,9 +167,14 @@ def _export_helper(
162167
if prefill_masks:
163168
assert len(prefill_masks) == len(prefill_seq_lens)
164169

165-
decode_token = torch.tensor([[0]], dtype=torch.int)
170+
decode_token = torch.tensor(
171+
[[0] for _ in range(export_config.decode_batch_size)], dtype=torch.int
172+
)
166173
decode_input_pos = torch.tensor([0], dtype=torch.int)
167-
kv = export_config.kvcache_cls.from_model_config(config)
174+
prefill_kv = export_config.kvcache_cls.from_model_config(config)
175+
decode_kv = export_config.kvcache_cls.from_model_config(
176+
config, batch_size=export_config.decode_batch_size
177+
)
168178

169179
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
170180

@@ -183,7 +193,7 @@ def _export_helper(
183193
sample_kwargs = {
184194
'tokens': prefill_tokens,
185195
'input_pos': prefill_input_pos,
186-
'kv_cache': kv,
196+
'kv_cache': prefill_kv,
187197
}
188198
if prefill_masks is not None:
189199
sample_kwargs['mask'] = prefill_masks[i]
@@ -211,7 +221,7 @@ def _export_helper(
211221
sample_kwargs = {
212222
'tokens': decode_token,
213223
'input_pos': decode_input_pos,
214-
'kv_cache': kv,
224+
'kv_cache': decode_kv,
215225
}
216226
if export_config.decode_mask is not None:
217227
sample_kwargs['mask'] = export_config.decode_mask

ai_edge_torch/generative/utilities/model_builder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ class ExportConfig:
6060
decode_mask: Optional[torch.Tensor | List[torch.Tensor]] = None
6161
# The KV Cache class for K and V buffers in attention.
6262
kvcache_cls: type = kv_utils.KVCache
63+
# The batch size of the decode signature.
64+
decode_batch_size: int = 1
6365

6466

6567
class DecoderOnlyModel(nn.Module):

0 commit comments

Comments
 (0)