Skip to content

Commit 7ec9a19

Browse files
sirakiincopybara-github
authored andcommitted
No public description
PiperOrigin-RevId: 865424485
1 parent fd5c47f commit 7ec9a19

File tree

4 files changed

+39
-3
lines changed

4 files changed

+39
-3
lines changed

litert_torch/generative/export_hf/core/export_lib.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ def export_embedder_model(
326326
sample_kwargs=sample_inputs,
327327
)
328328
lrt_model = converter.convert(strict_export=False)
329-
model_path = os.path.join(work_dir, 'model.tflite')
329+
model_path = os.path.join(work_dir, 'embedder.tflite')
330330
lrt_model.export(model_path)
331331
quantization_recipe_list = (
332332
quantization_recipe.split(',') if quantization_recipe else [None]
@@ -359,7 +359,10 @@ def export_auxiliary_model(
359359
sample_kwargs=sample_input,
360360
)
361361
# Attention Mask
362-
attention_mask_module = split_cache_module.SplitAttentionMaskBuilder(model)
362+
attention_mask_module = split_cache_module.SplitAttentionMaskBuilder(
363+
export_config.cache_length,
364+
# TODO(weiyiw): Add sliding window sizes.
365+
)
363366
sample_inputs = attention_mask_module.get_sample_inputs(
364367
text_model_config, export_config
365368
)
@@ -370,7 +373,7 @@ def export_auxiliary_model(
370373
sample_kwargs=sample_input,
371374
)
372375
# Cache Update
373-
cache_update_module = split_cache_module.CacheUpdate(model)
376+
cache_update_module = split_cache_module.CacheUpdate()
374377
sample_inputs = cache_update_module.get_sample_inputs(
375378
text_model_config, export_config
376379
)

litert_torch/generative/export_hf/core/exportable_module_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class ExportableModuleConfig:
3131

3232
# Export configs
3333
externalize_embedder: bool = False
34+
single_token_embedder: bool = False
3435
externalize_rope: bool = False
3536

3637
split_cache: bool = False

litert_torch/generative/export_hf/core/external_emb/exportable_module.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,33 @@ def forward(
9494
token_ids = torch.maximum(token_ids, torch.tensor(0, dtype=torch.int32))
9595
output = self.model(token_ids)
9696
return {"embeddings": output}
97+
98+
@classmethod
99+
def get_sample_inputs(
100+
cls,
101+
model_config,
102+
export_config: base_exportable_module.ExportableModuleConfig,
103+
):
104+
"""Gets sample inputs."""
105+
batch_size = export_config.batch_size
106+
prefill_length = export_config.prefill_lengths[0]
107+
prefill_length_dim = export_config.prefill_length_dim
108+
del model_config # Unused.
109+
tokens = {"token_ids": torch.ones((batch_size, 1), dtype=torch.int32)}
110+
tokens_dynamic_shape = {"token_ids": {1: 1}} if prefill_length_dim else {}
111+
if export_config.single_token_embedder:
112+
return {"embedder": (tokens, tokens_dynamic_shape)}
113+
else:
114+
ret = {}
115+
ret["decode_embedder"] = (tokens, tokens_dynamic_shape)
116+
117+
tokens = {
118+
"token_ids": torch.ones(
119+
(batch_size, prefill_length), dtype=torch.int32
120+
)
121+
}
122+
tokens_dynamic_shape = (
123+
{"token_ids": {1: prefill_length_dim}} if prefill_length_dim else {}
124+
)
125+
ret[f"prefill_embedder_{prefill_length}"] = (tokens, tokens_dynamic_shape)
126+
return ret

litert_torch/generative/export_hf/export.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def export(
3131
quantization_recipe: str = 'dynamic_wi8_afp32',
3232
enable_dynamic_shape: bool = False,
3333
externalize_embedder: bool = False,
34+
single_token_embedder: bool = False,
3435
key_ts_idx: int = 2,
3536
value_ts_idx: int = 3,
3637
split_cache: bool = False,
@@ -62,6 +63,7 @@ def export(
6263
if enable_dynamic_shape
6364
else None,
6465
externalize_embedder=externalize_embedder,
66+
single_token_embedder=single_token_embedder,
6567
k_ts_idx=key_ts_idx,
6668
v_ts_idx=value_ts_idx,
6769
split_cache=split_cache,

0 commit comments

Comments
 (0)