Skip to content

Commit daa06ff

Browse files
zichuan-weicopybara-github
authored andcommitted
fix quantization_config error when running convert with new quantization flag
PiperOrigin-RevId: 757886360
1 parent 8e53f94 commit daa06ff

File tree

2 files changed

+29
-11
lines changed

2 files changed

+29
-11
lines changed

ai_edge_torch/generative/quantize/quant_recipes.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,43 +27,56 @@
2727
)
2828
"""
2929

30+
from typing import Optional
31+
from ai_edge_torch.generative.layers import model_config
3032
from ai_edge_torch.generative.quantize import quant_recipe
3133
from ai_edge_torch.generative.quantize import quant_recipe_utils
3234
from ai_edge_torch.quantize import quant_config
3335

3436

35-
def full_int8_dynamic_recipe() -> quant_config.QuantConfig:
37+
def full_int8_dynamic_recipe(
38+
mcfg: Optional[model_config.ModelConfig] = None,
39+
) -> quant_config.QuantConfig:
3640
return quant_config.QuantConfig(
3741
generative_recipe=quant_recipe.GenerativeQuantRecipe(
3842
default=quant_recipe_utils.create_layer_quant_int8_dynamic(),
43+
_model_config=mcfg,
3944
)
4045
)
4146

4247

43-
def full_int8_weight_only_recipe() -> quant_config.QuantConfig:
48+
def full_int8_weight_only_recipe(
49+
mcfg: Optional[model_config.ModelConfig] = None,
50+
) -> quant_config.QuantConfig:
4451
return quant_config.QuantConfig(
4552
generative_recipe=quant_recipe.GenerativeQuantRecipe(
4653
default=quant_recipe_utils.create_layer_quant_int8_weight_only(),
54+
_model_config=mcfg,
4755
)
4856
)
4957

5058

51-
def full_fp16_recipe() -> quant_config.QuantConfig:
59+
def full_fp16_recipe(
60+
mcfg: Optional[model_config.ModelConfig] = None,
61+
) -> quant_config.QuantConfig:
5262
return quant_config.QuantConfig(
5363
generative_recipe=quant_recipe.GenerativeQuantRecipe(
54-
default=quant_recipe_utils.create_layer_quant_fp16()
64+
default=quant_recipe_utils.create_layer_quant_fp16(),
65+
_model_config=mcfg,
5566
)
5667
)
5768

5869

5970
def all_supported_int4_dynamic_block_recipe(
6071
block_size: int,
72+
mcfg: Optional[model_config.ModelConfig] = None,
6173
) -> quant_config.QuantConfig:
6274
return quant_config.QuantConfig(
6375
generative_recipe=quant_recipe.GenerativeQuantRecipe(
6476
default=quant_recipe_utils.create_layer_quant_int4_dynamic_block(
6577
block_size
6678
),
6779
embedding=quant_recipe_utils.create_layer_quant_int8_dynamic(),
80+
_model_config=mcfg,
6881
)
6982
)

ai_edge_torch/generative/utilities/converter.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import ai_edge_torch.generative.layers.model_config as cfg
2727
from ai_edge_torch.generative.quantize import quant_recipes
2828
from ai_edge_torch.generative.utilities import export_config
29+
from ai_edge_torch.quantize import quant_config as qcfg
2930
import torch
3031

3132
ExportConfig = export_config.ExportConfig
@@ -123,7 +124,8 @@ def define_conversion_flags(
123124

124125
def get_quant_recipe_from_flag(
125126
quantize: str,
126-
) -> Optional[quant_recipes.QuantizationRecipe]:
127+
model_config: cfg.ModelConfig,
128+
) -> Optional[qcfg.QuantConfig]:
127129
"""Processes the quantization flag and returns the corresponding recipe.
128130
129131
Args:
@@ -139,15 +141,19 @@ def get_quant_recipe_from_flag(
139141
case QuantizationName.NONE:
140142
return None
141143
case QuantizationName.DYNAMIC_INT8:
142-
return quant_recipes.full_int8_dynamic_recipe()
144+
return quant_recipes.full_int8_dynamic_recipe(mcfg=model_config)
143145
case QuantizationName.WEIGHT_ONLY_INT8:
144-
return quant_recipes.full_int8_weight_only_recipe()
146+
return quant_recipes.full_int8_weight_only_recipe(mcfg=model_config)
145147
case QuantizationName.FP16:
146148
return quant_recipes.full_fp16_recipe()
147149
case QuantizationName.DYNAMIC_INT4_BLOCK32:
148-
return quant_recipes.full_int4_dynamic_block_recipe(32)
150+
return quant_recipes.all_supported_int4_dynamic_block_recipe(
151+
32, mcfg=model_config
152+
)
149153
case QuantizationName.DYNAMIC_INT4_BLOCK128:
150-
return quant_recipes.full_int4_dynamic_block_recipe(128)
154+
return quant_recipes.all_supported_int4_dynamic_block_recipe(
155+
128, mcfg=model_config
156+
)
151157
case _:
152158
raise ValueError(f'Unsupported quantization flag: {quantize}')
153159

@@ -351,8 +357,7 @@ def _export_helper(
351357
kv_layout=export_config.kvcache_layout,
352358
)
353359

354-
quant_config = get_quant_recipe_from_flag(quantize)
355-
quant_config._model_config = config
360+
quant_config = get_quant_recipe_from_flag(quantize, config)
356361

357362
# For export, we create a module that captures any non-exportable,
358363
# arugments, e.g. the generation config object.

0 commit comments

Comments
 (0)