Skip to content

Commit 603e8ea

Browse files
hheydarycopybara-github
authored andcommitted
Add LoRA support to AI Edge Transformers.
PiperOrigin-RevId: 713026247
1 parent b183411 commit 603e8ea

File tree

16 files changed

+1009
-127
lines changed

16 files changed

+1009
-127
lines changed

ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,15 @@
2929
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b'),
3030
'The path to the model checkpoint, or directory holding the checkpoint.',
3131
)
32-
_TFLITE_PATH = flags.DEFINE_string(
33-
'tflite_path',
32+
_OUTPUT_PATH = flags.DEFINE_string(
33+
'output_path',
3434
'/tmp/',
35-
'The tflite file path to export.',
35+
'The path to export the tflite model.',
36+
)
37+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38+
'output_name_prefix',
39+
'gemma',
40+
'The prefix of the output tflite model name.',
3641
)
3742
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
3843
'prefill_seq_lens',
@@ -49,19 +54,24 @@
4954
True,
5055
'Whether the model should be quantized.',
5156
)
57+
_LORA_RANKS = flags.DEFINE_multi_integer(
58+
'lora_ranks',
59+
None,
60+
'If set, the model will be converted with the provided list of LoRA ranks.',
61+
)
5262

5363

5464
def main(_):
5565
pytorch_model = gemma1.build_2b_model(
5666
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
5767
)
58-
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
59-
output_filename = f'gemma_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
6068
converter.convert_to_tflite(
6169
pytorch_model,
62-
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
70+
output_path=_OUTPUT_PATH.value,
71+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
6372
prefill_seq_len=_PREFILL_SEQ_LENS.value,
6473
quantize=_QUANTIZE.value,
74+
lora_ranks=_LORA_RANKS.value,
6575
export_config=ExportConfig(),
6676
)
6777

ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,15 @@
2929
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma2-2b'),
3030
'The path to the model checkpoint, or directory holding the checkpoint.',
3131
)
32-
_TFLITE_PATH = flags.DEFINE_string(
33-
'tflite_path',
32+
_OUTPUT_PATH = flags.DEFINE_string(
33+
'output_path',
3434
'/tmp/',
35-
'The tflite file path to export.',
35+
'The path to export the tflite model.',
36+
)
37+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38+
'output_name_prefix',
39+
'gemma2',
40+
'The prefix of the output tflite model name.',
3641
)
3742
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
3843
'prefill_seq_lens',
@@ -49,19 +54,24 @@
4954
True,
5055
'Whether the model should be quantized.',
5156
)
57+
_LORA_RANKS = flags.DEFINE_multi_integer(
58+
'lora_ranks',
59+
None,
60+
'If set, the model will be converted with the provided list of LoRA ranks.',
61+
)
5262

5363

5464
def main(_):
5565
pytorch_model = gemma2.build_2b_model(
5666
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
5767
)
58-
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
59-
output_filename = f'gemma2_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
6068
converter.convert_to_tflite(
6169
pytorch_model,
62-
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
70+
output_path=_OUTPUT_PATH.value,
71+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
6372
prefill_seq_len=_PREFILL_SEQ_LENS.value,
6473
quantize=_QUANTIZE.value,
74+
lora_ranks=_LORA_RANKS.value,
6575
export_config=ExportConfig(),
6676
)
6777

ai_edge_torch/generative/examples/llama/convert_to_tflite.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,15 @@
3535
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/llama'),
3636
'The path to the model checkpoint, or directory holding the checkpoint.',
3737
)
38-
_TFLITE_PATH = flags.DEFINE_string(
39-
'tflite_path',
38+
_OUTPUT_PATH = flags.DEFINE_string(
39+
'output_path',
4040
'/tmp/',
41-
'The tflite file path to export.',
41+
'The path to export the tflite model.',
42+
)
43+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
44+
'output_name_prefix',
45+
'llama',
46+
'The prefix of the output tflite model name.',
4247
)
4348
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
4449
'prefill_seq_lens',
@@ -55,6 +60,11 @@
5560
True,
5661
'Whether the model should be quantized.',
5762
)
63+
_LORA_RANKS = flags.DEFINE_multi_integer(
64+
'lora_ranks',
65+
None,
66+
'If set, the model will be converted with the provided list of LoRA ranks.',
67+
)
5868

