1515
1616"""Common utility functions for model conversion."""
1717
18+ import enum
1819import os
1920import pathlib
2021from typing import Optional , Union
@@ -42,6 +43,27 @@ def forward(self, *export_args, **export_kwargs):
4243 return self .module (* export_args , ** full_kwargs )
4344
4445
46+ class QuantizationName (str , enum .Enum ):
47+ """Strings for all supported quantization recipes.
48+
49+ none: No quantization.
50+ dynamic_int8: Dynamic range quantization with int8 weights.
51+ weight_only_int8: Weight only quantization with int8 weights.
52+ fp16: Float16 quantization.
53+ dynamic_int4_block32: Dynamic range quantization with int4 weights and block
54+ size of 32, better model quality but slower inference.
55+ dynamic_int4_block128: Dynamic range quantization with int4 weights and block
56+ size of 128, faster inference but worse model quality.
57+ """
58+
59+ NONE = 'none'
60+ DYNAMIC_INT8 = 'dynamic_int8'
61+ WEIGHT_ONLY_INT8 = 'weight_only_int8'
62+ FP16 = 'fp16'
63+ DYNAMIC_INT4_BLOCK32 = 'dynamic_int4_block32'
64+ DYNAMIC_INT4_BLOCK128 = 'dynamic_int4_block128'
65+
66+
4567def define_conversion_flags (
4668 model_name : str ,
4769 default_mask_as_input : bool = False ,
@@ -74,10 +96,10 @@ def define_conversion_flags(
7496 1280 ,
7597 'The maximum size of KV cache buffer, including both prefill and decode.' ,
7698 )
77- flags .DEFINE_bool (
99+ flags .DEFINE_string (
78100 'quantize' ,
79- True ,
80- 'Whether the model should be quantized.' ,
101+ 'dynamic_int8' ,
102+ 'How the model should be quantized.' ,
81103 )
82104 flags .DEFINE_multi_integer (
83105 'lora_ranks' ,
@@ -99,6 +121,66 @@ def define_conversion_flags(
99121 return flags
100122
101123
124+ def get_quant_recipe_from_flag (
125+ quantize : str ,
126+ ) -> Optional [quant_recipes .QuantizationRecipe ]:
127+ """Processes the quantization flag and returns the corresponding recipe.
128+
129+ Args:
130+ quantize: The quantization type.
131+
132+ Returns:
133+ The quantization recipe, or None if no quantization is needed.
134+
135+ Raises:
136+ ValueError: If the quantization type is not supported.
137+ """
138+ match quantize :
139+ case QuantizationName .NONE :
140+ return None
141+ case QuantizationName .DYNAMIC_INT8 :
142+ return quant_recipes .full_int8_dynamic_recipe ()
143+ case QuantizationName .WEIGHT_ONLY_INT8 :
144+ return quant_recipes .full_int8_weight_only_recipe ()
145+ case QuantizationName .FP16 :
146+ return quant_recipes .full_fp16_recipe ()
147+ case QuantizationName .DYNAMIC_INT4_BLOCK32 :
148+ return quant_recipes .full_int4_dynamic_block_recipe (32 )
149+ case QuantizationName .DYNAMIC_INT4_BLOCK128 :
150+ return quant_recipes .full_int4_dynamic_block_recipe (128 )
151+ case _:
152+ raise ValueError (f'Unsupported quantization flag: { quantize } ' )
153+
154+
155+ def create_quantize_suffix (quantize : str ) -> str :
156+ """Creates a suffix for the output file name based on the quantization type.
157+
158+ Args:
159+ quantize: The quantization type.
160+
161+ Returns:
162+ A string representing the quantization suffix.
163+
164+ Raises:
165+ ValueError: If the quantization type is not supported.
166+ """
167+ match quantize :
168+ case QuantizationName .NONE :
169+ return 'f32'
170+ case QuantizationName .DYNAMIC_INT8 :
171+ return 'q8'
172+ case QuantizationName .WEIGHT_ONLY_INT8 :
173+ return 'q8_wo'
174+ case QuantizationName .FP16 :
175+ return 'fp16'
176+ case QuantizationName .DYNAMIC_INT4_BLOCK32 :
177+ return 'q4_block32'
178+ case QuantizationName .DYNAMIC_INT4_BLOCK128 :
179+ return 'q4_block128'
180+ case _:
181+ raise ValueError (f'Unsupported quantization flag: { quantize } ' )
182+
183+
102184def _build_mask (mask_len , kv_cache_max_len , causal_mask_value ) -> torch .Tensor :
103185 if isinstance (mask_len , list ):
104186 return [
@@ -118,7 +200,7 @@ def convert_to_tflite(
118200 prefill_seq_len : Union [int , list [int ]],
119201 pixel_values_size : torch .Size = None ,
120202 pixel_seq_len : int = 0 ,
121- quantize : bool = True ,
203+ quantize : str = 'dynamic_int8' ,
122204 config : cfg .ModelConfig = None ,
123205 lora_ranks : Optional [list [int ]] = None ,
124206 export_config : ExportConfig = None ,
@@ -164,8 +246,8 @@ def convert_to_tflite(
164246 embeddings generated by the image encoder with pixel values. The actual
165247 length of prefill_seq_len will be added by pixel_seq_len when pixel
166248 values are passed.
167- quantize (bool , optional): Whether the model should be quanized . Defaults
168- to True .
249+ quantize (str , optional): The quantization type . Defaults to
250+ 'dynamic_int8' .
169251 config (cfg.ModelConfig, optional): The model config used to configure KV
170252 cache. If None, it uses the config of the pytorch_model.
171253 lora_ranks (list[int], optional): The ranks of the LORA layers. If None,
@@ -186,7 +268,7 @@ def convert_to_tflite(
186268 lora = lora_utils .LoRA .zeros (rank , config )
187269 loras .append (lora )
188270
189- quant_suffix = 'q8' if quantize else 'f32'
271+ quant_suffix = create_quantize_suffix ( quantize )
190272 kv_size = config .kv_cache_max_len
191273 lora_suffix = (
192274 '' if not lora_ranks else f'_lora{ "," .join (map (str , lora_ranks ))} '
@@ -220,7 +302,7 @@ def _export_helper(
220302 prefill_seq_lens : list [int ],
221303 pixel_values_size : torch .Size ,
222304 pixel_seq_len : int ,
223- quantize : bool ,
305+ quantize : str ,
224306 config : cfg .ModelConfig ,
225307 loras : list [None | lora_utils .LoRA ],
226308 export_config : ExportConfig ,
@@ -269,7 +351,7 @@ def _export_helper(
269351 kv_layout = export_config .kvcache_layout ,
270352 )
271353
272- quant_config = quant_recipes . full_int8_dynamic_recipe () if quantize else None
354+ quant_config = get_quant_recipe_from_flag ( quantize )
273355 quant_config ._model_config = config
274356
275357 # For export, we create a module that captures any non-exportable,
0 commit comments