diff --git a/test/prototype/test_awq.py b/test/prototype/test_awq.py index 5538fa513d..42e266205e 100644 --- a/test/prototype/test_awq.py +++ b/test/prototype/test_awq.py @@ -5,14 +5,15 @@ # LICENSE file in the root directory of this source tree. import copy import tempfile -import unittest import torch +from parameterized import parameterized from torch.testing._internal.common_utils import ( TestCase, run_tests, ) +from torchao.dtypes import Int4CPULayout from torchao.prototype.awq import AWQConfig, AWQStep from torchao.quantization import FbgemmConfig, Int4WeightOnlyConfig, quantize_ from torchao.utils import ( @@ -45,15 +46,15 @@ def forward(self, x): return x -@unittest.skipIf(not torch.cuda.is_available(), reason="CUDA not available") -@unittest.skipIf( - not _is_fbgemm_genai_gpu_available(), - reason="need to install fbgemm_gpu_genai package", -) -@unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_6, - reason="torch.int4 needs torch 2.6+, can remove after we are not using FbgemmConfig", -) +devices = ["cpu"] +if ( + torch.cuda.is_available() + and _is_fbgemm_genai_gpu_available() + and TORCH_VERSION_AT_LEAST_2_6 +): + devices.append("cuda") + + class TestAWQ(TestCase): def test_awq_config(self): base_config = Int4WeightOnlyConfig() @@ -68,8 +69,8 @@ def test_awq_config(self): with self.assertRaisesRegex(ValueError, "is not one of"): AWQConfig(base_config, step="not_supported") - def test_awq_functionality(self): - device = "cuda" + @parameterized.expand([(device,) for device in devices]) + def test_awq_functionality(self, device): dataset_size = 100 l1, l2, l3 = 512, 256, 128 original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs @@ -80,13 +81,21 @@ def test_awq_functionality(self): m = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device) # baseline quantization - base_config = FbgemmConfig( - input_dtype=torch.bfloat16, - weight_dtype=torch.int4, - output_dtype=torch.bfloat16, - block_size=[1, group_size], - preshuffle=False, - ) + if device == "cuda": + base_config = FbgemmConfig( + input_dtype=torch.bfloat16, + weight_dtype=torch.int4, + output_dtype=torch.bfloat16, + block_size=[1, group_size], + preshuffle=False, + ) + elif device == "cpu": + base_config = Int4WeightOnlyConfig( + group_size=group_size, layout=Int4CPULayout(), set_inductor_config=False + ) + torch.manual_seed(1234) + else: + assert False, "Unsupported device: {}".format(device) m_baseline = copy.deepcopy(m) quantize_(m_baseline, base_config) @@ -117,8 +126,8 @@ def test_awq_functionality(self): loss_base = (ref_out - baseline_out).pow(2).mean().item() assert loss_awq < loss_base - def test_awq_loading(self): - device = "cuda" + @parameterized.expand([(device,) for device in devices]) + def test_awq_loading(self, device): dataset_size = 100 l1, l2, l3 = 512, 256, 128 original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs @@ -136,13 +145,20 @@ def test_awq_loading(self): calibration_data = dataset[:n_calibration_examples] # calibrate - base_config = FbgemmConfig( - input_dtype=torch.bfloat16, - weight_dtype=torch.int4, - output_dtype=torch.bfloat16, - block_size=[1, group_size], - preshuffle=False, - ) + if device == "cuda": + base_config = FbgemmConfig( + input_dtype=torch.bfloat16, + weight_dtype=torch.int4, + output_dtype=torch.bfloat16, + block_size=[1, group_size], + preshuffle=False, + ) + elif device == "cpu": + base_config = Int4WeightOnlyConfig( + group_size=group_size, layout=Int4CPULayout(), set_inductor_config=False + ) + else: + assert False, "Unsupported device: {}".format(device) quant_config = AWQConfig(base_config, step=AWQStep.PREPARE) quantize_(m, quant_config) @@ -171,14 +187,14 @@ def test_awq_loading(self): assert awq_save_load_out is not None assert torch.allclose(awq_out, awq_save_load_out, atol=1e-2) - def test_awq_loading_vllm(self): + @parameterized.expand([(device,) for device in devices]) + def test_awq_loading_vllm(self, device): """Simulate weight loading in vllm: * prepare model weight to the same format (awq weight) * use weight.copy_(state_dict["weight"]) to copy over the quantized weights from checkpoint There is also a slicing op that is ommitted here, overall e2e is tested in tests in vllm repo """ - device = "cuda" dataset_size = 100 l1, l2, l3 = 512, 256, 128 original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs @@ -196,13 +212,20 @@ def test_awq_loading_vllm(self): calibration_data = dataset[:n_calibration_examples] # calibrate - base_config = FbgemmConfig( - input_dtype=torch.bfloat16, - weight_dtype=torch.int4, - output_dtype=torch.bfloat16, - block_size=[1, group_size], - preshuffle=False, - ) + if device == "cuda": + base_config = FbgemmConfig( + input_dtype=torch.bfloat16, + weight_dtype=torch.int4, + output_dtype=torch.bfloat16, + block_size=[1, group_size], + preshuffle=False, + ) + elif device == "cpu": + base_config = Int4WeightOnlyConfig( + group_size=group_size, layout=Int4CPULayout(), set_inductor_config=False + ) + else: + assert False, "Unsupported device: {}".format(device) quant_config = AWQConfig(base_config, step=AWQStep.PREPARE) quantize_(m, quant_config) diff --git a/torchao/dtypes/uintx/int4_cpu_layout.py b/torchao/dtypes/uintx/int4_cpu_layout.py index da19bbc259..44665fd029 100644 --- a/torchao/dtypes/uintx/int4_cpu_layout.py +++ b/torchao/dtypes/uintx/int4_cpu_layout.py @@ -30,6 +30,17 @@ aten = torch.ops.aten +def _same_metadata(self: "Int4CPUAQTTensorImpl", src: "Int4CPUAQTTensorImpl") -> bool: + return ( + isinstance(self, Int4CPUAQTTensorImpl) + and isinstance(src, Int4CPUAQTTensorImpl) + and self.packed_weight.shape == src.packed_weight.shape + and self.scale_and_zero.shape == src.scale_and_zero.shape + and self.transposed == src.transposed + and type(self._layout) == type(src._layout) + ) + + @dataclass(frozen=True) class Int4CPULayout(Layout): """Layout class for int4 CPU layout for affine quantized tensor, used by tinygemm kernels `_weight_int4pack_mm_for_cpu`. @@ -208,6 +219,18 @@ def __torch_dispatch__(cls, func, types, args, kwargs): f"{cls.__name__} dispatch: attempting to run {func}, with dim={dim}, that is not supported" ) + if func is aten.copy_.default: + self = args[0] + src = args[1] + if _same_metadata(self, src): + self_tensors = self.__tensor_flatten__()[0] + for tensor_name in self_tensors: + getattr(self, tensor_name).copy_(getattr(src, tensor_name)) + return + raise ValueError( + f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}" + ) + raise NotImplementedError( f"{cls.__name__} dispatch: attempting to run {func}, this is not supported" ) diff --git a/torchao/prototype/awq/example.py b/torchao/prototype/awq/example.py index 0bbd1256e8..388c2fdc38 100644 --- a/torchao/prototype/awq/example.py +++ b/torchao/prototype/awq/example.py @@ -93,7 +93,9 @@ def wiki2_eval( # adapted from Hicham Badri (@mobicham) -def benchmark(model, tokenizer, max_length, tasks=None, device="cuda"): +def benchmark( + model, tokenizer, max_length, tasks=None, evaluation_limit=None, device="cuda" +): import lm_eval import numpy as np @@ -126,21 +128,33 @@ def benchmark(model, tokenizer, max_length, tasks=None, device="cuda"): for task in [("truthfulqa_mc2", 0)]: tag, fewshot = task results[tag] = lm_eval.evaluator.simple_evaluate( - model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size + model_eval, + tasks=[tag], + num_fewshot=fewshot, + batch_size=eval_batch_size, + limit=evaluation_limit, )["results"] print(tag, results[tag]) if "winogrande" in tasks: for task in [("winogrande", 5)]: tag, fewshot = task results[tag] = lm_eval.evaluator.simple_evaluate( - model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size + model_eval, + tasks=[tag], + num_fewshot=fewshot, + batch_size=eval_batch_size, + limit=evaluation_limit, )["results"] print(tag, results[tag]) if "arc_challenge" in tasks: for task in [("arc_challenge", 25)]: tag, fewshot = task results[tag] = lm_eval.evaluator.simple_evaluate( - model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size + model_eval, + tasks=[tag], + num_fewshot=fewshot, + batch_size=eval_batch_size, + limit=evaluation_limit, )["results"] print(tag, results[tag]) @@ -149,14 +163,22 @@ def benchmark(model, tokenizer, max_length, tasks=None, device="cuda"): for task in [("hellaswag", 10)]: tag, fewshot = task results[tag] = lm_eval.evaluator.simple_evaluate( - model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size + model_eval, + tasks=[tag], + num_fewshot=fewshot, + batch_size=eval_batch_size, + limit=evaluation_limit, )["results"] print(tag, results[tag]) if "gsm8k" in tasks: for task in [("gsm8k", 5)]: tag, fewshot = task results[tag] = lm_eval.evaluator.simple_evaluate( - model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size + model_eval, + tasks=[tag], + num_fewshot=fewshot, + batch_size=eval_batch_size, + limit=evaluation_limit, )["results"] print(tag, results[tag]) # ############################################ @@ -167,7 +189,11 @@ def benchmark(model, tokenizer, max_length, tasks=None, device="cuda"): for task in [("mmlu", 5)]: tag, fewshot = task results_mmlu[tag] = lm_eval.evaluator.simple_evaluate( - model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size + model_eval, + tasks=[tag], + num_fewshot=fewshot, + batch_size=eval_batch_size, + limit=evaluation_limit, )["results"] print(tag, results_mmlu[tag]) @@ -188,7 +214,11 @@ def benchmark(model, tokenizer, max_length, tasks=None, device="cuda"): for task in [("leaderboard_bbh", 3)]: tag, fewshot = task results[tag] = lm_eval.evaluator.simple_evaluate( - model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size + model_eval, + tasks=[tag], + num_fewshot=fewshot, + batch_size=eval_batch_size, + limit=evaluation_limit, )["results"] print(tag, results[tag]) results["bbh"] = results[tag] @@ -202,7 +232,7 @@ def quantize_and_eval( tasks: list[str], max_seq_length: int, calibration_limit: int, - validation_size: int, + evaluation_limit: int, device: str, precision: torch.dtype, compile: bool, @@ -223,18 +253,26 @@ def quantize_and_eval( if quant.startswith("awq-int4wo"): group_size = int(quant.split("-")[2]) print(f"running {quant} quantization with group size {group_size}") - # TODO: this is temporary, we'll be using Int4WeightOnlyConfig soon - from torchao.quantization import FbgemmConfig + from torchao.dtypes import Int4CPULayout + from torchao.quantization import FbgemmConfig, Int4WeightOnlyConfig # use_hqq = True # base_config = Int4WeightOnlyConfig(group_size=group_size, use_hqq=use_hqq) - base_config = FbgemmConfig( - input_dtype=torch.bfloat16, - weight_dtype=torch.int4, - output_dtype=torch.bfloat16, - block_size=[1, group_size], - preshuffle=False, - ) + if device == "cuda": + # TODO: this is temporary, we'll be using Int4WeightOnlyConfig for CUDA soon + base_config = FbgemmConfig( + input_dtype=torch.bfloat16, + weight_dtype=torch.int4, + output_dtype=torch.bfloat16, + block_size=[1, group_size], + preshuffle=False, + ) + elif device == "cpu": + base_config = Int4WeightOnlyConfig( + group_size=group_size, layout=Int4CPULayout(), set_inductor_config=False + ) + else: + assert False, "Unsupported device: {}".format(device) print(f"running {quant} prepare and calibrate") t0 = time.time() quant_config = AWQConfig(base_config, step="prepare") @@ -291,7 +329,14 @@ def quantize_and_eval( if compile: model = torch.compile(model) - return benchmark(model, tokenizer, max_seq_length, tasks=tasks, device=device) + return benchmark( + model, + tokenizer, + max_seq_length, + tasks=tasks, + evaluation_limit=evaluation_limit, + device=device, + ) if __name__ == "__main__": @@ -310,8 +355,8 @@ def quantize_and_eval( "--tasks", nargs="+", type=str, - help="Task to benchmark model on. Either PPL or QA", - default=["PPL"], + help="Task to benchmark model on. Here is the list of tasks you can use: https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/README.md", + default=["hellaswag"], ) parser.add_argument( "--calibration_limit", @@ -320,7 +365,10 @@ def quantize_and_eval( help="Number of samples to use for calibration. Default is 10.", ) parser.add_argument( - "--validation_size", type=int, default=1, help="Validation size. Default is 1." + "--evaluation_limit", + type=int, + default=None, + help="Number of samples to use for evaluation. Default is None (all).", ) parser.add_argument( "--device", @@ -368,7 +416,7 @@ def quantize_and_eval( args.tasks, args.max_seq_length, args.calibration_limit, - args.validation_size, + args.evaluation_limit, args.device, args.precision, args.compile,