5969
_BUILDER = {
6070
'1b': llama.build_1b_model,
@@ -66,13 +76,13 @@ def main(_):
6676
pytorch_model = _BUILDER[_MODEL_SIZE.value](
6777
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
6878
)
69-
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
70-
output_filename = f'llama_{_MODEL_SIZE.value}_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
7179
converter.convert_to_tflite(
7280
pytorch_model,
73-
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
81+
output_path=_OUTPUT_PATH.value,
82+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
7483
prefill_seq_len=_PREFILL_SEQ_LENS.value,
7584
quantize=_QUANTIZE.value,
85+
lora_ranks=_LORA_RANKS.value,
7686
export_config=ExportConfig(),
7787
)
7888

ai_edge_torch/generative/examples/openelm/convert_to_tflite.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,15 @@
2929
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/openelm'),
3030
'The path to the model checkpoint, or directory holding the checkpoint.',
3131
)
32-
_TFLITE_PATH = flags.DEFINE_string(
33-
'tflite_path',
32+
_OUTPUT_PATH = flags.DEFINE_string(
33+
'output_path',
3434
'/tmp/',
35-
'The tflite file path to export.',
35+
'The path to export the tflite model.',
36+
)
37+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38+
'output_name_prefix',
39+
'openelm',
40+
'The prefix of the output tflite model name.',
3641
)
3742
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
3843
'prefill_seq_lens',
@@ -49,22 +54,24 @@
4954
True,
5055
'Whether the model should be quantized.',
5156
)
57+
_LORA_RANKS = flags.DEFINE_multi_integer(
58+
'lora_ranks',
59+
None,
60+
'If set, the model will be converted with the provided list of LoRA ranks.',
61+
)
5262

5363

5464
def main(_):
5565
pytorch_model = openelm.build_model(
5666
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
5767
)
58-
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
59-
output_filename = (
60-
f'openelm_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
61-
)
62-
6368
converter.convert_to_tflite(
6469
pytorch_model,
65-
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
70+
output_path=_OUTPUT_PATH.value,
71+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
6672
prefill_seq_len=_PREFILL_SEQ_LENS.value,
6773
quantize=_QUANTIZE.value,
74+
lora_ranks=_LORA_RANKS.value,
6875
export_config=ExportConfig(),
6976
)
7077

ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,15 @@
4040
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/paligemma2-3b-224'),
4141
'The path to the model checkpoint, or directory holding the checkpoint.',
4242
)
43-
_TFLITE_PATH = flags.DEFINE_string(
44-
'tflite_path',
43+
_OUTPUT_PATH = flags.DEFINE_string(
44+
'output_path',
4545
'/tmp/',
46-
'The tflite file path to export.',
46+
'The path to export the tflite model.',
47+
)
48+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
49+
'output_name_prefix',
50+
'paligemma',
51+
'The prefix of the output tflite model name.',
4752
)
4853
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
4954
'prefill_seq_len',
@@ -73,11 +78,11 @@ def main(_):
7378
version=int(_VERSION.value),
7479
kv_cache_max_len=_KV_CACHE_MAX_LEN.value,
7580
)
76-
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
77-
output_filename = f'paligemma{_VERSION.value}_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
81+
7882
converter.convert_to_tflite(
7983
pytorch_model,
80-
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
84+
output_path=_OUTPUT_PATH.value,
85+
output_name_prefix=f'{_OUTPUT_NAME_PREFIX.value}_{_VERSION.value}',
8186
prefill_seq_len=_PREFILL_SEQ_LEN.value,
8287
pixel_values_size=torch.Size(_PIXEL_VALUES_SIZE.value),
8388
quantize=_QUANTIZE.value,

ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,18 @@
2626

2727
_CHECKPOINT_PATH = flags.DEFINE_string(
2828
'checkpoint_path',
29-
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi3'),
29+
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b'),
3030
'The path to the model checkpoint, or directory holding the checkpoint.',
3131
)
32-
_TFLITE_PATH = flags.DEFINE_string(
33-
'tflite_path',
32+
_OUTPUT_PATH = flags.DEFINE_string(
33+
'output_path',
3434
'/tmp/',
35-
'The tflite file path to export.',
35+
'The path to export the tflite model.',
36+
)
37+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38+
'output_name_prefix',
39+
'phi3',
40+
'The prefix of the output tflite model name.',
3641
)
3742
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
3843
'prefill_seq_lens',
@@ -49,19 +54,24 @@
4954
True,
5055
'Whether the model should be quantized.',
5156
)
57+
_LORA_RANKS = flags.DEFINE_multi_integer(
58+
'lora_ranks',
59+
None,
60+
'If set, the model will be converted with the provided list of LoRA ranks.',
61+
)
5262

