2626import  ai_edge_torch .generative .layers .model_config  as  cfg 
2727from  ai_edge_torch .generative .quantize  import  quant_recipes 
2828from  ai_edge_torch .generative .utilities  import  export_config 
29+ from  ai_edge_torch .quantize  import  quant_config  as  qcfg 
2930import  torch 
3031
3132ExportConfig  =  export_config .ExportConfig 
@@ -123,7 +124,8 @@ def define_conversion_flags(
123124
124125def  get_quant_recipe_from_flag (
125126    quantize : str ,
126- ) ->  Optional [quant_recipes .QuantizationRecipe ]:
127+     model_config : cfg .ModelConfig ,
128+ ) ->  Optional [qcfg .QuantConfig ]:
127129  """Processes the quantization flag and returns the corresponding recipe. 
128130
129131  Args: 
@@ -139,15 +141,19 @@ def get_quant_recipe_from_flag(
139141    case  QuantizationName .NONE :
140142      return  None 
141143    case  QuantizationName .DYNAMIC_INT8 :
142-       return  quant_recipes .full_int8_dynamic_recipe ()
144+       return  quant_recipes .full_int8_dynamic_recipe (mcfg = model_config )
143145    case  QuantizationName .WEIGHT_ONLY_INT8 :
144-       return  quant_recipes .full_int8_weight_only_recipe ()
146+       return  quant_recipes .full_int8_weight_only_recipe (mcfg = model_config )
145147    case  QuantizationName .FP16 :
146148      return  quant_recipes .full_fp16_recipe ()
147149    case  QuantizationName .DYNAMIC_INT4_BLOCK32 :
148-       return  quant_recipes .full_int4_dynamic_block_recipe (32 )
150+       return  quant_recipes .all_supported_int4_dynamic_block_recipe (
151+           32 , mcfg = model_config 
152+       )
149153    case  QuantizationName .DYNAMIC_INT4_BLOCK128 :
150-       return  quant_recipes .full_int4_dynamic_block_recipe (128 )
154+       return  quant_recipes .all_supported_int4_dynamic_block_recipe (
155+           128 , mcfg = model_config 
156+       )
151157    case  _:
152158      raise  ValueError (f'Unsupported quantization flag: { quantize }  )
153159
@@ -351,8 +357,7 @@ def _export_helper(
351357      kv_layout = export_config .kvcache_layout ,
352358  )
353359
354-   quant_config  =  get_quant_recipe_from_flag (quantize )
355-   quant_config ._model_config  =  config 
360+   quant_config  =  get_quant_recipe_from_flag (quantize , config )
356361
357362  # For export, we create a module that captures any non-exportable, 
358363  # arugments, e.g. the generation config object. 
0 commit comments