|
29 | 29 | os.path.join(pathlib.Path.home(), 'Downloads/llm_data/amd-llama-135m'), |
30 | 30 | 'The path to the model checkpoint, or directory holding the checkpoint.', |
31 | 31 | ) |
32 | | -_TFLITE_PATH = flags.DEFINE_string( |
33 | | - 'tflite_path', |
34 | | - '/tmp/', |
35 | | - 'The tflite file path to export.', |
36 | | -) |
37 | | -_PREFILL_SEQ_LEN = flags.DEFINE_integer( |
38 | | - 'prefill_seq_len', |
39 | | - 1024, |
40 | | - 'The maximum size of prefill input tensor.', |
41 | | -) |
42 | 32 | _KV_CACHE_MAX_LEN = flags.DEFINE_integer( |
43 | 33 | 'kv_cache_max_len', |
44 | 34 | 1280, |
45 | 35 | 'The maximum size of KV cache buffer, including both prefill and decode.', |
46 | 36 | ) |
| 37 | +_OUTPUT_PATH = flags.DEFINE_string( |
| 38 | + 'output_path', |
| 39 | + '/tmp/', |
| 40 | + 'The path to export the tflite model.', |
| 41 | +) |
| 42 | +_OUTPUT_NAME_PREFIX = flags.DEFINE_string( |
| 43 | + 'output_name_prefix', |
| 44 | + 'deepseek', |
| 45 | + 'The prefix of the output tflite model name.', |
| 46 | +) |
| 47 | +_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer( |
| 48 | + 'prefill_seq_lens', |
| 49 | + (8, 64, 128, 256, 512, 1024), |
| 50 | + 'List of the maximum sizes of prefill input tensors.', |
| 51 | +) |
47 | 52 | _QUANTIZE = flags.DEFINE_bool( |
48 | 53 | 'quantize', |
49 | 54 | True, |
50 | 55 | 'Whether the model should be quantized.', |
51 | 56 | ) |
52 | | - |
| 57 | +_LORA_RANKS = flags.DEFINE_multi_integer( |
| 58 | + 'lora_ranks', |
| 59 | + None, |
| 60 | + 'If set, the model will be converted with the provided list of LoRA ranks.', |
| 61 | +) |
53 | 62 |
|
54 | 63 | def main(_): |
55 | 64 | pytorch_model = amd_llama_135m.build_model( |
56 | 65 | _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value |
57 | 66 | ) |
58 | | - quant_suffix = 'q8' if _QUANTIZE.value else 'f32' |
59 | | - output_filename = f'amd-llama-135m_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite' |
60 | 67 | converter.convert_to_tflite( |
61 | 68 | pytorch_model, |
62 | | - tflite_path=os.path.join(_TFLITE_PATH.value, output_filename), |
63 | | - prefill_seq_len=_PREFILL_SEQ_LEN.value, |
| 69 | + output_path=_OUTPUT_PATH.value, |
| 70 | + output_name_prefix=_OUTPUT_NAME_PREFIX.value, |
| 71 | + prefill_seq_len=_PREFILL_SEQ_LENS.value, |
64 | 72 | quantize=_QUANTIZE.value, |
| 73 | + lora_ranks=_LORA_RANKS.value, |
65 | 74 | export_config=ExportConfig(), |
66 | 75 | ) |
67 | 76 |
|
|
0 commit comments