diff --git a/torchao/prototype/awq/example.py b/torchao/prototype/awq/example.py index 0bbd1256e8..46c09c4f0a 100644 --- a/torchao/prototype/awq/example.py +++ b/torchao/prototype/awq/example.py @@ -17,6 +17,7 @@ from torchao.quantization import ( quantize_, ) +from torchao.quantization.quant_api import Int4WeightOnlyConfig # adapted from: https://github.com/mit-han-lab/llm-awq/blob/main/awq/entry.py#L255 @@ -223,18 +224,7 @@ def quantize_and_eval( if quant.startswith("awq-int4wo"): group_size = int(quant.split("-")[2]) print(f"running {quant} quantization with group size {group_size}") - # TODO: this is temporary, we'll be using Int4WeightOnlyConfig soon - from torchao.quantization import FbgemmConfig - - # use_hqq = True - # base_config = Int4WeightOnlyConfig(group_size=group_size, use_hqq=use_hqq) - base_config = FbgemmConfig( - input_dtype=torch.bfloat16, - weight_dtype=torch.int4, - output_dtype=torch.bfloat16, - block_size=[1, group_size], - preshuffle=False, - ) + base_config = Int4WeightOnlyConfig(group_size=group_size, use_hqq=True) print(f"running {quant} prepare and calibrate") t0 = time.time() quant_config = AWQConfig(base_config, step="prepare") @@ -267,17 +257,9 @@ def quantize_and_eval( elif quant.startswith("int4wo"): group_size = int(quant.split("-")[1]) print(f"running {quant} quantization with group size {group_size}") - # TODO: enable after refactor: https://github.com/pytorch/ao/pull/2474 - # use_hqq = "hqq" in quant - # base_config = Int4WeightOnlyConfig(group_size=group_size, use_hqq=use_hqq) - int4_weight_only_config = FbgemmConfig( - input_dtype=torch.bfloat16, - weight_dtype=torch.int4, - output_dtype=torch.bfloat16, - block_size=[1, group_size], - preshuffle=False, - ) - quantize_(model, int4_weight_only_config) + use_hqq = "hqq" in quant + base_config = Int4WeightOnlyConfig(group_size=group_size, use_hqq=use_hqq) + quantize_(model, base_config) if model_save_path is not None: print(f"Saving model to {model_save_path}")