Skip to content

Commit e029f9b

Browse files
talumbaucopybara-github
authored andcommitted
Prefill signatures don't do head FC or output logits
PiperOrigin-RevId: 704378247
1 parent 86e07ca commit e029f9b

File tree

18 files changed

+111
-13
lines changed

18 files changed

+111
-13
lines changed

ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from absl import flags
2323
from ai_edge_torch.generative.examples.amd_llama_135m import amd_llama_135m
2424
from ai_edge_torch.generative.utilities import converter
25+
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
2526

2627
_CHECKPOINT_PATH = flags.DEFINE_string(
2728
'checkpoint_path',
@@ -61,6 +62,7 @@ def main(_):
6162
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
6263
prefill_seq_len=_PREFILL_SEQ_LEN.value,
6364
quantize=_QUANTIZE.value,
65+
export_config=ExportConfig(),
6466
)
6567

6668

ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from absl import flags
2323
from ai_edge_torch.generative.examples.gemma import gemma1
2424
from ai_edge_torch.generative.utilities import converter
25+
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
2526

2627
_CHECKPOINT_PATH = flags.DEFINE_string(
2728
'checkpoint_path',
@@ -61,6 +62,7 @@ def main(_):
6162
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
6263
prefill_seq_len=_PREFILL_SEQ_LENS.value,
6364
quantize=_QUANTIZE.value,
65+
export_config=ExportConfig(),
6466
)
6567

6668

ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from absl import flags
2323
from ai_edge_torch.generative.examples.gemma import gemma2
2424
from ai_edge_torch.generative.utilities import converter
25+
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
2526

2627
_CHECKPOINT_PATH = flags.DEFINE_string(
2728
'checkpoint_path',
@@ -61,6 +62,7 @@ def main(_):
6162
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
6263
prefill_seq_len=_PREFILL_SEQ_LENS.value,
6364
quantize=_QUANTIZE.value,
65+
export_config=ExportConfig(),
6466
)
6567

6668

ai_edge_torch/generative/examples/gemma/gemma2.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import ai_edge_torch.generative.layers.attention_utils as attn_utils
2424
import ai_edge_torch.generative.layers.model_config as cfg
2525
import ai_edge_torch.generative.utilities.loader as loading_utils
26+
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
2627
import torch
2728
from torch import nn
2829

@@ -132,6 +133,7 @@ def forward(
132133
tokens: torch.Tensor,
133134
input_pos: torch.Tensor,
134135
kv_cache: kv_utils.KVCache,
136+
export_config: Optional[ExportConfig] = None,
135137
) -> dict[torch.Tensor, kv_utils.KVCache]:
136138
_, seq_len = tokens.size()
137139
assert self.config.max_seq_len >= seq_len, (
@@ -162,6 +164,13 @@ def forward(
162164
updated_kv_entires.append(kv_entry)
163165
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
164166

167+
if export_config is not None:
168+
if (
169+
torch.numel(input_pos) > 1
170+
and not export_config.output_logits_on_prefill
171+
):
172+
return {"kv_cache": updated_kv_cache}
173+
165174
x = self.final_norm(x)
166175
res = self.lm_head(x) # (b, t, vocab_size)
167176
if self.config.final_logit_softcap is not None:

ai_edge_torch/generative/examples/llama/convert_to_tflite.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from absl import flags
2323
from ai_edge_torch.generative.examples.llama import llama
2424
from ai_edge_torch.generative.utilities import converter
25+
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
2526

2627
_MODEL_SIZE = flags.DEFINE_enum(
2728
'model_size',
@@ -72,6 +73,7 @@ def main(_):
7273
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
7374
prefill_seq_len=_PREFILL_SEQ_LENS.value,
7475
quantize=_QUANTIZE.value,
76+
export_config=ExportConfig(),
7577
)
7678

7779

ai_edge_torch/generative/examples/openelm/convert_to_tflite.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from absl import flags
2323
from ai_edge_torch.generative.examples.openelm import openelm
2424
from ai_edge_torch.generative.utilities import converter
25+
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
2526

2627
_CHECKPOINT_PATH = flags.DEFINE_string(
2728
'checkpoint_path',
@@ -64,6 +65,7 @@ def main(_):
6465
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
6566
prefill_seq_len=_PREFILL_SEQ_LENS.value,
6667
quantize=_QUANTIZE.value,
68+
export_config=ExportConfig(),
6769
)
6870

6971

ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from absl import flags
2727
from ai_edge_torch.generative.examples.paligemma import paligemma
2828
from ai_edge_torch.generative.utilities import converter
29+
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
2930
import torch
3031

3132
_CHECKPOINT_PATH = flags.DEFINE_string(
@@ -73,6 +74,7 @@ def main(_):
7374
pixel_values_size=torch.Size(_PIXEL_VALUES_SIZE.value),
7475
quantize=_QUANTIZE.value,
7576
config=pytorch_model.config.decoder_config,
77+
export_config=ExportConfig(),
7678
)
7779

7880

ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from absl import flags
2323
from ai_edge_torch.generative.examples.phi import phi3
2424
from ai_edge_torch.generative.utilities import converter
25+
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
2526

2627
_CHECKPOINT_PATH = flags.DEFINE_string(
2728
'checkpoint_path',
@@ -61,6 +62,7 @@ def main(_):
6162
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
6263
prefill_seq_len=_PREFILL_SEQ_LENS.value,
6364
quantize=_QUANTIZE.value,
65+
export_config=ExportConfig(),
6466
)
6567

6668

ai_edge_torch/generative/examples/phi/convert_to_tflite.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from absl import flags
2323
from ai_edge_torch.generative.examples.phi import phi2
2424
from ai_edge_torch.generative.utilities import converter
25+
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
2526

2627
_CHECKPOINT_PATH = flags.DEFINE_string(
2728
'checkpoint_path',
@@ -61,6 +62,7 @@ def main(_):
6162
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
6263
prefill_seq_len=_PREFILL_SEQ_LENS.value,
6364
quantize=_QUANTIZE.value,
65+
export_config=ExportConfig(),
6466
)
6567

6668

ai_edge_torch/generative/examples/qwen/convert_to_tflite.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from absl import flags
2323
from ai_edge_torch.generative.examples.qwen import qwen
2424
from ai_edge_torch.generative.utilities import converter
25+
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
2526

2627
_MODEL_SIZE = flags.DEFINE_enum(
2728
'model_size',
@@ -76,6 +77,7 @@ def main(_):
7677
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
7778
prefill_seq_len=_PREFILL_SEQ_LENS.value,
7879
quantize=_QUANTIZE.value,
80+
export_config=ExportConfig(),
7981
)
8082

8183

0 commit comments

Comments
 (0)