Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,7 @@ def convert_stable_diffusion_to_tflite(
if not os.path.exists(output_dir):
pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)

quant_config = (
quant_recipes.full_int8_weight_only_recipe() if quantize else None
)
quant_config = quant_recipes.full_weight_only_recipe() if quantize else None

# TODO(yichunk): convert to multi signature tflite model.
# CLIP text encoder
Expand Down
2 changes: 1 addition & 1 deletion ai_edge_torch/generative/quantize/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def main():
kv = kv_utils.KVCache.from_model_config(config)

# Create a quantization recipe to be applied to the model
quant_config = quant_recipes.full_int8_dynamic_recipe()
quant_config = quant_recipes.full_dynamic_recipe()
print(quant_config)

# Convert with quantization
Expand Down
9 changes: 8 additions & 1 deletion ai_edge_torch/generative/quantize/quant_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,15 @@ class Granularity(enum.Enum):
NONE: Granularity not applicable to this quantization scheme.
CHANNELWISE: Or per-channel quantization. Each channel of relevant tensors
is quantized independently of one another.
BLOCKWISE_32: Blockwise quantization with block size 32.
BLOCKWISE_64: Blockwise quantization with block size 64.
BLOCKWISE_128: Blockwise quantization with block size 128.
BLOCKWISE_256: Blockwise quantization with block size 256.
"""

NONE = enum.auto()
CHANNELWISE = enum.auto()
BLOCKWISE = enum.auto()
BLOCKWISE_32 = enum.auto()
BLOCKWISE_64 = enum.auto()
BLOCKWISE_128 = enum.auto()
BLOCKWISE_256 = enum.auto()
13 changes: 0 additions & 13 deletions ai_edge_torch/generative/quantize/quant_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,13 @@ class LayerQuantRecipe:
mode: Type of quantization.
algorithm: Algorithm for calculating quantization parameters.
granularity: Granularity of quantization.
block_size: Size of the block for blockwise quantization.
"""

activation_dtype: quant_attrs.Dtype
weight_dtype: quant_attrs.Dtype
mode: quant_attrs.Mode
algorithm: quant_attrs.Algorithm
granularity: quant_attrs.Granularity
block_size: int = 0

def __str__(self):
base_str = (
Expand All @@ -56,7 +54,6 @@ def __str__(self):
f'{self.mode.name}, '
f'{self.algorithm.name}, '
f'{self.granularity.name}, '
f'{self.block_size}'
)
return f'{base_str})'

Expand All @@ -77,16 +74,6 @@ def verify(self):
and self.algorithm == supported[3]
and self.granularity == supported[4]
):
if self.block_size > 0:
if (
self.block_size % 32 == 0
and self.granularity == quant_attrs.Granularity.BLOCKWISE
):
is_valid = True
break
else:
is_valid = False
break
is_valid = True
break

Expand Down
31 changes: 12 additions & 19 deletions ai_edge_torch/generative/quantize/quant_recipe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,23 +32,29 @@
from ai_edge_torch.generative.quantize import quant_recipe


def create_layer_quant_int8_dynamic() -> quant_recipe.LayerQuantRecipe:
def create_layer_quant_dynamic(
weight_dtype: quant_attrs.Dtype = quant_attrs.Dtype.INT8,
granularity: quant_attrs.Granularity = quant_attrs.Granularity.CHANNELWISE,
) -> quant_recipe.LayerQuantRecipe:
return quant_recipe.LayerQuantRecipe(
activation_dtype=quant_attrs.Dtype.FP32,
weight_dtype=quant_attrs.Dtype.INT8,
weight_dtype=weight_dtype,
mode=quant_attrs.Mode.DYNAMIC_RANGE,
algorithm=quant_attrs.Algorithm.MIN_MAX,
granularity=quant_attrs.Granularity.CHANNELWISE,
granularity=granularity,
)


