|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import copy |
| 8 | +import tempfile |
| 9 | +import unittest |
| 10 | + |
| 11 | +import torch |
| 12 | +import torch.nn as nn |
| 13 | +from parameterized import param, parameterized |
| 14 | +from torch import uint1, uint2, uint3, uint4 |
| 15 | +from torchao.prototype.quantization.codebook_groupwise.api import ( |
| 16 | + GroupwiseLutWeightConfig, |
| 17 | +) |
| 18 | +from torchao.quantization.quant_api import quantize_ |
| 19 | +from torchao.quantization.granularity import PerGroup |
| 20 | + |
| 21 | + |
| 22 | +class TestGroupwiseLowbitWeightLut(unittest.TestCase): |
| 23 | + """ |
| 24 | + Test suite for the GroupwiseLutWeight quantization scheme, updated for the |
| 25 | + new simplified API. |
| 26 | + """ |
| 27 | + |
| 28 | + TEST_CASES = [ |
| 29 | + param( |
| 30 | + weight_dtype=weight_dtype, |
| 31 | + lut_group_size=lut_group_size, |
| 32 | + scale_group_size=scale_group_size, |
| 33 | + model_dtype=model_dtype, |
| 34 | + has_bias=has_bias, |
| 35 | + has_scales=has_scales, |
| 36 | + ) |
| 37 | + for weight_dtype in [uint1, uint2, uint3, uint4] |
| 38 | + for lut_group_size, scale_group_size in [(256, 64), (256, 32)] |
| 39 | + for model_dtype in [torch.float32] |
| 40 | + for has_bias in [True, False] |
| 41 | + for has_scales in [True, False] |
| 42 | + ] |
| 43 | + |
| 44 | + # -------------------------------------------------------------------------- |
| 45 | + # Test 1: End-to-End Model Accuracy |
| 46 | + # -------------------------------------------------------------------------- |
| 47 | + @parameterized.expand(TEST_CASES) |
| 48 | + def test_e2e_accuracy_vs_reference( |
| 49 | + self, |
| 50 | + weight_dtype, |
| 51 | + lut_group_size, |
| 52 | + scale_group_size, |
| 53 | + model_dtype, |
| 54 | + has_bias, |
| 55 | + has_scales, |
| 56 | + ): |
| 57 | + """ |
| 58 | + Tests the numerical accuracy of the full quantized model against a reference. |
| 59 | + This now uses the `use_qdq_reference` flag instead of layout objects. |
| 60 | + """ |
| 61 | + m, k, n = 3, 64, 32 |
| 62 | + activations = torch.randn(m, k, dtype=model_dtype) |
| 63 | + model = nn.Sequential(nn.Linear(k, n, bias=has_bias)).to(dtype=model_dtype) |
| 64 | + |
| 65 | + lut_granularity = PerGroup(lut_group_size) |
| 66 | + scale_granularity = PerGroup(scale_group_size) if has_scales else None |
| 67 | + |
| 68 | + # --- Quantize using C++ ops --- |
| 69 | + quantized_model = copy.deepcopy(model) |
| 70 | + perf_config = GroupwiseLutWeightConfig( |
| 71 | + weight_dtype=weight_dtype, |
| 72 | + lut_granularity=lut_granularity, |
| 73 | + scale_granularity=scale_granularity, |
| 74 | + use_qdq_reference=False, # This creates the custom tensor |
| 75 | + ) |
| 76 | + quantize_(quantized_model, perf_config) |
| 77 | + with torch.no_grad(): |
| 78 | + actual_result = quantized_model(activations) |
| 79 | + |
| 80 | + # --- Quantize for Reference (using Python ops) --- |
| 81 | + reference_model = copy.deepcopy(model) |
| 82 | + ref_config = GroupwiseLutWeightConfig( |
| 83 | + weight_dtype=weight_dtype, |
| 84 | + lut_granularity=lut_granularity, |
| 85 | + scale_granularity=scale_granularity, |
| 86 | + use_qdq_reference=True, |
| 87 | + ) |
| 88 | + quantize_(reference_model, ref_config) |
| 89 | + with torch.no_grad(): |
| 90 | + expected_result = reference_model(activations) |
| 91 | + # Compare results |
| 92 | + self.assertTrue( |
| 93 | + torch.allclose(actual_result, expected_result, atol=1e-2, rtol=1e-2) |
| 94 | + ) |
| 95 | + |
| 96 | + def tearDown(self): |
| 97 | + """ |
| 98 | + Clear the TorchDynamo cache after each test case to prevent |
| 99 | + recompilation errors in parameterized tests. |
| 100 | + """ |
| 101 | + super().tearDown() |
| 102 | + torch._dynamo.reset() |
| 103 | + |
| 104 | + # -------------------------------------------------------------------------- |
| 105 | + # Test 2: Deployment Readiness (Updated for new API) |
| 106 | + # -------------------------------------------------------------------------- |
| 107 | + @parameterized.expand(TEST_CASES) |
| 108 | + def test_export_compile_aoti( |
| 109 | + self, |
| 110 | + weight_dtype, |
| 111 | + lut_group_size, |
| 112 | + scale_group_size, |
| 113 | + model_dtype, |
| 114 | + has_bias, |
| 115 | + has_scales, |
| 116 | + ): |
| 117 | + """ |
| 118 | + Tests that the quantized model can be exported and compiled. |
| 119 | + """ |
| 120 | + k, n = 64, 32 |
| 121 | + activations = torch.randn(2, k, dtype=model_dtype) |
| 122 | + model = ( |
| 123 | + nn.Sequential(nn.Linear(k, n, bias=has_bias)).to(dtype=model_dtype).eval() |
| 124 | + ) |
| 125 | + |
| 126 | + # Configure the quantization using the new API |
| 127 | + config = GroupwiseLutWeightConfig( |
| 128 | + weight_dtype=weight_dtype, |
| 129 | + lut_granularity=PerGroup(lut_group_size), |
| 130 | + scale_granularity=PerGroup(scale_group_size) if has_scales else None, |
| 131 | + use_qdq_reference=False, # Ensure we are testing the custom tensor |
| 132 | + ) |
| 133 | + quantize_(model, config) |
| 134 | + |
| 135 | + with torch.no_grad(): |
| 136 | + eager_results = model(activations) |
| 137 | + |
| 138 | + # Export and Compile |
| 139 | + exported_model = torch.export.export(model, (activations,)) |
| 140 | + compiled_model = torch.compile(model, fullgraph=True) |
| 141 | + |
| 142 | + with tempfile.TemporaryDirectory() as tmpdir, torch.no_grad(): |
| 143 | + # Check exported model |
| 144 | + exported_results = exported_model.module()(activations) |
| 145 | + self.assertTrue( |
| 146 | + torch.allclose(eager_results, exported_results, atol=1e-3, rtol=1e-3) |
| 147 | + ) |
| 148 | + |
| 149 | + # Check compiled model |
| 150 | + compiled_results = compiled_model(activations) |
| 151 | + self.assertTrue( |
| 152 | + torch.allclose(eager_results, compiled_results, atol=1e-3, rtol=1e-3) |
| 153 | + ) |
| 154 | + |
| 155 | + # Check AOTI compiled model using the packaging API |
| 156 | + package_path = f"{tmpdir}/model.pt2" |
| 157 | + torch._inductor.aoti_compile_and_package( |
| 158 | + exported_model, package_path=package_path |
| 159 | + ) |
| 160 | + aoti_model = torch._inductor.aoti_load_package(package_path) |
| 161 | + aoti_results = aoti_model(activations) |
| 162 | + self.assertTrue( |
| 163 | + torch.allclose(eager_results, aoti_results, atol=1e-3, rtol=1e-3) |
| 164 | + ) |
| 165 | + |
| 166 | + |
| 167 | +if __name__ == "__main__": |
| 168 | + unittest.main() |
0 commit comments