Skip to content

Commit 8c52eb0

Browse files
haozha111copybara-github
authored andcommitted
Fix conversion issue of amd-llama-135M.
PiperOrigin-RevId: 728375648
1 parent 8eac09b commit 8c52eb0

File tree

2 files changed

+25
-16
lines changed

2 files changed

+25
-16
lines changed

ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -29,39 +29,48 @@
2929
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/amd-llama-135m'),
3030
'The path to the model checkpoint, or directory holding the checkpoint.',
3131
)
32-
_TFLITE_PATH = flags.DEFINE_string(
33-
'tflite_path',
34-
'/tmp/',
35-
'The tflite file path to export.',
36-
)
37-
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
38-
'prefill_seq_len',
39-
1024,
40-
'The maximum size of prefill input tensor.',
41-
)
4232
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
4333
'kv_cache_max_len',
4434
1280,
4535
'The maximum size of KV cache buffer, including both prefill and decode.',
4636
)
37+
_OUTPUT_PATH = flags.DEFINE_string(
38+
'output_path',
39+
'/tmp/',
40+
'The path to export the tflite model.',
41+
)
42+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
43+
'output_name_prefix',
44+
'deepseek',
45+
'The prefix of the output tflite model name.',
46+
)
47+
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
48+
'prefill_seq_lens',
49+
(8, 64, 128, 256, 512, 1024),
50+
'List of the maximum sizes of prefill input tensors.',
51+
)
4752
_QUANTIZE = flags.DEFINE_bool(
4853
'quantize',
4954
True,
5055
'Whether the model should be quantized.',
5156
)
52-
57+
_LORA_RANKS = flags.DEFINE_multi_integer(
58+
'lora_ranks',
59+
None,
60+
'If set, the model will be converted with the provided list of LoRA ranks.',
61+
)
5362

5463
def main(_):
5564
pytorch_model = amd_llama_135m.build_model(
5665
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
5766
)
58-
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
59-
output_filename = f'amd-llama-135m_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
6067
converter.convert_to_tflite(
6168
pytorch_model,
62-
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
63-
prefill_seq_len=_PREFILL_SEQ_LEN.value,
69+
output_path=_OUTPUT_PATH.value,
70+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
71+
prefill_seq_len=_PREFILL_SEQ_LENS.value,
6472
quantize=_QUANTIZE.value,
73+
lora_ranks=_LORA_RANKS.value,
6574
export_config=ExportConfig(),
6675
)
6776

ai_edge_torch/generative/examples/amd_llama_135m/verify.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def main(_):
5151
)
5252
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
5353
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
54-
reauthored_model = amd_llama_135m.build_model(reauthored_checkpoint)
54+
reauthored_model = amd_llama_135m.build_model(str(reauthored_checkpoint))
5555

5656
logging.info("Loading the tokenizer from: %s", checkpoint)
5757
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)

0 commit comments

Comments
 (0)