def create_layer_quant_int8_weight_only() -> quant_recipe.LayerQuantRecipe:
def create_layer_quant_weight_only(
weight_dtype: quant_attrs.Dtype = quant_attrs.Dtype.INT8,
granularity: quant_attrs.Granularity = quant_attrs.Granularity.CHANNELWISE,
) -> quant_recipe.LayerQuantRecipe:
return quant_recipe.LayerQuantRecipe(
activation_dtype=quant_attrs.Dtype.FP32,
weight_dtype=quant_attrs.Dtype.INT8,
weight_dtype=weight_dtype,
mode=quant_attrs.Mode.WEIGHT_ONLY,
algorithm=quant_attrs.Algorithm.MIN_MAX,
granularity=quant_attrs.Granularity.CHANNELWISE,
granularity=granularity,
)


Expand All @@ -60,16 +66,3 @@ def create_layer_quant_fp16() -> quant_recipe.LayerQuantRecipe:
algorithm=quant_attrs.Algorithm.FLOAT_CAST,
granularity=quant_attrs.Granularity.NONE,
)


def create_layer_quant_int4_dynamic_block(
block_size: int,
) -> quant_recipe.LayerQuantRecipe:
return quant_recipe.LayerQuantRecipe(
activation_dtype=quant_attrs.Dtype.FP32,
weight_dtype=quant_attrs.Dtype.INT4,
mode=quant_attrs.Mode.DYNAMIC_RANGE,
algorithm=quant_attrs.Algorithm.MIN_MAX,
granularity=quant_attrs.Granularity.BLOCKWISE,
block_size=block_size,
)
31 changes: 13 additions & 18 deletions ai_edge_torch/generative/quantize/quant_recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,28 +29,37 @@

from typing import Optional
from ai_edge_torch.generative.layers import model_config
from ai_edge_torch.generative.quantize import quant_attrs
from ai_edge_torch.generative.quantize import quant_recipe
from ai_edge_torch.generative.quantize import quant_recipe_utils
from ai_edge_torch.quantize import quant_config


