Skip to content

Commit c1e5ce7

Browse files
hheydarycopybara-github
authored andcommitted
Harmonize the prefill signature name.
PiperOrigin-RevId: 715917372
1 parent c8d0b19 commit c1e5ce7

File tree

3 files changed

+6
-9
lines changed

3 files changed

+6
-9
lines changed

ai_edge_torch/generative/README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ The system is designed to help ML practitioners deploy their trained Large Langu
1818
* [Convert](#convert-pytorch-llm-to-a-tflite-model) the model, and get a TFLite Flatbuffer representing the mobile model.
1919
* Choose either approach below to deploy the end to end [LLM Inference Pipeline](#end-to-end-inference-pipeline).
2020

21-
For a more detailed explaination of how the system works, please refer to the [System Overview](doc/system_overview.md).
21+
For a more detailed explanation of how the system works, please refer to the [System Overview](doc/system_overview.md).
2222

2323
### Model Authoring using Edge Generative API
2424

@@ -67,7 +67,7 @@ https://github.com/google-ai-edge/ai-edge-torch/blob/853301630f2b2455bd2e2f73d8a
6767
Then export the model to TFLite with:
6868
https://github.com/google-ai-edge/ai-edge-torch/blob/853301630f2b2455bd2e2f73d8a47e1a1534c91c/ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py#L133-L139
6969

70-
Please note that using the `prefill` and `decode` method conventions are required for easy integration into the Mediapipe LLM Inference API.
70+
Please note that using the `prefill_{SEQ-LEN}` and `decode` method conventions are required for easy integration into the Mediapipe LLM Inference API.
7171

7272
To further optimize the on-device execution, a model can be exported with more than one prefill signature. As such, we use `prefill_{SEQ-LENS}` to export models with multiple prefill sequence lengths. During inference, the signature closest the input sequence length is used to minimize throwaway results.
7373

@@ -137,7 +137,7 @@ For an end-to-end example showing how to author, convert, quantize and execute,
137137
## What to expect
138138

139139
### Future Roadmap
140-
* Expanded accleration support on mobile, and web GPUs, and mobile NPUs.
140+
* Expanded acceleration support on mobile, and web GPUs, and mobile NPUs.
141141
* Advanced quantization approaches suitable for LLMs.
142142
* Expanded support of models, including Diffusion models.
143143
* LoRA support.

ai_edge_torch/generative/examples/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,8 +200,8 @@ export PYTHONPATH=$PWD/gemma_pytorch:$PYTHONPATH
200200
In this step, we use the `ai_edge_torch`'s standard multi-signature conversion API to convert PyTorch `nn.Module` to a single TFLite flatbuffer for on-device execution. For example, in `tiny_llama/convert_to_tflite.py`, we use this python code to convert the `TinyLlama` model to a multi-signature TFLite model:
201201
https://github.com/google-ai-edge/ai-edge-torch/blob/853301630f2b2455bd2e2f73d8a47e1a1534c91c/ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py#L26-L61
202202

203-
Once converted, you will get a `.tflite` model which will be ready for on-device execution. Note that the `.tflite` model generated uses static shapes. Inside the generated `.tflite` model, there will be two signatures defined (two entrypoints to the model):
204-
1) `prefill`: taking 2 tensor inputs `prefill_tokens`, `prefill_input_pos`. With shape `(BATCH_SIZE, PREFILL_SEQ_LEN)` and `(PREFILL_SEQ_LEN)`.
203+
Once converted, you will get a `.tflite` model which will be ready for on-device execution. Note that the `.tflite` model generated uses static shapes. Inside the generated `.tflite` model, there will be two signatures defined (two entry points to the model):
204+
1) `prefill_*`: taking 2 tensor inputs `prefill_tokens`, `prefill_input_pos`. With shape `(BATCH_SIZE, PREFILL_SEQ_LEN)` and `(PREFILL_SEQ_LEN)`.
205205
2) `decode`: taking 2 tensor inputs `decode_token`, `decode_input_pos`. With shape `(1, 1)` and `(1)`.
206206
To learn more about TFLite signatures, please refer to this [article](https://www.tensorflow.org/lite/guide/signatures).
207207

ai_edge_torch/generative/utilities/converter.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -167,10 +167,7 @@ def _export_helper(
167167
prefill_seq_len = prefill_seq_lens[i]
168168
prefill_tokens = prefill_tokens_list[i]
169169
prefill_input_pos = prefill_input_pos_list[i]
170-
if i == 0 and len(prefill_seq_lens) == 1:
171-
prefill_signature_name = 'prefill'
172-
else:
173-
prefill_signature_name = f'prefill_{prefill_seq_len}'
170+
prefill_signature_name = f'prefill_{prefill_seq_len}'
174171

175172
sample_kwargs = {
176173
'tokens': prefill_tokens,

0 commit comments

Comments
 (0)