5363

5464
def main(_):
5565
pytorch_model = phi3.build_model(
5666
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
5767
)
58-
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
59-
output_filename = f'phi3_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
6068
converter.convert_to_tflite(
6169
pytorch_model,
62-
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
70+
output_path=_OUTPUT_PATH.value,
71+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
6372
prefill_seq_len=_PREFILL_SEQ_LENS.value,
6473
quantize=_QUANTIZE.value,
74+
lora_ranks=_LORA_RANKS.value,
6575
export_config=ExportConfig(),
6676
)
6777

ai_edge_torch/generative/examples/phi/convert_to_tflite.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,15 @@
2929
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi2'),
3030
'The path to the model checkpoint, or directory holding the checkpoint.',
3131
)
32-
_TFLITE_PATH = flags.DEFINE_string(
33-
'tflite_path',
32+
_OUTPUT_PATH = flags.DEFINE_string(
33+
'output_path',
3434
'/tmp/',
35-
'The tflite file path to export.',
35+
'The path to export the tflite model.',
36+
)
37+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38+
'output_name_prefix',
39+
'phi2',
40+
'The prefix of the output tflite model name.',
3641
)
3742
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
3843
'prefill_seq_lens',
@@ -49,19 +54,24 @@
4954
True,
5055
'Whether the model should be quantized.',
5156
)
57+
_LORA_RANKS = flags.DEFINE_multi_integer(
58+
'lora_ranks',
59+
None,
60+
'If set, the model will be converted with the provided list of LoRA ranks.',
61+
)
5262

5363

5464
def main(_):
5565
pytorch_model = phi2.build_model(
5666
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
5767
)
58-
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
59-
output_filename = f'phi2_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
6068
converter.convert_to_tflite(
6169
pytorch_model,
62-
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
70+
output_path=_OUTPUT_PATH.value,
71+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
6372
prefill_seq_len=_PREFILL_SEQ_LENS.value,
6473
quantize=_QUANTIZE.value,
74+
lora_ranks=_LORA_RANKS.value,
6575
export_config=ExportConfig(),
6676
)
6777

ai_edge_torch/generative/examples/qwen/convert_to_tflite.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,15 @@
3535
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/qwen'),
3636
'The path to the model checkpoint, or directory holding the checkpoint.',
3737
)
38-
_TFLITE_PATH = flags.DEFINE_string(
39-
'tflite_path',
38+
_OUTPUT_PATH = flags.DEFINE_string(
39+
'output_path',
4040
'/tmp/',
41-
'The tflite file path to export.',
41+
'The path to export the tflite model.',
42+
)
43+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
44+
'output_name_prefix',
45+
'qwen',
46+
'The prefix of the output tflite model name.',
4247
)
4348
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
4449
'prefill_seq_lens',
@@ -55,6 +60,12 @@
5560
True,
5661
'Whether the model should be quantized.',
5762
)
63+
_LORA_RANKS = flags.DEFINE_multi_integer(
64+
'lora_ranks',
65+
None,
66+
'If set, the model will be converted with the provided list of LoRA ranks.',
67+
)
68+
5869

5970
_BUILDER = {
6071
'0.5b': qwen.build_0_5b_model,
@@ -67,16 +78,13 @@ def main(_):
6778
pytorch_model = _BUILDER[_MODEL_SIZE.value](
6879
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
6980
)
70-
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
71-
model_size = _MODEL_SIZE.value.replace('.', '_')
72-
output_filename = (
73-
f'qwen_{model_size}_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
74-
)
7581
converter.convert_to_tflite(
7682
pytorch_model,
77-
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
83+
output_path=_OUTPUT_PATH.value,
84+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
7885
prefill_seq_len=_PREFILL_SEQ_LENS.value,
7986
quantize=_QUANTIZE.value,
87+
lora_ranks=_LORA_RANKS.value,
8088
export_config=ExportConfig(),
8189
)
8290

0 commit comments

Comments
 (0)