def full_int8_dynamic_recipe(
def full_dynamic_recipe(
mcfg: Optional[model_config.ModelConfig] = None,
weight_dtype: quant_attrs.Dtype = quant_attrs.Dtype.INT8,
granularity: quant_attrs.Granularity = quant_attrs.Granularity.CHANNELWISE,
) -> quant_config.QuantConfig:
return quant_config.QuantConfig(
generative_recipe=quant_recipe.GenerativeQuantRecipe(
default=quant_recipe_utils.create_layer_quant_int8_dynamic(),
default=quant_recipe_utils.create_layer_quant_dynamic(
weight_dtype, granularity
),
_model_config=mcfg,
)
)


def full_int8_weight_only_recipe(
def full_weight_only_recipe(
mcfg: Optional[model_config.ModelConfig] = None,
weight_dtype: quant_attrs.Dtype = quant_attrs.Dtype.INT8,
granularity: quant_attrs.Granularity = quant_attrs.Granularity.CHANNELWISE,
) -> quant_config.QuantConfig:
return quant_config.QuantConfig(
generative_recipe=quant_recipe.GenerativeQuantRecipe(
default=quant_recipe_utils.create_layer_quant_int8_weight_only(),
default=quant_recipe_utils.create_layer_quant_weight_only(
weight_dtype, granularity
),
_model_config=mcfg,
)
)
Expand All @@ -65,17 +74,3 @@ def full_fp16_recipe(
_model_config=mcfg,
)
)


def all_supported_int4_dynamic_block_recipe(
block_size: int,
mcfg: Optional[model_config.ModelConfig] = None,
) -> quant_config.QuantConfig:
return quant_config.QuantConfig(
generative_recipe=quant_recipe.GenerativeQuantRecipe(
default=quant_recipe_utils.create_layer_quant_int4_dynamic_block(
block_size
),
_model_config=mcfg,
)
)
4 changes: 3 additions & 1 deletion ai_edge_torch/generative/quantize/supported_schemes.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,7 @@ def get_supported_layer_schemes():
(_t.FP32, _t.INT8, _m.DYNAMIC_RANGE, _a.MIN_MAX, _g.CHANNELWISE),
(_t.FP32, _t.INT8, _m.WEIGHT_ONLY, _a.MIN_MAX, _g.CHANNELWISE),
(_t.FP32, _t.FP16, _m.WEIGHT_ONLY, _a.FLOAT_CAST, _g.NONE),
(_t.FP32, _t.INT4, _m.DYNAMIC_RANGE, _a.MIN_MAX, _g.BLOCKWISE),
(_t.FP32, _t.INT4, _m.DYNAMIC_RANGE, _a.MIN_MAX, _g.BLOCKWISE_32),
(_t.FP32, _t.INT4, _m.DYNAMIC_RANGE, _a.MIN_MAX, _g.BLOCKWISE_64),
(_t.FP32, _t.INT4, _m.DYNAMIC_RANGE, _a.MIN_MAX, _g.BLOCKWISE_128),
]
43 changes: 17 additions & 26 deletions ai_edge_torch/generative/test/test_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,18 +79,18 @@ def test_verify_invalid_recipes(
Dtype.INT4,
Mode.DYNAMIC_RANGE,
Algorithm.MIN_MAX,
Granularity.BLOCKWISE,
32,
Granularity.BLOCKWISE_32,
),
(
Dtype.FP32,
Dtype.INT4,
Mode.DYNAMIC_RANGE,
Algorithm.MIN_MAX,
Granularity.BLOCKWISE_128,
),
])
def test_verify_valid_recipes(
self,
activation,
weight,
mode,
algo,
granularity,
block_size=None,
self, activation, weight, mode, algo, granularity
):
quant_recipe.LayerQuantRecipe(
activation, weight, mode, algo, granularity
Expand All @@ -108,21 +108,21 @@ def setUp(self):
def _attention_int8_dynamic_recipe() -> quant_config.QuantConfig:
return quant_config.QuantConfig(
generative_recipe=quant_recipe.GenerativeQuantRecipe(
attention=quant_recipe_utils.create_layer_quant_int8_dynamic(),
attention=quant_recipe_utils.create_layer_quant_dynamic(),
)
)

def _feedforward_int8_dynamic_recipe() -> quant_config.QuantConfig:
return quant_config.QuantConfig(
generative_recipe=quant_recipe.GenerativeQuantRecipe(
feedforward=quant_recipe_utils.create_layer_quant_int8_dynamic(),
feedforward=quant_recipe_utils.create_layer_quant_dynamic(),
)
)

@parameterized.parameters([
(quant_recipes.full_fp16_recipe()),
(quant_recipes.full_int8_dynamic_recipe()),
(quant_recipes.full_int8_weight_only_recipe()),
(quant_recipes.full_dynamic_recipe()),
(quant_recipes.full_weight_only_recipe()),
(_attention_int8_dynamic_recipe()),
(_feedforward_int8_dynamic_recipe()),
])
Expand All @@ -148,7 +148,7 @@ def test_quantize_convert_toy_weight_sharing(self):
idx = torch.unsqueeze(torch.arange(0, 100, dtype=torch.int), 0)
input_pos = torch.arange(0, 100, dtype=torch.int)

quant_config = quant_recipes.full_int8_dynamic_recipe()
quant_config = quant_recipes.full_dynamic_recipe()
quantized_model = ai_edge_torch.convert(
pytorch_model, (idx, input_pos), quant_config=quant_config
)
Expand All @@ -164,7 +164,9 @@ def test_quantize_convert_toy_blockwise(self):
pytorch_model = toy_model.ToySingleLayerModel(config)
idx = torch.unsqueeze(torch.arange(0, 100, dtype=torch.int), 0)
input_pos = torch.arange(0, 100, dtype=torch.int)
quant_config = quant_recipes.all_supported_int4_dynamic_block_recipe(32)
quant_config = quant_recipes.full_dynamic_recipe(
weight_dtype=Dtype.INT4, granularity=Granularity.BLOCKWISE_32
)
quantized_model = ai_edge_torch.convert(
pytorch_model, (idx, input_pos), quant_config=quant_config
)
Expand All @@ -175,17 +177,6 @@ def test_quantize_convert_toy_blockwise(self):
"Quantized model isn't smaller than F32 model.",
)

