diff --git a/ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py b/ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py index 17d68015..336a3fe6 100644 --- a/ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +++ b/ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py @@ -138,9 +138,7 @@ def convert_stable_diffusion_to_tflite( if not os.path.exists(output_dir): pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True) - quant_config = ( - quant_recipes.full_int8_weight_only_recipe() if quantize else None - ) + quant_config = quant_recipes.full_weight_only_recipe() if quantize else None # TODO(yichunk): convert to multi signature tflite model. # CLIP text encoder diff --git a/ai_edge_torch/generative/quantize/example.py b/ai_edge_torch/generative/quantize/example.py index b7bccc92..0162ed3d 100644 --- a/ai_edge_torch/generative/quantize/example.py +++ b/ai_edge_torch/generative/quantize/example.py @@ -33,7 +33,7 @@ def main(): kv = kv_utils.KVCache.from_model_config(config) # Create a quantization recipe to be applied to the model - quant_config = quant_recipes.full_int8_dynamic_recipe() + quant_config = quant_recipes.full_dynamic_recipe() print(quant_config) # Convert with quantization diff --git a/ai_edge_torch/generative/quantize/quant_attrs.py b/ai_edge_torch/generative/quantize/quant_attrs.py index 8a869382..1eccf2e4 100644 --- a/ai_edge_torch/generative/quantize/quant_attrs.py +++ b/ai_edge_torch/generative/quantize/quant_attrs.py @@ -63,8 +63,15 @@ class Granularity(enum.Enum): NONE: Granularity not applicable to this quantization scheme. CHANNELWISE: Or per-channel quantization. Each channel of relevant tensors is quantized independently of one another. + BLOCKWISE_32: Blockwise quantization with block size 32. + BLOCKWISE_64: Blockwise quantization with block size 64. + BLOCKWISE_128: Blockwise quantization with block size 128. + BLOCKWISE_256: Blockwise quantization with block size 256. """ NONE = enum.auto() CHANNELWISE = enum.auto() - BLOCKWISE = enum.auto() + BLOCKWISE_32 = enum.auto() + BLOCKWISE_64 = enum.auto() + BLOCKWISE_128 = enum.auto() + BLOCKWISE_256 = enum.auto() diff --git a/ai_edge_torch/generative/quantize/quant_recipe.py b/ai_edge_torch/generative/quantize/quant_recipe.py index a287bfc2..c80d8288 100644 --- a/ai_edge_torch/generative/quantize/quant_recipe.py +++ b/ai_edge_torch/generative/quantize/quant_recipe.py @@ -39,7 +39,6 @@ class LayerQuantRecipe: mode: Type of quantization. algorithm: Algorithm for calculating quantization parameters. granularity: Granularity of quantization. - block_size: Size of the block for blockwise quantization. """ activation_dtype: quant_attrs.Dtype @@ -47,7 +46,6 @@ class LayerQuantRecipe: mode: quant_attrs.Mode algorithm: quant_attrs.Algorithm granularity: quant_attrs.Granularity - block_size: int = 0 def __str__(self): base_str = ( @@ -56,7 +54,6 @@ def __str__(self): f'{self.mode.name}, ' f'{self.algorithm.name}, ' f'{self.granularity.name}, ' - f'{self.block_size}' ) return f'{base_str})' @@ -77,16 +74,6 @@ def verify(self): and self.algorithm == supported[3] and self.granularity == supported[4] ): - if self.block_size > 0: - if ( - self.block_size % 32 == 0 - and self.granularity == quant_attrs.Granularity.BLOCKWISE - ): - is_valid = True - break - else: - is_valid = False - break is_valid = True break diff --git a/ai_edge_torch/generative/quantize/quant_recipe_utils.py b/ai_edge_torch/generative/quantize/quant_recipe_utils.py index 19b27ef7..4b47e584 100644 --- a/ai_edge_torch/generative/quantize/quant_recipe_utils.py +++ b/ai_edge_torch/generative/quantize/quant_recipe_utils.py @@ -32,23 +32,29 @@ from ai_edge_torch.generative.quantize import quant_recipe -def create_layer_quant_int8_dynamic() -> quant_recipe.LayerQuantRecipe: +def create_layer_quant_dynamic( + weight_dtype: quant_attrs.Dtype = quant_attrs.Dtype.INT8, + granularity: quant_attrs.Granularity = quant_attrs.Granularity.CHANNELWISE, +) -> quant_recipe.LayerQuantRecipe: return quant_recipe.LayerQuantRecipe( activation_dtype=quant_attrs.Dtype.FP32, - weight_dtype=quant_attrs.Dtype.INT8, + weight_dtype=weight_dtype, mode=quant_attrs.Mode.DYNAMIC_RANGE, algorithm=quant_attrs.Algorithm.MIN_MAX, - granularity=quant_attrs.Granularity.CHANNELWISE, + granularity=granularity, ) -def create_layer_quant_int8_weight_only() -> quant_recipe.LayerQuantRecipe: +def create_layer_quant_weight_only( + weight_dtype: quant_attrs.Dtype = quant_attrs.Dtype.INT8, + granularity: quant_attrs.Granularity = quant_attrs.Granularity.CHANNELWISE, +) -> quant_recipe.LayerQuantRecipe: return quant_recipe.LayerQuantRecipe( activation_dtype=quant_attrs.Dtype.FP32, - weight_dtype=quant_attrs.Dtype.INT8, + weight_dtype=weight_dtype, mode=quant_attrs.Mode.WEIGHT_ONLY, algorithm=quant_attrs.Algorithm.MIN_MAX, - granularity=quant_attrs.Granularity.CHANNELWISE, + granularity=granularity, ) @@ -60,16 +66,3 @@ def create_layer_quant_fp16() -> quant_recipe.LayerQuantRecipe: algorithm=quant_attrs.Algorithm.FLOAT_CAST, granularity=quant_attrs.Granularity.NONE, ) - - -def create_layer_quant_int4_dynamic_block( - block_size: int, -) -> quant_recipe.LayerQuantRecipe: - return quant_recipe.LayerQuantRecipe( - activation_dtype=quant_attrs.Dtype.FP32, - weight_dtype=quant_attrs.Dtype.INT4, - mode=quant_attrs.Mode.DYNAMIC_RANGE, - algorithm=quant_attrs.Algorithm.MIN_MAX, - granularity=quant_attrs.Granularity.BLOCKWISE, - block_size=block_size, - ) diff --git a/ai_edge_torch/generative/quantize/quant_recipes.py b/ai_edge_torch/generative/quantize/quant_recipes.py index 8e2ecb70..e4d9614e 100644 --- a/ai_edge_torch/generative/quantize/quant_recipes.py +++ b/ai_edge_torch/generative/quantize/quant_recipes.py @@ -29,28 +29,37 @@ from typing import Optional from ai_edge_torch.generative.layers import model_config +from ai_edge_torch.generative.quantize import quant_attrs from ai_edge_torch.generative.quantize import quant_recipe from ai_edge_torch.generative.quantize import quant_recipe_utils from ai_edge_torch.quantize import quant_config -def full_int8_dynamic_recipe( +def full_dynamic_recipe( mcfg: Optional[model_config.ModelConfig] = None, + weight_dtype: quant_attrs.Dtype = quant_attrs.Dtype.INT8, + granularity: quant_attrs.Granularity = quant_attrs.Granularity.CHANNELWISE, ) -> quant_config.QuantConfig: return quant_config.QuantConfig( generative_recipe=quant_recipe.GenerativeQuantRecipe( - default=quant_recipe_utils.create_layer_quant_int8_dynamic(), + default=quant_recipe_utils.create_layer_quant_dynamic( + weight_dtype, granularity + ), _model_config=mcfg, ) ) -def full_int8_weight_only_recipe( +def full_weight_only_recipe( mcfg: Optional[model_config.ModelConfig] = None, + weight_dtype: quant_attrs.Dtype = quant_attrs.Dtype.INT8, + granularity: quant_attrs.Granularity = quant_attrs.Granularity.CHANNELWISE, ) -> quant_config.QuantConfig: return quant_config.QuantConfig( generative_recipe=quant_recipe.GenerativeQuantRecipe( - default=quant_recipe_utils.create_layer_quant_int8_weight_only(), + default=quant_recipe_utils.create_layer_quant_weight_only( + weight_dtype, granularity + ), _model_config=mcfg, ) ) @@ -65,17 +74,3 @@ def full_fp16_recipe( _model_config=mcfg, ) ) - - -def all_supported_int4_dynamic_block_recipe( - block_size: int, - mcfg: Optional[model_config.ModelConfig] = None, -) -> quant_config.QuantConfig: - return quant_config.QuantConfig( - generative_recipe=quant_recipe.GenerativeQuantRecipe( - default=quant_recipe_utils.create_layer_quant_int4_dynamic_block( - block_size - ), - _model_config=mcfg, - ) - ) diff --git a/ai_edge_torch/generative/quantize/supported_schemes.py b/ai_edge_torch/generative/quantize/supported_schemes.py index 2b8bbb7a..ac1ab8d2 100644 --- a/ai_edge_torch/generative/quantize/supported_schemes.py +++ b/ai_edge_torch/generative/quantize/supported_schemes.py @@ -29,5 +29,7 @@ def get_supported_layer_schemes(): (_t.FP32, _t.INT8, _m.DYNAMIC_RANGE, _a.MIN_MAX, _g.CHANNELWISE), (_t.FP32, _t.INT8, _m.WEIGHT_ONLY, _a.MIN_MAX, _g.CHANNELWISE), (_t.FP32, _t.FP16, _m.WEIGHT_ONLY, _a.FLOAT_CAST, _g.NONE), - (_t.FP32, _t.INT4, _m.DYNAMIC_RANGE, _a.MIN_MAX, _g.BLOCKWISE), + (_t.FP32, _t.INT4, _m.DYNAMIC_RANGE, _a.MIN_MAX, _g.BLOCKWISE_32), + (_t.FP32, _t.INT4, _m.DYNAMIC_RANGE, _a.MIN_MAX, _g.BLOCKWISE_64), + (_t.FP32, _t.INT4, _m.DYNAMIC_RANGE, _a.MIN_MAX, _g.BLOCKWISE_128), ] diff --git a/ai_edge_torch/generative/test/test_quantize.py b/ai_edge_torch/generative/test/test_quantize.py index c2539bdd..4bd8ca13 100644 --- a/ai_edge_torch/generative/test/test_quantize.py +++ b/ai_edge_torch/generative/test/test_quantize.py @@ -79,18 +79,18 @@ def test_verify_invalid_recipes( Dtype.INT4, Mode.DYNAMIC_RANGE, Algorithm.MIN_MAX, - Granularity.BLOCKWISE, - 32, + Granularity.BLOCKWISE_32, + ), + ( + Dtype.FP32, + Dtype.INT4, + Mode.DYNAMIC_RANGE, + Algorithm.MIN_MAX, + Granularity.BLOCKWISE_128, ), ]) def test_verify_valid_recipes( - self, - activation, - weight, - mode, - algo, - granularity, - block_size=None, + self, activation, weight, mode, algo, granularity ): quant_recipe.LayerQuantRecipe( activation, weight, mode, algo, granularity @@ -108,21 +108,21 @@ def setUp(self): def _attention_int8_dynamic_recipe() -> quant_config.QuantConfig: return quant_config.QuantConfig( generative_recipe=quant_recipe.GenerativeQuantRecipe( - attention=quant_recipe_utils.create_layer_quant_int8_dynamic(), + attention=quant_recipe_utils.create_layer_quant_dynamic(), ) ) def _feedforward_int8_dynamic_recipe() -> quant_config.QuantConfig: return quant_config.QuantConfig( generative_recipe=quant_recipe.GenerativeQuantRecipe( - feedforward=quant_recipe_utils.create_layer_quant_int8_dynamic(), + feedforward=quant_recipe_utils.create_layer_quant_dynamic(), ) ) @parameterized.parameters([ (quant_recipes.full_fp16_recipe()), - (quant_recipes.full_int8_dynamic_recipe()), - (quant_recipes.full_int8_weight_only_recipe()), + (quant_recipes.full_dynamic_recipe()), + (quant_recipes.full_weight_only_recipe()), (_attention_int8_dynamic_recipe()), (_feedforward_int8_dynamic_recipe()), ]) @@ -148,7 +148,7 @@ def test_quantize_convert_toy_weight_sharing(self): idx = torch.unsqueeze(torch.arange(0, 100, dtype=torch.int), 0) input_pos = torch.arange(0, 100, dtype=torch.int) - quant_config = quant_recipes.full_int8_dynamic_recipe() + quant_config = quant_recipes.full_dynamic_recipe() quantized_model = ai_edge_torch.convert( pytorch_model, (idx, input_pos), quant_config=quant_config ) @@ -164,7 +164,9 @@ def test_quantize_convert_toy_blockwise(self): pytorch_model = toy_model.ToySingleLayerModel(config) idx = torch.unsqueeze(torch.arange(0, 100, dtype=torch.int), 0) input_pos = torch.arange(0, 100, dtype=torch.int) - quant_config = quant_recipes.all_supported_int4_dynamic_block_recipe(32) + quant_config = quant_recipes.full_dynamic_recipe( + weight_dtype=Dtype.INT4, granularity=Granularity.BLOCKWISE_32 + ) quantized_model = ai_edge_torch.convert( pytorch_model, (idx, input_pos), quant_config=quant_config ) @@ -175,17 +177,6 @@ def test_quantize_convert_toy_blockwise(self): "Quantized model isn't smaller than F32 model.", ) - def test_unsupported_block_size(self): - config = toy_model.get_model_config() - pytorch_model = toy_model.ToySingleLayerModel(config) - idx = torch.unsqueeze(torch.arange(0, 100, dtype=torch.int), 0) - input_pos = torch.arange(0, 100, dtype=torch.int) - self.assertRaises( - ValueError, - quant_recipes.all_supported_int4_dynamic_block_recipe, - 36, - ) - def test_quantize_convert_compare_toy(self): self.skipTest("b/338288901") config = toy_model_with_kv_cache.get_model_config() diff --git a/ai_edge_torch/generative/utilities/converter.py b/ai_edge_torch/generative/utilities/converter.py index 14f34ce2..388b149c 100644 --- a/ai_edge_torch/generative/utilities/converter.py +++ b/ai_edge_torch/generative/utilities/converter.py @@ -25,6 +25,7 @@ from ai_edge_torch.generative.layers import kv_cache as kv_utils from ai_edge_torch.generative.layers import lora as lora_utils import ai_edge_torch.generative.layers.model_config as cfg +from ai_edge_torch.generative.quantize import quant_attrs from ai_edge_torch.generative.quantize import quant_recipes from ai_edge_torch.generative.utilities import export_config as export_config_lib from ai_edge_torch.generative.utilities import litertlm_builder @@ -193,18 +194,22 @@ def get_quant_recipe_from_flag( case QuantizationName.NONE: return None case QuantizationName.DYNAMIC_INT8: - return quant_recipes.full_int8_dynamic_recipe(mcfg=model_config) + return quant_recipes.full_dynamic_recipe(mcfg=model_config) case QuantizationName.WEIGHT_ONLY_INT8: - return quant_recipes.full_int8_weight_only_recipe(mcfg=model_config) + return quant_recipes.full_weight_only_recipe(mcfg=model_config) case QuantizationName.FP16: return quant_recipes.full_fp16_recipe() case QuantizationName.DYNAMIC_INT4_BLOCK32: - return quant_recipes.all_supported_int4_dynamic_block_recipe( - 32, mcfg=model_config + return quant_recipes.full_dynamic_recipe( + mcfg=model_config, + weight_dtype=quant_attrs.Dtype.INT4, + granularity=quant_attrs.Granularity.BLOCKWISE_32, ) case QuantizationName.DYNAMIC_INT4_BLOCK128: return quant_recipes.all_supported_int4_dynamic_block_recipe( - 128, mcfg=model_config + mcfg=model_config, + weight_dtype=quant_attrs.Dtype.INT4, + granularity=quant_attrs.Granularity.BLOCKWISE_128, ) case _: raise ValueError(f'Unsupported quantization flag: {quantize}') diff --git a/ai_edge_torch/lowertools/translate_recipe.py b/ai_edge_torch/lowertools/translate_recipe.py index 7a399bae..d67682f2 100644 --- a/ai_edge_torch/lowertools/translate_recipe.py +++ b/ai_edge_torch/lowertools/translate_recipe.py @@ -80,8 +80,12 @@ def _get_granularity( return _QuantGranularity.CHANNELWISE if granularity == quant_attrs.Granularity.NONE: return _QuantGranularity.TENSORWISE - if granularity == quant_attrs.Granularity.BLOCKWISE: - return _QuantGranularity.BLOCKWISE + if granularity == quant_attrs.Granularity.BLOCKWISE_32: + return _QuantGranularity.BLOCKWISE_32 + if granularity == quant_attrs.Granularity.BLOCKWISE_64: + return _QuantGranularity.BLOCKWISE_64 + if granularity == quant_attrs.Granularity.BLOCKWISE_128: + return _QuantGranularity.BLOCKWISE_128 raise ValueError('Unimplemented granularity') @@ -108,7 +112,6 @@ def _set_quant_config( symmetric=True, granularity=_get_granularity(layer_recipe.granularity), dtype=_get_dtype_from_dtype(layer_recipe.weight_dtype), - block_size=layer_recipe.block_size, ), compute_precision=_get_compute_precision_from_mode(layer_recipe.mode), explicit_dequantize=_get_explicit_dequant_from_mode(