def test_unsupported_block_size(self):
config = toy_model.get_model_config()
pytorch_model = toy_model.ToySingleLayerModel(config)
idx = torch.unsqueeze(torch.arange(0, 100, dtype=torch.int), 0)
input_pos = torch.arange(0, 100, dtype=torch.int)
self.assertRaises(
ValueError,
quant_recipes.all_supported_int4_dynamic_block_recipe,
36,
)

def test_quantize_convert_compare_toy(self):
self.skipTest("b/338288901")
config = toy_model_with_kv_cache.get_model_config()
Expand Down
15 changes: 10 additions & 5 deletions ai_edge_torch/generative/utilities/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from ai_edge_torch.generative.layers import kv_cache as kv_utils
from ai_edge_torch.generative.layers import lora as lora_utils
import ai_edge_torch.generative.layers.model_config as cfg
from ai_edge_torch.generative.quantize import quant_attrs
from ai_edge_torch.generative.quantize import quant_recipes
from ai_edge_torch.generative.utilities import export_config as export_config_lib
from ai_edge_torch.generative.utilities import litertlm_builder
Expand Down Expand Up @@ -193,18 +194,22 @@ def get_quant_recipe_from_flag(
case QuantizationName.NONE:
return None
case QuantizationName.DYNAMIC_INT8:
return quant_recipes.full_int8_dynamic_recipe(mcfg=model_config)
return quant_recipes.full_dynamic_recipe(mcfg=model_config)
case QuantizationName.WEIGHT_ONLY_INT8:
return quant_recipes.full_int8_weight_only_recipe(mcfg=model_config)
return quant_recipes.full_weight_only_recipe(mcfg=model_config)
case QuantizationName.FP16:
return quant_recipes.full_fp16_recipe()
case QuantizationName.DYNAMIC_INT4_BLOCK32:
return quant_recipes.all_supported_int4_dynamic_block_recipe(
32, mcfg=model_config
return quant_recipes.full_dynamic_recipe(
mcfg=model_config,
weight_dtype=quant_attrs.Dtype.INT4,
granularity=quant_attrs.Granularity.BLOCKWISE_32,
)
case QuantizationName.DYNAMIC_INT4_BLOCK128:
return quant_recipes.all_supported_int4_dynamic_block_recipe(
128, mcfg=model_config
mcfg=model_config,
weight_dtype=quant_attrs.Dtype.INT4,
granularity=quant_attrs.Granularity.BLOCKWISE_128,
)
case _:
raise ValueError(f'Unsupported quantization flag: {quantize}')
Expand Down
9 changes: 6 additions & 3 deletions ai_edge_torch/lowertools/translate_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,12 @@ def _get_granularity(
return _QuantGranularity.CHANNELWISE
if granularity == quant_attrs.Granularity.NONE:
return _QuantGranularity.TENSORWISE
if granularity == quant_attrs.Granularity.BLOCKWISE:
return _QuantGranularity.BLOCKWISE
if granularity == quant_attrs.Granularity.BLOCKWISE_32:
return _QuantGranularity.BLOCKWISE_32
if granularity == quant_attrs.Granularity.BLOCKWISE_64:
return _QuantGranularity.BLOCKWISE_64
if granularity == quant_attrs.Granularity.BLOCKWISE_128:
return _QuantGranularity.BLOCKWISE_128
raise ValueError('Unimplemented granularity')


Expand All @@ -108,7 +112,6 @@ def _set_quant_config(
symmetric=True,
granularity=_get_granularity(layer_recipe.granularity),
dtype=_get_dtype_from_dtype(layer_recipe.weight_dtype),
block_size=layer_recipe.block_size,
),
compute_precision=_get_compute_precision_from_mode(layer_recipe.mode),
explicit_dequantize=_get_explicit_dequant_from_mode(
Expand Down
Loading