diff --git a/.github/workflows/regression_test.yml b/.github/workflows/regression_test.yml index 2453e7eaaf..0858076551 100644 --- a/.github/workflows/regression_test.yml +++ b/.github/workflows/regression_test.yml @@ -59,12 +59,6 @@ jobs: fail-fast: false matrix: include: - - name: CUDA 2.5.1 - runs-on: linux.g5.12xlarge.nvidia.gpu - torch-spec: 'torch==2.5.1 --index-url https://download.pytorch.org/whl/cu121' - gpu-arch-type: "cuda" - gpu-arch-version: "12.6" - dev-requirements-overrides: "s/^pytest$/pytest==7.4.0/" - name: CUDA 2.6 runs-on: linux.g5.12xlarge.nvidia.gpu torch-spec: 'torch==2.6.0' @@ -77,13 +71,13 @@ jobs: gpu-arch-type: "cuda" gpu-arch-version: "12.6" dev-requirements-overrides: "" + - name: CUDA 2.8 + runs-on: linux.g5.12xlarge.nvidia.gpu + torch-spec: 'torch==2.8.0' + gpu-arch-type: "cuda" + gpu-arch-version: "12.6" + dev-requirements-overrides: "" - - name: CPU 2.5.1 - runs-on: linux.4xlarge - torch-spec: 'torch==2.5.1 --index-url https://download.pytorch.org/whl/cpu' - gpu-arch-type: "cpu" - gpu-arch-version: "" - dev-requirements-overrides: "s/^pytest$/pytest==7.4.0/" - name: CPU 2.6 runs-on: linux.4xlarge torch-spec: 'torch==2.6.0 --index-url https://download.pytorch.org/whl/cpu' @@ -96,6 +90,12 @@ jobs: gpu-arch-type: "cpu" gpu-arch-version: "" dev-requirements-overrides: "" + - name: CPU 2.8 + runs-on: linux.4xlarge + torch-spec: 'torch==2.8.0 --index-url https://download.pytorch.org/whl/cpu' + gpu-arch-type: "cpu" + gpu-arch-version: "" + dev-requirements-overrides: "" uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: diff --git a/benchmarks/benchmark_aq.py b/benchmarks/benchmark_aq.py index cdc6f6fe5a..7dd732debc 100644 --- a/benchmarks/benchmark_aq.py +++ b/benchmarks/benchmark_aq.py @@ -20,46 +20,26 @@ Int4WeightOnlyQuantizedLinearWeight, Int8WeightOnlyQuantizedLinearWeight, ) -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_4, - TORCH_VERSION_AT_LEAST_2_5, - unwrap_tensor_subclass, -) def _int8wo_api(mod, **kwargs): - if TORCH_VERSION_AT_LEAST_2_4: - quantize_(mod, int8_weight_only(**kwargs), set_inductor_config=False) - if not TORCH_VERSION_AT_LEAST_2_5: - unwrap_tensor_subclass(mod) - else: - change_linear_weights_to_int8_woqtensors(mod, **kwargs) + quantize_(mod, int8_weight_only(**kwargs), set_inductor_config=False) def _int8da_int8w_api(mod, **kwargs): - if TORCH_VERSION_AT_LEAST_2_4: - quantize_( - mod, - int8_dynamic_activation_int8_weight(**kwargs), - set_inductor_config=False, - ) - if not TORCH_VERSION_AT_LEAST_2_5: - unwrap_tensor_subclass(mod) - else: - change_linear_weights_to_int8_dqtensors(mod, **kwargs) + quantize_( + mod, + int8_dynamic_activation_int8_weight(**kwargs), + set_inductor_config=False, + ) def _int4wo_api(mod, **kwargs): - if TORCH_VERSION_AT_LEAST_2_4: - kwargs_copy = kwargs.copy() - if "groupsize" in kwargs_copy: - kwargs_copy["group_size"] = kwargs_copy["groupsize"] - del kwargs_copy["groupsize"] - quantize_(mod, int4_weight_only(**kwargs_copy), set_inductor_config=False) - if not TORCH_VERSION_AT_LEAST_2_5: - unwrap_tensor_subclass(mod) - else: - change_linear_weights_to_int4_woqtensors(mod, **kwargs) + kwargs_copy = kwargs.copy() + if "groupsize" in kwargs_copy: + kwargs_copy["group_size"] = kwargs_copy["groupsize"] + del kwargs_copy["groupsize"] + quantize_(mod, int4_weight_only(**kwargs_copy), set_inductor_config=False) class ToyLinearModel(torch.nn.Module): @@ -195,13 +175,12 @@ def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None): ) -if __name__ == "__main__" and TORCH_VERSION_AT_LEAST_2_4 and torch.cuda.is_available(): +if __name__ == "__main__" and torch.cuda.is_available(): all_shapes = [ (20, 2048, 2048), ] print("_int8da_int8w_api") - from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors for M, N, K in all_shapes: _bench_quantized_tensor_subclass_perf( @@ -209,7 +188,6 @@ def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None): ) print("_int8wo_api") - from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors for M, N, K in all_shapes: _bench_quantized_tensor_subclass_perf( @@ -218,7 +196,6 @@ def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None): print("_int4wo_api") kwargs = {"groupsize": 32} - from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors for M, N, K in all_shapes: _bench_quantized_tensor_subclass_perf( diff --git a/docs/source/pretraining.rst b/docs/source/pretraining.rst index da9659b9a0..2f60719ec5 100644 --- a/docs/source/pretraining.rst +++ b/docs/source/pretraining.rst @@ -161,10 +161,6 @@ Below is a code snippet showing how to use it: from torchao.float8.float8_linear_utils import convert_to_float8_training from torchao.float8.float8_linear import Float8Linear from torchao.float8 import convert_to_float8_training - from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 - - if not TORCH_VERSION_AT_LEAST_2_5: - raise AssertionError("torchao.float8 requires PyTorch version 2.5 or greater") # create model and sample input m = nn.Sequential( diff --git a/docs/source/quick_start.rst b/docs/source/quick_start.rst index c2e7a542df..a95316af99 100644 --- a/docs/source/quick_start.rst +++ b/docs/source/quick_start.rst @@ -95,16 +95,10 @@ it is also much faster! .. code:: py from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, benchmark_model, unwrap_tensor_subclass, ) - # Temporary workaround for tensor subclass + torch.compile - # Only needed for torch version < 2.5 - if not TORCH_VERSION_AT_LEAST_2_5: - unwrap_tensor_subclass(model) - num_runs = 100 torch._dynamo.reset() example_inputs = (torch.randn(1, 1024, dtype=torch.bfloat16, device="cuda"),) diff --git a/scripts/quick_start.py b/scripts/quick_start.py index 55c17a8684..6b56412f03 100644 --- a/scripts/quick_start.py +++ b/scripts/quick_start.py @@ -8,11 +8,7 @@ import torch from torchao.quantization import Int4WeightOnlyConfig, quantize_ -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, - benchmark_model, - unwrap_tensor_subclass, -) +from torchao.utils import benchmark_model # ================ # | Set up model | @@ -50,11 +46,6 @@ def forward(self, x): # | Benchmark | # ============= -# Temporary workaround for tensor subclass + torch.compile -# Only needed for torch version < 2.5 -if not TORCH_VERSION_AT_LEAST_2_5: - unwrap_tensor_subclass(model) - num_runs = 100 torch._dynamo.reset() example_inputs = (torch.randn(1, 1024, dtype=torch.bfloat16, device="cuda"),) diff --git a/test/core/test_config.py b/test/core/test_config.py index fc752d989e..9574c3ec76 100644 --- a/test/core/test_config.py +++ b/test/core/test_config.py @@ -39,7 +39,6 @@ UIntXWeightOnlyConfig, ) from torchao.sparsity.sparse_api import BlockSparseWeightConfig, SemiSparseWeightConfig -from torchao.utils import TORCH_VERSION_AT_LEAST_2_6 # Define test configurations as fixtures configs = [ @@ -85,11 +84,9 @@ ), AWQConfig(Int4WeightOnlyConfig(group_size=128), step=AWQStep.PREPARE_FOR_LOADING), AWQConfig(Int4WeightOnlyConfig(group_size=128), step="prepare_for_loading"), + FbgemmConfig(torch.bfloat16, torch.int4, torch.bfloat16, [1, 1, 256]), ] -if TORCH_VERSION_AT_LEAST_2_6: - configs += [FbgemmConfig(torch.bfloat16, torch.int4, torch.bfloat16, [1, 1, 256])] - # Create ids for better test naming def get_config_ids(configs): diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index bd5ed0c3b5..e27796bb74 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -41,7 +41,6 @@ from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain from torchao.testing.utils import skip_if_no_cuda, skip_if_no_gemlite, skip_if_rocm from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, check_cpu_version, check_xpu_version, is_fbcode, @@ -151,11 +150,7 @@ def test_weights_only(self): with tempfile.NamedTemporaryFile() as f: torch.save(ql.state_dict(), f) f.seek(0) - # `weights_only=True` is enabled for torch 2.5+ - if TORCH_VERSION_AT_LEAST_2_5: - _ = torch.load(f, weights_only=True) - else: - _ = torch.load(f, weights_only=False) + _ = torch.load(f, weights_only=True) @unittest.skipIf(len(GPU_DEVICES) == 0, "Need GPU available") @common_utils.parametrize("apply_quant", get_quantization_functions(False, False)) diff --git a/test/dtypes/test_affine_quantized_float.py b/test/dtypes/test_affine_quantized_float.py index 1dfed4dda8..d705b2cfe1 100644 --- a/test/dtypes/test_affine_quantized_float.py +++ b/test/dtypes/test_affine_quantized_float.py @@ -3,15 +3,6 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -import pytest - -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, -) - -if not TORCH_VERSION_AT_LEAST_2_5: - pytest.skip("Unsupported PyTorch version", allow_module_level=True) - import copy import io import random diff --git a/test/dtypes/test_affine_quantized_tensor_parallel.py b/test/dtypes/test_affine_quantized_tensor_parallel.py index c2eff77b07..fd5f43a470 100644 --- a/test/dtypes/test_affine_quantized_tensor_parallel.py +++ b/test/dtypes/test_affine_quantized_tensor_parallel.py @@ -24,7 +24,6 @@ ) from torchao.quantization.observer import PerRow, PerTensor from torchao.quantization.quant_api import quantize_ -from torchao.utils import TORCH_VERSION_AT_LEAST_2_6 if common_utils.SEED is None: common_utils.SEED = 1234 @@ -127,10 +126,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: dn_dist(up_dist(input_dtensor)) - if not TORCH_VERSION_AT_LEAST_2_6: - # Need torch 2.6 to support compiled tensor parallelism - return - up_compiled = torch.compile(up_dist) y_up = up_compiled(input_dtensor) dn_compiled = torch.compile(dn_dist) diff --git a/test/dtypes/test_floatx.py b/test/dtypes/test_floatx.py index 237bc2bd92..9a99ba0802 100644 --- a/test/dtypes/test_floatx.py +++ b/test/dtypes/test_floatx.py @@ -33,7 +33,7 @@ quantize_, ) from torchao.testing.utils import skip_if_rocm -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_fbcode +from torchao.utils import is_fbcode _DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) _Floatx_DTYPES = [(3, 2), (2, 2)] @@ -107,10 +107,6 @@ def test_to_copy_device(self, ebits, mbits): assert floatx_tensor_impl.device.type == "cpu" @unittest.skipIf(not torch.cuda.is_available(), reason="CUDA not available") - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_5, - reason="quantization only works with torch.compile for 2.5+", - ) @parametrize("ebits,mbits", _Floatx_DTYPES) @parametrize("bias", [False, True]) @parametrize("dtype", [torch.half, torch.bfloat16]) diff --git a/test/dtypes/test_uint4.py b/test/dtypes/test_uint4.py index f7656ef19e..aa9eccc903 100644 --- a/test/dtypes/test_uint4.py +++ b/test/dtypes/test_uint4.py @@ -34,7 +34,6 @@ _replace_with_custom_fn_if_matches_filter, ) from torchao.testing.utils import skip_if_rocm -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 def _apply_weight_only_uint4_quant(model): @@ -243,16 +242,10 @@ def forward(self, x): # program capture m = copy.deepcopy(m_eager) - if TORCH_VERSION_AT_LEAST_2_5: - m = torch.export.texport_for_training( - m, - example_inputs, - ).module() - else: - m = torch._export.capture_pre_autograd_graph( - m, - example_inputs, - ).module() + m = torch.export.texport_for_training( + m, + example_inputs, + ).module() m = prepare_pt2e(m, quantizer) # Calibrate diff --git a/test/dtypes/test_uintx.py b/test/dtypes/test_uintx.py index 35c722365d..dbc69b8ee9 100644 --- a/test/dtypes/test_uintx.py +++ b/test/dtypes/test_uintx.py @@ -14,24 +14,16 @@ dequantize_affine, quantize_affine, ) -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_3, - TORCH_VERSION_AT_LEAST_2_5, -) -# torch.uintx dtypes are introduced in 2.3 -if TORCH_VERSION_AT_LEAST_2_3: - dtypes = ( - torch.uint1, - torch.uint2, - torch.uint3, - torch.uint4, - torch.uint5, - torch.uint6, - torch.uint7, - ) -else: - dtypes = () +dtypes = ( + torch.uint1, + torch.uint2, + torch.uint3, + torch.uint4, + torch.uint5, + torch.uint6, + torch.uint7, +) group_sizes = [32, 64, 128] devices = ["cpu", "cuda"] @@ -65,9 +57,6 @@ def forward(self, x): @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("group_size", group_sizes) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_5, reason="only works with fix in the nightly build" -) def test_uintx_quant_on_cpu_then_move_to_cuda(dtype, group_size): scale = 512 fp16_mod_on_cpu = Linear16(scale, "cpu") @@ -86,9 +75,6 @@ def test_uintx_quant_on_cpu_then_move_to_cuda(dtype, group_size): @pytest.mark.parametrize("group_size", group_sizes) @pytest.mark.parametrize("device", devices) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_5, reason="only works with fix in the nightly build" -) def test_uintx_weight_only_model_quant(dtype, group_size, device): scale = 512 fp16 = Linear16(scale, device) @@ -103,9 +89,6 @@ def test_uintx_weight_only_model_quant(dtype, group_size, device): @pytest.mark.parametrize("group_size", group_sizes) @pytest.mark.parametrize("device", devices) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_5, reason="only works with fix in the nightly build" -) def test_uintx_weight_only_quant(dtype, group_size, device): input_float = torch.randn((1, 256), dtype=torch.float16, device=device) mapping_type = MappingType.SYMMETRIC @@ -140,9 +123,6 @@ def test_uintx_weight_only_quant(dtype, group_size, device): @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") -@pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_3, reason="sub byte dtype requires torch 2.3+" -) def test_uintx_target_dtype(dtype): from torchao.quantization.quant_api import uintx_weight_only @@ -154,10 +134,6 @@ def test_uintx_target_dtype(dtype): @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") -@pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_5, - reason="torch.compile without unwrap_tensor_subclass requires torch 2.5+", -) def test_uintx_target_dtype_compile(dtype): from torchao.quantization.quant_api import uintx_weight_only @@ -170,9 +146,6 @@ def test_uintx_target_dtype_compile(dtype): @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") -@pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_3, reason="sub byte dtype requires torch 2.3+" -) def test_uintx_model_size(dtype): from torchao.quantization.quant_api import uintx_weight_only from torchao.utils import get_model_size_in_bytes diff --git a/test/float8/test_base.py b/test/float8/test_base.py index c2b2c5488a..1f9ae19346 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -13,17 +13,6 @@ import torch import torch.nn as nn -from torchao.testing.utils import skip_if_rocm -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, - is_sm_at_least_89, - is_sm_at_least_90, -) - -if not TORCH_VERSION_AT_LEAST_2_5: - pytest.skip("Unsupported PyTorch version", allow_module_level=True) - - from torchao.float8.config import ( Float8LinearConfig, Float8LinearRecipeName, @@ -53,7 +42,13 @@ tensor_to_scale, ) from torchao.testing.training.test_utils import get_test_float8_linear_config -from torchao.utils import is_MI300, is_ROCM +from torchao.testing.utils import skip_if_rocm +from torchao.utils import ( + is_MI300, + is_ROCM, + is_sm_at_least_89, + is_sm_at_least_90, +) random.seed(0) torch.manual_seed(0) diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index a196d87430..04f03bb0ee 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -10,16 +10,6 @@ from io import StringIO import pytest - -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, - is_sm_at_least_89, - is_sm_at_least_90, -) - -if not TORCH_VERSION_AT_LEAST_2_5: - pytest.skip("Unsupported PyTorch version", allow_module_level=True) - import torch import torch.nn as nn from torch._dynamo.test_case import TestCase as DynamoTestCase @@ -42,6 +32,10 @@ ScaledMMConfig, ) from torchao.testing.training.test_utils import get_test_float8_linear_config +from torchao.utils import ( + is_sm_at_least_89, + is_sm_at_least_90, +) def _test_compile_base( diff --git a/test/float8/test_dtensor.py b/test/float8/test_dtensor.py index f357196785..7285d4bbc0 100644 --- a/test/float8/test_dtensor.py +++ b/test/float8/test_dtensor.py @@ -12,14 +12,7 @@ import os -import pytest import torch - -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 - -if not TORCH_VERSION_AT_LEAST_2_5: - pytest.skip("Unsupported PyTorch version", allow_module_level=True) - from torch.distributed._tensor import DTensor, Replicate, Shard, distribute_tensor from torch.distributed.device_mesh import DeviceMesh, init_device_mesh from torch.testing._internal.distributed._tensor.common_dtensor import ( diff --git a/test/float8/test_float8_utils.py b/test/float8/test_float8_utils.py index 888c7aadb1..c253af55ea 100644 --- a/test/float8/test_float8_utils.py +++ b/test/float8/test_float8_utils.py @@ -10,10 +10,6 @@ from torchao.float8.float8_utils import _round_scale_down_to_power_of_2 from torchao.testing.utils import skip_if_rocm -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 - -if not TORCH_VERSION_AT_LEAST_2_5: - pytest.skip("Unsupported PyTorch version", allow_module_level=True) # source for notable single-precision cases: diff --git a/test/float8/test_fsdp.py b/test/float8/test_fsdp.py index 3017c8b539..a25bd53509 100644 --- a/test/float8/test_fsdp.py +++ b/test/float8/test_fsdp.py @@ -16,13 +16,6 @@ import warnings import fire -import pytest - -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 - -if not TORCH_VERSION_AT_LEAST_2_5: - pytest.skip("Unsupported PyTorch version", allow_module_level=True) - import torch import torch.distributed as dist import torch.multiprocessing as mp diff --git a/test/float8/test_fsdp2/test_fsdp2.py b/test/float8/test_fsdp2/test_fsdp2.py index ef87e5fcda..e7b3b8be91 100644 --- a/test/float8/test_fsdp2/test_fsdp2.py +++ b/test/float8/test_fsdp2/test_fsdp2.py @@ -10,13 +10,6 @@ from typing import Any, List, Optional import pytest - -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_at_least_89 - -if not TORCH_VERSION_AT_LEAST_2_5: - pytest.skip("Unsupported PyTorch version", allow_module_level=True) - - import torch import torch._dynamo.testing import torch.distributed as dist @@ -47,6 +40,7 @@ check_parity_bf16_mp, check_parity_no_mp, ) +from torchao.utils import is_sm_at_least_89 if not is_sm_at_least_89(): pytest.skip("Unsupported CUDA device capability version", allow_module_level=True) diff --git a/test/float8/test_fsdp2_tp.py b/test/float8/test_fsdp2_tp.py index 8a735c5865..ea93d5949d 100644 --- a/test/float8/test_fsdp2_tp.py +++ b/test/float8/test_fsdp2_tp.py @@ -13,14 +13,7 @@ import copy import os -import pytest import torch - -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 - -if not TORCH_VERSION_AT_LEAST_2_5: - pytest.skip("Unsupported PyTorch version", allow_module_level=True) - from torch.distributed._composable.fsdp import fully_shard from torch.distributed.device_mesh import DeviceMesh, init_device_mesh from torch.distributed.tensor.parallel import parallelize_module diff --git a/test/float8/test_fsdp_compile.py b/test/float8/test_fsdp_compile.py index a78a30925c..eb32c40aa3 100644 --- a/test/float8/test_fsdp_compile.py +++ b/test/float8/test_fsdp_compile.py @@ -12,13 +12,6 @@ import warnings import fire -import pytest - -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 - -if not TORCH_VERSION_AT_LEAST_2_5: - pytest.skip("Unsupported PyTorch version", allow_module_level=True) - import torch import torch.distributed as dist import torch.multiprocessing as mp diff --git a/test/float8/test_numerics_integration.py b/test/float8/test_numerics_integration.py index db02444109..8da36cef8e 100644 --- a/test/float8/test_numerics_integration.py +++ b/test/float8/test_numerics_integration.py @@ -10,16 +10,6 @@ from typing import Optional import pytest - -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, - is_sm_at_least_89, - is_sm_at_least_90, -) - -if not TORCH_VERSION_AT_LEAST_2_5: - pytest.skip("Unsupported PyTorch version", allow_module_level=True) - import torch import torch.nn as nn import torch.nn.functional as F @@ -34,6 +24,10 @@ ) from torchao.float8.float8_utils import IS_ROCM, compute_error from torchao.testing.training.test_utils import get_test_float8_linear_config +from torchao.utils import ( + is_sm_at_least_89, + is_sm_at_least_90, +) torch.manual_seed(0) diff --git a/test/hqq/test_hqq_affine.py b/test/hqq/test_hqq_affine.py index a6990549a3..728bf9378b 100644 --- a/test/hqq/test_hqq_affine.py +++ b/test/hqq/test_hqq_affine.py @@ -15,9 +15,6 @@ uintx_weight_only, ) from torchao.testing.utils import skip_if_rocm -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_3, -) cuda_available = torch.cuda.is_available() @@ -78,7 +75,6 @@ def _eval_hqq(dtype): @unittest.skipIf(not cuda_available, "Need CUDA available") -@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "Need torch 2.3+") class TestHQQ(unittest.TestCase): def _test_hqq( self, dtype=None, ref_dequantize_error=None, ref_dot_product_error=None diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 5514228f4b..5c29f0b8ad 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -40,9 +40,7 @@ from torchao.quantization.quant_api import ( Float8DynamicActivationFloat8WeightConfig, _replace_with_custom_fn_if_matches_filter, - change_linear_weights_to_int4_woqtensors, change_linear_weights_to_int8_dqtensors, - change_linear_weights_to_int8_woqtensors, int4_weight_only, int8_dynamic_activation_int4_weight, int8_dynamic_activation_int8_weight, @@ -79,10 +77,6 @@ ) from torchao.testing.utils import skip_if_rocm from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_3, - TORCH_VERSION_AT_LEAST_2_4, - TORCH_VERSION_AT_LEAST_2_5, - TORCH_VERSION_AT_LEAST_2_6, TORCH_VERSION_AT_LEAST_2_7, benchmark_model, check_cpu_version, @@ -116,14 +110,7 @@ def _int8wo_api(mod): - if TORCH_VERSION_AT_LEAST_2_4: - quantize_(mod, int8_weight_only(set_inductor_config=False)) - if not TORCH_VERSION_AT_LEAST_2_5 or ( - not TORCH_VERSION_AT_LEAST_2_6 and torch._inductor.config.freezing - ): - unwrap_tensor_subclass(mod) - else: - change_linear_weights_to_int8_woqtensors(mod) + quantize_(mod, int8_weight_only(set_inductor_config=False)) def _int8wo_groupwise_api(mod): @@ -135,18 +122,13 @@ def _int8da_int8w_api( mod, act_mapping_type=MappingType.SYMMETRIC, ): - if TORCH_VERSION_AT_LEAST_2_4: - quantize_( - mod, - int8_dynamic_activation_int8_weight( - act_mapping_type=act_mapping_type, - set_inductor_config=False, - ), - ) - if not TORCH_VERSION_AT_LEAST_2_5: - unwrap_tensor_subclass(mod) - else: - change_linear_weights_to_int8_dqtensors(mod) + quantize_( + mod, + int8_dynamic_activation_int8_weight( + act_mapping_type=act_mapping_type, + set_inductor_config=False, + ), + ) def _int4wo_api(mod, use_hqq=False): @@ -163,18 +145,12 @@ def _int4wo_api(mod, use_hqq=False): mod, int4_weight_only(layout=Int4XPULayout()), set_inductor_config=False ) unwrap_tensor_subclass(mod) - elif TORCH_VERSION_AT_LEAST_2_4: - quantize_(mod, int4_weight_only(set_inductor_config=False)) - if not TORCH_VERSION_AT_LEAST_2_5: - unwrap_tensor_subclass(mod) else: - change_linear_weights_to_int4_woqtensors(mod) + quantize_(mod, int4_weight_only(set_inductor_config=False)) def _int8da_int4w_api(mod): quantize_(mod, int8_dynamic_activation_int4_weight(set_inductor_config=False)) - if not TORCH_VERSION_AT_LEAST_2_5: - unwrap_tensor_subclass(mod) # TODO: use this to reduce the number of tests @@ -393,7 +369,6 @@ def test_swap(self): assert torch.allclose(y_ref, y) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "newer dtypes not supported") def test_weight_t_and_non_t_numerics_match(self): # verify that numerics match whether weight is stored # in transposed format (for cuBLAS) vs non-transposed format @@ -710,8 +685,6 @@ def test_dequantize_int8_weight_only_quant_subclass(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") - # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") @skip_if_rocm("ROCm enablement in progress") def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype): if device == "cpu": @@ -730,8 +703,6 @@ def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") - # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") @skip_if_rocm("ROCm enablement in progress") def test_dequantize_int4_weight_only_quant_subclass_grouped(self, device, dtype): if device == "cpu": @@ -789,9 +760,6 @@ def _test_lin_weight_subclass_impl( ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf( - TORCH_VERSION_AT_LEAST_2_4, "skip because there is some bug in inductor codegen" - ) def test_int8_dynamic_quant_subclass(self, device, dtype): self._test_lin_weight_subclass_impl( Int8DynamicallyQuantizedLinearWeight.from_float, @@ -808,9 +776,6 @@ def test_int8_weight_only_quant_subclass(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch" - ) def test_aq_int8_dynamic_quant_subclass(self, device, dtype): self._test_lin_weight_subclass_impl( AQInt8DynamicallyQuantizedLinearWeight.from_float, @@ -820,9 +785,6 @@ def test_aq_int8_dynamic_quant_subclass(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch" - ) @unittest.skip( "This segfaults in CI cuda only, disable to unblock PR, we can investigate " "later if needed" @@ -836,9 +798,6 @@ def test_aq_int8_weight_only_quant_subclass(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch" - ) def test_aq_int8_weight_only_quant_2_subclass(self, device, dtype): self._test_lin_weight_subclass_impl( AQInt8WeightOnlyQuantizedLinearWeight2.from_float, @@ -848,9 +807,6 @@ def test_aq_int8_weight_only_quant_2_subclass(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch" - ) def test_aq_int8_weight_only_quant_3_subclass(self, device, dtype): self._test_lin_weight_subclass_impl( AQInt8WeightOnlyQuantizedLinearWeight3.from_float, @@ -860,9 +816,6 @@ def test_aq_int8_weight_only_quant_3_subclass(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch" - ) @unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run") def test_aq_float8_weight_only_quant_subclass(self, device, dtype): self._test_lin_weight_subclass_impl( @@ -892,9 +845,6 @@ def test_autoquantizable_flatten_unflatten(self): for device, dtype in COMMON_DEVICE_DTYPE ] ) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch" - ) @unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run") @unittest.skip("TODO this is not working correctly") def test_aq_float8_dynamic_quant_rowwise_scaling_subclass( @@ -919,9 +869,6 @@ def test_aq_float8_dynamic_quant_rowwise_scaling_subclass( ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch" - ) @unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run") @unittest.skip("TODO this is not working correctly") def test_aq_float8_dynamic_quant_tensorwise_scaling_subclass(self, device, dtype): @@ -933,8 +880,6 @@ def test_aq_float8_dynamic_quant_tensorwise_scaling_subclass(self, device, dtype ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") - # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") @skip_if_rocm("ROCm enablement in progress") def test_int4_weight_only_quant_subclass(self, device, dtype): if device == "cpu": @@ -953,8 +898,6 @@ def test_int4_weight_only_quant_subclass(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") - # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") @skip_if_rocm("ROCm enablement in progress") @unittest.skip("Skip to fix CI until we deprecate these APIs long term") def test_int4_weight_only_quant_subclass_grouped(self, device, dtype): @@ -1025,14 +968,8 @@ def _test_lin_weight_subclass_api_impl( ) ) ) + @unittest.skip("skip because there is some bug in inductor codegen") def test_int8_dynamic_quant_subclass_api(self, device, dtype, act_mapping): - if ( - not TORCH_VERSION_AT_LEAST_2_5 - and dtype in (torch.float16, torch.bfloat16) - and act_mapping is MappingType.ASYMMETRIC - and device == "cpu" - ): - self.skipTest("Inductor-CPU codegen issue fixed in torch 2.5") api = partial( _int8da_int8w_api, act_mapping_type=act_mapping, @@ -1042,12 +979,6 @@ def test_int8_dynamic_quant_subclass_api(self, device, dtype, act_mapping): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_int8_weight_only_quant_subclass_api(self, device, dtype): - if ( - not TORCH_VERSION_AT_LEAST_2_6 - and dtype in (torch.float16, torch.bfloat16) - and device == "cpu" - ): - self.skipTest("Regression fixed after torch 2.6") undo_recommended_configs() self._test_lin_weight_subclass_api_impl( _int8wo_api, device, 40, test_dtype=dtype @@ -1055,9 +986,6 @@ def test_int8_weight_only_quant_subclass_api(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @torch._inductor.config.patch({"freezing": True}) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "freeze requires torch 2.4 and after." - ) @skip_if_rocm("Test flaky on ROCm, under investigation") def test_int8_weight_only_quant_with_freeze(self, device, dtype): torch._dynamo.reset() @@ -1066,8 +994,6 @@ def test_int8_weight_only_quant_with_freeze(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") - # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") def test_int4_weight_only_quant_subclass_api(self, device, dtype): if dtype != torch.bfloat16: self.skipTest(f"Fails for {dtype}") @@ -1079,7 +1005,6 @@ def test_int4_weight_only_quant_subclass_api(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "int4 hqq requires torch nightly.") def test_int4_weight_only_hqq_quant_subclass_api(self, device, dtype): if dtype != torch.bfloat16: self.skipTest(f"Fails for {dtype}") @@ -1093,9 +1018,6 @@ def test_int4_weight_only_hqq_quant_subclass_api(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_5, "gemlite tests needs torch 2.5 or greater" - ) @unittest.skipIf(not has_gemlite, "gemlite not available") def test_gemlite_layout(self, device, dtype): if dtype != torch.float16: @@ -1139,8 +1061,6 @@ def test_gemlite_layout(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") - # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") @skip_if_rocm("ROCm enablement in progress") def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype): if dtype != torch.bfloat16: @@ -1162,16 +1082,9 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype): def api(mod): kwargs_copy = kwargs.copy() - if TORCH_VERSION_AT_LEAST_2_4: - kwargs_copy["group_size"] = groupsize - del kwargs_copy["groupsize"] - quantize_(mod, int4_weight_only(**kwargs_copy)) - if not TORCH_VERSION_AT_LEAST_2_5: - unwrap_tensor_subclass(mod) - else: - kwargs_copy["inner_k_tiles"] = inner_k_tiles - del kwargs_copy["layout"] - change_linear_weights_to_int4_woqtensors(mod, **kwargs_copy) + kwargs_copy["group_size"] = groupsize + del kwargs_copy["groupsize"] + quantize_(mod, int4_weight_only(**kwargs_copy)) self._test_lin_weight_subclass_api_impl( api, @@ -1252,11 +1165,7 @@ def test_weight_only_quant_force_mixed_mm(self, device, dtype): self.skipTest("test requires SM capability of at least (8, 0).") from torch._inductor import config - mixed_mm_key, mixed_mm_val = ( - ("mixed_mm_choice", "triton") - if TORCH_VERSION_AT_LEAST_2_5 - else ("force_mixed_mm", True) - ) + mixed_mm_key, mixed_mm_val = ("mixed_mm_choice", "triton") with config.patch( { @@ -1289,11 +1198,7 @@ def test_weight_only_quant_use_mixed_mm(self, device, dtype): torch.manual_seed(0) from torch._inductor import config - mixed_mm_key, mixed_mm_val = ( - ("mixed_mm_choice", "triton") - if TORCH_VERSION_AT_LEAST_2_5 - else ("force_mixed_mm", True) - ) + mixed_mm_key, mixed_mm_val = ("mixed_mm_choice", "triton") with config.patch( { @@ -1395,18 +1300,10 @@ def test_save_load_dqtensors(self, device, dtype): @torch.no_grad() @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_save_load_int8woqtensors(self, device, dtype): - if ( - not TORCH_VERSION_AT_LEAST_2_6 - and dtype in (torch.float16, torch.bfloat16) - and device == "cpu" - ): - self.skipTest("Regression fixed after torch 2.6") undo_recommended_configs() self._test_handle_save_load_meta_impl(_int8wo_api, device, test_dtype=dtype) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch 2.3+.") - # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 doesn't work for 2.5+ right now") @torch.no_grad() def test_save_load_int4woqtensors(self, device, dtype): if dtype != torch.bfloat16: @@ -1416,9 +1313,6 @@ def test_save_load_int4woqtensors(self, device, dtype): class TorchCompileUnitTest(unittest.TestCase): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_3, "fullgraph requires torch nightly." - ) def test_fullgraph(self): lin_fp16 = nn.Linear(32, 16, device="cuda", dtype=torch.float16) lin_smooth = SmoothFakeDynamicallyQuantizedLinear.from_float( @@ -1467,7 +1361,7 @@ def test_shape_logger(self): class SmoothquantIntegrationTest(unittest.TestCase): @torch.no_grad() @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "newer dtypes not supported") + @unittest.skip("Seg fault?") def test_non_dynamically_quantizable_linear(self): if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0): self.skipTest("test requires SM capability of at least (8, 0).") @@ -1562,7 +1456,6 @@ class TestAutoQuant(unittest.TestCase): ], ) ) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "autoquant requires 2.3+.") def test_autoquant_one_input(self, device, dtype, m, k, n): undo_recommended_configs() print("(m, k, n): ", (m, k, n)) @@ -1604,7 +1497,6 @@ def test_autoquant_one_input(self, device, dtype, m, k, n): ], ) ) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant requires 2.5+.") def test_autoquant_compile(self, device, dtype, m1, m2, k, n): undo_recommended_configs() @@ -1626,9 +1518,6 @@ def test_autoquant_compile(self, device, dtype, m1, m2, k, n): if m1 == 1 or m2 == 1: self.skipTest(f"Shape {(m1, m2, k, n)} requires sm80+") - # Skip certain shapes on older PyTorch versions - if (m1 == 1 or m2 == 1) and not TORCH_VERSION_AT_LEAST_2_5: - self.skipTest(f"Shape {(m1, m2, k, n)} requires torch version > 2.4") # TODO remove this once https://github.com/pytorch/pytorch/issues/155838 is resolved if m1 == 1 or m2 == 1: self.skipTest(f"Shape {(m1, m2, k, n)} is flaky, skipping") @@ -1657,7 +1546,6 @@ def test_autoquant_compile(self, device, dtype, m1, m2, k, n): self.assertTrue(sqnr >= 30) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant requires 2.5+.") def test_autoquant_mha(self, device, dtype): if device != "cuda" or not torch.cuda.is_available(): self.skipTest(f"autoquant currently does not support {device}") @@ -1685,7 +1573,6 @@ def forward(self, x): assert len(_AUTOQUANT_CACHE) > 0 @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant requires 2.5+.") def test_autoquant_manual(self, device, dtype): undo_recommended_configs() if device != "cuda" or not torch.cuda.is_available(): @@ -1735,7 +1622,6 @@ def test_autoquant_manual(self, device, dtype): ], ) ) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant requires 2.5+.") def test_autoquant_kwargs(self, device, dtype, m1, m2, k, n): undo_recommended_configs() if device != "cuda" or not torch.cuda.is_available(): @@ -1745,9 +1631,27 @@ def test_autoquant_kwargs(self, device, dtype, m1, m2, k, n): self.skipTest("bfloat16 requires sm80+") if m1 == 1 or m2 == 1: self.skipTest(f"Shape {(m1, m2, k, n)} requires sm80+") - # This test fails on v0.4.0 and torch 2.4, so skipping for now. - if m1 == 1 or m2 == 1 and not TORCH_VERSION_AT_LEAST_2_5: - self.skipTest(f"Shape {(m1, m2, k, n)} requires torch version > 2.4") + + # Note: This test was incorrectly written before with this skip condition: + # + # m1 == 1 or m2 == 1 and not TORCH_VERSION_AT_LEAST_2_5: + # + # This is actually equivalent to: + # + # m1 == 1 or (m2 == 1 and not TORCH_VERSION_AT_LEAST_2_5) + # + # which means we always skips the test as long as `m1 == 1` regardless of + # the pytorch version, which was not the intended behavior. Unfortunately, + # unskipping this test now leads to the following error when calling + # `aten._int_mm`: + # + # RuntimeError: self.size(0) needs to be greater than 16, but got 1 + # + # Therefore, we keep around this skip condition for now since it doesn't + # change the test behavior from before. For more details, please see + # https://github.com/pytorch/ao/pull/2720. + if m1 == 1: + self.skipTest(f"Shape {(m1, m2, k, n)} is not supported") class NeedsKwargs(torch.nn.Module): def __init__(self): @@ -1782,7 +1686,6 @@ def forward(self, x, y): ], ) ) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "autoquant requires 2.3+.") def test_autoquant_double_access(self, device, dtype, m, k, n): undo_recommended_configs() if device != "cuda" or not torch.cuda.is_available(): @@ -1835,9 +1738,6 @@ def test_autoquant_min_sqnr(self, device, dtype): self.assertTrue(sqnr >= 50, f"sqnr: {sqnr}") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "autoquant float option requires 2.4+." - ) def test_autoquant_hp_float(self): device = "cuda" dtype = torch.float32 @@ -1868,9 +1768,6 @@ def test_autoquant_hp_float(self): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_5, "autoquant int4 option requires 2.5+." - ) @unittest.skipIf(not has_gemlite, "gemlite not available") def test_autoquant_int4wo(self, device, dtype): if device == "cpu": @@ -1906,9 +1803,6 @@ def test_autoquant_int4wo(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not is_sm_at_least_90(), "Need cuda arch greater than SM90") - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_5, "autoquant int4 option requires 2.5+." - ) @unittest.skipIf( True, "Skipping for now, do to lowering bug in inductor" ) # TODO unblock when fixed @@ -1948,7 +1842,6 @@ def test_autoquant_float8(self, device, dtype): self.assertGreater(compute_error(ref, out), 20) -@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "requires 2.5+.") @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") @unittest.skip( "AOTI tests are failing right now, repro by commenting out the skip and run:" @@ -2011,7 +1904,6 @@ def forward(self, x): ) -@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "requires 2.5+.") @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") class TestExport(unittest.TestCase): @parameterized.expand( @@ -2067,12 +1959,9 @@ def forward(self, x): # TODO: export changes numerics right now, this is because of functionalization according to Zhengxu # we can re-enable this after non-functional IR is enabled in export # model = torch.export.export(model, example_inputs).module() - if TORCH_VERSION_AT_LEAST_2_5: - model = torch.export.export_for_training( - model, example_inputs, strict=True - ).module() - else: - model = torch._export.capture_pre_autograd_graph(model, example_inputs) + model = torch.export.export_for_training( + model, example_inputs, strict=True + ).module() after_export = model(x) self.assertTrue(torch.equal(after_export, ref)) if api is _int8da_int4w_api: @@ -2111,7 +2000,6 @@ class TestUtils(unittest.TestCase): @parameterized.expand( list(itertools.product(TENSOR_SUBCLASS_APIS, COMMON_DEVICES, COMMON_DTYPES)), ) - # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") def test_get_model_size_aqt(self, api, test_device, test_dtype): if test_dtype != torch.bfloat16: self.skipTest(f"{api} in {test_dtype} is not supported yet") diff --git a/test/prototype/moe_training/test_kernels.py b/test/prototype/moe_training/test_kernels.py index b24b61be8c..a10f41e696 100644 --- a/test/prototype/moe_training/test_kernels.py +++ b/test/prototype/moe_training/test_kernels.py @@ -7,15 +7,9 @@ import pytest import torch -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 - # We need to skip before doing any imports which would use triton, since -# triton won't be available on CPU builds and torch < 2.5 -if not ( - TORCH_VERSION_AT_LEAST_2_5 - and torch.cuda.is_available() - and torch.cuda.get_device_capability()[0] >= 9 -): +# triton won't be available on CPU builds +if not (torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 9): pytest.skip("Unsupported PyTorch version", allow_module_level=True) diff --git a/test/prototype/test_autoround.py b/test/prototype/test_autoround.py index 483704a28c..cf7f956a11 100644 --- a/test/prototype/test_autoround.py +++ b/test/prototype/test_autoround.py @@ -25,7 +25,6 @@ prepare_model_for_applying_auto_round_, ) from torchao.prototype.autoround.multi_tensor import MultiTensor -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 _AVAILABLE_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) @@ -92,9 +91,6 @@ def _check_params_and_buffers_type(module, check_fun): class TestAutoRound(TestCase): @pytest.mark.skip("these tests are broken on main branch") - @pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_5, reason="Requires torch 2.5 or later" - ) @parametrize("device", _AVAILABLE_DEVICES) @torch.no_grad() def test_auto_round(self, device: str): @@ -136,9 +132,6 @@ def test_auto_round(self, device: str): assert after_quant is not None, "Quantized model forward pass failed" @pytest.mark.skip("these tests are broken on main branch") - @pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_5, reason="Requires torch 2.5 or later" - ) @parametrize("device", _AVAILABLE_DEVICES) @torch.no_grad() def test_wrap_model_with_multi_tensor(self, device: str): diff --git a/test/prototype/test_awq.py b/test/prototype/test_awq.py index 5538fa513d..181445470e 100644 --- a/test/prototype/test_awq.py +++ b/test/prototype/test_awq.py @@ -15,10 +15,7 @@ from torchao.prototype.awq import AWQConfig, AWQStep from torchao.quantization import FbgemmConfig, Int4WeightOnlyConfig, quantize_ -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_6, - _is_fbgemm_genai_gpu_available, -) +from torchao.utils import _is_fbgemm_genai_gpu_available class ToyLinearModel(torch.nn.Module): @@ -50,10 +47,6 @@ def forward(self, x): not _is_fbgemm_genai_gpu_available(), reason="need to install fbgemm_gpu_genai package", ) -@unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_6, - reason="torch.int4 needs torch 2.6+, can remove after we are not using FbgemmConfig", -) class TestAWQ(TestCase): def test_awq_config(self): base_config = Int4WeightOnlyConfig() diff --git a/test/prototype/test_codebook_coreml.py b/test/prototype/test_codebook_coreml.py index 69956c7729..a9519f7321 100644 --- a/test/prototype/test_codebook_coreml.py +++ b/test/prototype/test_codebook_coreml.py @@ -14,7 +14,7 @@ ) from torchao.quantization import quantize_ from torchao.quantization.utils import compute_error -from torchao.utils import TORCH_VERSION_AT_LEAST_2_6, is_package_at_least +from torchao.utils import is_package_at_least @unittest.skipIf( @@ -75,7 +75,6 @@ def test_quantize_api(self): ) assert type(m[0].weight) == CodebookQuantizedTensor - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "requires 2.6+.") def test_export(self): m = torch.nn.Sequential(torch.nn.Linear(128, 64)).to(torch.float32) quantize_(m, CodebookWeightOnlyConfig(self.code_dtype, self.block_size)) diff --git a/test/prototype/test_parq.py b/test/prototype/test_parq.py index 6ceeb0d795..85a6e2b0c2 100644 --- a/test/prototype/test_parq.py +++ b/test/prototype/test_parq.py @@ -42,11 +42,7 @@ quantize_, ) from torchao.quantization.quant_primitives import MappingType -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_4, - TORCH_VERSION_AT_LEAST_2_6, - check_cpu_version, -) +from torchao.utils import check_cpu_version _DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -198,7 +194,6 @@ class TestUnifTorchaoQuantizer(common_utils.TestCase): def setUp(self): torch.manual_seed(123) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @common_utils.parametrize("group_size", [32, 256]) def test_int4_weight_only(self, group_size: int = 32): model = M(m=512, n=512).to(_DEVICE, dtype=torch.bfloat16) @@ -215,7 +210,6 @@ def test_int4_weight_only(self, group_size: int = 32): model, m_ref, Int4UnifTorchaoQuantizer(), b, group_size ) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+") @common_utils.parametrize("b", [2, 3, 4, 8]) @common_utils.parametrize("group_size", [32, 512]) def test_intx_weight_only(self, b: int = 2, group_size: int = 32): @@ -233,7 +227,6 @@ def test_intx_weight_only(self, b: int = 2, group_size: int = 32): quantizer = UnifTorchaoQuantizer() compare_quantized_models(model, m_ref, quantizer, b, group_size) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @unittest.skipIf(_DEVICE == "cpu", "Need GPU available") def test_int4_weight_only_e2e(self, group_size: int = 32): model = M(m=512, n=512).to(torch.bfloat16).to(_DEVICE) @@ -255,7 +248,6 @@ def test_int4_weight_only_e2e(self, group_size: int = 32): ) compare_parq_convert(model, m_ref, optimizer, config) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+") @unittest.skipIf(_DEVICE == "cpu", "Need GPU available") @common_utils.parametrize("b", [2, 3, 4, 8]) def test_intx_weight_only_e2e(self, b: int = 2, group_size: int = 32): @@ -305,7 +297,6 @@ def test_intx_weight_only_parq_equivalent(self, b: int = 2, group_size: int = 32 torch.testing.assert_close(q, q_ref, atol=0, rtol=0) torch.testing.assert_close(Q, Q_ref, atol=0, rtol=0) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+") @common_utils.parametrize("b", [2, 3]) @common_utils.parametrize("group_size", [32, 512]) def test_intx_weight_only(self, b: int = 2, group_size: int = 32): @@ -327,7 +318,6 @@ def test_intx_weight_only(self, b: int = 2, group_size: int = 32): compare_quantized_models(model, m_ref, quantizer, b, group_size) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+") @unittest.skipIf(_DEVICE == "cpu", "Need GPU available") @common_utils.parametrize("b", [2, 3]) def test_intx_weight_only_e2e(self, b: int = 2, group_size: int = 32): @@ -359,7 +349,6 @@ class TestInt8DynamicActivationTorchaoQuantizer(common_utils.TestCase): def setUp(self): torch.manual_seed(123) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+") @common_utils.parametrize("b", [2, 3, 4, 8]) @common_utils.parametrize("model_dtype", [torch.float16, torch.float32]) @common_utils.parametrize("group_size", [32, 128]) diff --git a/test/prototype/test_quantized_training.py b/test/prototype/test_quantized_training.py index c9d51389d1..836e2c302e 100644 --- a/test/prototype/test_quantized_training.py +++ b/test/prototype/test_quantized_training.py @@ -3,15 +3,9 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -import pytest - -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_6 - -if not TORCH_VERSION_AT_LEAST_2_4: - pytest.skip("Requires torch>=2.4", allow_module_level=True) - import copy +import pytest import torch import torch.distributed as dist import torch.nn.functional as F @@ -312,21 +306,19 @@ def test_fsdp2_correctness(self): (bitnet_training(), mp_policy, 1e-5), ] - # FSDP2 mixed-precision requires https://github.com/pytorch/pytorch/pull/136129 - if TORCH_VERSION_AT_LEAST_2_6: - # It's complicated (though possible) to simulate FSDP BF16 mixed-precision for base_model. - # We would need to cast all params to BF16 in forward and backward pass, while keeping - # the params in FP32 for optim step. - # torch.autocast() will only do this for F.linear() layer (and its backward). - # To keep it simple, we just use a larger tolerance here. - bf16_mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16) - - extra_args = [ - (int8_weight_only_quantized_training(), bf16_mp_policy, 1e-2), - (int8_mixed_precision_training(), bf16_mp_policy, 1e-2), - (bitnet_training(), bf16_mp_policy, 1e-2), - ] - test_args.extend(extra_args) + # It's complicated (though possible) to simulate FSDP BF16 mixed-precision for base_model. + # We would need to cast all params to BF16 in forward and backward pass, while keeping + # the params in FP32 for optim step. + # torch.autocast() will only do this for F.linear() layer (and its backward). + # To keep it simple, we just use a larger tolerance here. + bf16_mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16) + + extra_args = [ + (int8_weight_only_quantized_training(), bf16_mp_policy, 1e-2), + (int8_mixed_precision_training(), bf16_mp_policy, 1e-2), + (bitnet_training(), bf16_mp_policy, 1e-2), + ] + test_args.extend(extra_args) self.run_subtests({"args": test_args}, self._run_subtest) diff --git a/test/prototype/test_smoothquant.py b/test/prototype/test_smoothquant.py index 568b2d964f..85893f2241 100644 --- a/test/prototype/test_smoothquant.py +++ b/test/prototype/test_smoothquant.py @@ -22,9 +22,6 @@ dequantize_per_channel, dynamically_quantize_per_channel, ) -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, -) class ToyLinearModel(torch.nn.Module): @@ -56,9 +53,8 @@ class TestSmoothQuant(unittest.TestCase): @classmethod def setUpClass(cls): """Set up class-level configuration for tests.""" - if TORCH_VERSION_AT_LEAST_2_5: - # This test case will trigger recompilation many times, so set a large cache_size_limit here - torch._dynamo.config.cache_size_limit = 128 + # This test case will trigger recompilation many times, so set a large cache_size_limit here + torch._dynamo.config.cache_size_limit = 128 @unittest.skip("This test is broken on recent PyTorch, TODO(#1639): fix it") @common_utils.parametrize("bias", [True, False]) @@ -96,8 +92,7 @@ def forward(self, x): quantize_(m, SmoothQuantConfig(), is_observed_linear) # Apply compilation if supported - if TORCH_VERSION_AT_LEAST_2_5: - m = torch.compile(m, fullgraph=True) + m = torch.compile(m, fullgraph=True) # Step 2: Inference quantized model with torch.inference_mode(): @@ -213,8 +208,7 @@ def test_save_load_recipe(self, alpha, quant_mode, device, input_dtype): quantize_(m, SmoothQuantConfig(), is_observed_linear) # Apply compilation if supported - if TORCH_VERSION_AT_LEAST_2_5: - m = torch.compile(m, fullgraph=True) + m = torch.compile(m, fullgraph=True) # Step 2: Setup save/load model with recipe functionality insert_smooth_quant_observer_(m_save_load, alpha, quant_mode) @@ -231,8 +225,7 @@ def test_save_load_recipe(self, alpha, quant_mode, device, input_dtype): is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear) quantize_(m_save_load, SmoothQuantConfig(), is_observed_linear) - if TORCH_VERSION_AT_LEAST_2_5: - m_save_load = torch.compile(m_save_load, fullgraph=True) + m_save_load = torch.compile(m_save_load, fullgraph=True) # Step 5: Validate outputs on full dataset with torch.inference_mode(): diff --git a/test/quantization/pt2e/test_arm_inductor_quantizer.py b/test/quantization/pt2e/test_arm_inductor_quantizer.py index 750e88d451..4c3b397382 100644 --- a/test/quantization/pt2e/test_arm_inductor_quantizer.py +++ b/test/quantization/pt2e/test_arm_inductor_quantizer.py @@ -6,12 +6,23 @@ # Owner(s): ["oncall: quantization"] import copy +import functools import itertools +import platform import unittest from enum import Enum import torch import torch.nn as nn +from torch.export import export_for_training +from torch.testing._internal.common_quantization import ( + NodeSpec as ns, +) +from torch.testing._internal.common_quantization import ( + QuantizationTestCase, + skipIfNoInductorSupport, +) +from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo import torchao.quantization.pt2e.quantizer.arm_inductor_quantizer as armiq from torchao.quantization.pt2e import ObserverBase @@ -26,22 +37,7 @@ from torchao.quantization.pt2e.quantizer.x86_inductor_quantizer import ( QUANT_ANNOTATION_KEY, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_7 - -if TORCH_VERSION_AT_LEAST_2_5: - from torch.export import export_for_training - -import functools -import platform - -from torch.testing._internal.common_quantization import ( - NodeSpec as ns, -) -from torch.testing._internal.common_quantization import ( - QuantizationTestCase, - skipIfNoInductorSupport, -) -from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo +from torchao.utils import TORCH_VERSION_AT_LEAST_2_7 def skipIfNoArm(fn): diff --git a/test/quantization/pt2e/test_duplicate_dq.py b/test/quantization/pt2e/test_duplicate_dq.py index a1b43b4f3a..8430f605e1 100644 --- a/test/quantization/pt2e/test_duplicate_dq.py +++ b/test/quantization/pt2e/test_duplicate_dq.py @@ -11,6 +11,7 @@ from typing import Any import torch +from torch.export import export_for_training from torch.testing._internal.common_quantization import QuantizationTestCase from torch.testing._internal.common_utils import IS_WINDOWS, run_tests @@ -33,10 +34,7 @@ OP_TO_ANNOTATOR, QuantizationConfig, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_7 - -if TORCH_VERSION_AT_LEAST_2_5: - from torch.export import export_for_training +from torchao.utils import TORCH_VERSION_AT_LEAST_2_7 class TestHelperModules: diff --git a/test/quantization/pt2e/test_quantize_pt2e.py b/test/quantization/pt2e/test_quantize_pt2e.py index 19f208a55c..0c1a1f23c9 100644 --- a/test/quantization/pt2e/test_quantize_pt2e.py +++ b/test/quantization/pt2e/test_quantize_pt2e.py @@ -19,6 +19,7 @@ per_channel_weight_observer_range_neg_127_to_127, weight_observer_range_neg_127_to_127, ) +from torch.export import export_for_training from torch.fx import Node from torch.testing._internal.common_quantization import ( NodeSpec as ns, @@ -66,11 +67,7 @@ QuantizationConfig, ) from torchao.testing.pt2e.utils import PT2EQuantizationTestCase -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_7 - -if TORCH_VERSION_AT_LEAST_2_5: - from torch.export import export_for_training - +from torchao.utils import TORCH_VERSION_AT_LEAST_2_7 DEVICE_LIST = ["cpu"] + (["cuda"] if TEST_CUDA else []) diff --git a/test/quantization/pt2e/test_quantize_pt2e_qat.py b/test/quantization/pt2e/test_quantize_pt2e_qat.py index d8a2c8df03..e0a51453a9 100644 --- a/test/quantization/pt2e/test_quantize_pt2e_qat.py +++ b/test/quantization/pt2e/test_quantize_pt2e_qat.py @@ -18,6 +18,7 @@ default_symmetric_qnnpack_qat_qconfig, ) from torch.ao.quantization.quantize_fx import prepare_qat_fx +from torch.export import export_for_training from torch.testing._internal.common_cuda import TEST_CUDA from torch.testing._internal.common_quantization import ( NodeSpec as ns, @@ -51,10 +52,7 @@ XNNPACKQuantizer, get_symmetric_quantization_config, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_7 - -if TORCH_VERSION_AT_LEAST_2_5: - from torch.export import export_for_training +from torchao.utils import TORCH_VERSION_AT_LEAST_2_7 class PT2EQATTestCase(QuantizationTestCase): diff --git a/test/quantization/pt2e/test_representation.py b/test/quantization/pt2e/test_representation.py index 2123995a4b..abe79a08e3 100644 --- a/test/quantization/pt2e/test_representation.py +++ b/test/quantization/pt2e/test_representation.py @@ -11,6 +11,7 @@ import torch from torch._higher_order_ops.out_dtype import out_dtype # noqa: F401 +from torch.export import export_for_training from torch.testing._internal.common_quantization import ( NodeSpec as ns, ) @@ -27,10 +28,7 @@ XNNPACKQuantizer, get_symmetric_quantization_config, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_7 - -if TORCH_VERSION_AT_LEAST_2_5: - from torch.export import export_for_training +from torchao.utils import TORCH_VERSION_AT_LEAST_2_7 @skipIfNoQNNPACK diff --git a/test/quantization/pt2e/test_x86inductor_fusion.py b/test/quantization/pt2e/test_x86inductor_fusion.py index ffaa4573d8..42439552c6 100644 --- a/test/quantization/pt2e/test_x86inductor_fusion.py +++ b/test/quantization/pt2e/test_x86inductor_fusion.py @@ -26,6 +26,7 @@ IS_FBCODE, IS_LINUX, IS_X86, + TEST_ACL, instantiate_parametrized_tests, parametrize, ) @@ -45,15 +46,7 @@ X86InductorQuantizer, ) from torchao.testing.utils import skip_if_rocm -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_6, - TORCH_VERSION_AT_LEAST_2_8, -) - -if TORCH_VERSION_AT_LEAST_2_6: - from torch.testing._internal.common_utils import TEST_ACL -else: - TEST_ACL = False +from torchao.utils import TORCH_VERSION_AT_LEAST_2_8 # The dict value is match_nodes(computation_op+unary_op) unary_list = { diff --git a/test/quantization/pt2e/test_x86inductor_quantizer.py b/test/quantization/pt2e/test_x86inductor_quantizer.py index 4476b18697..9dc7da3571 100644 --- a/test/quantization/pt2e/test_x86inductor_quantizer.py +++ b/test/quantization/pt2e/test_x86inductor_quantizer.py @@ -12,6 +12,7 @@ import torch import torch.nn as nn +from torch.export import export_for_training from torch.testing._internal.common_quantization import ( NodeSpec as ns, ) @@ -35,10 +36,7 @@ QUANT_ANNOTATION_KEY, X86InductorQuantizer, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_7 - -if TORCH_VERSION_AT_LEAST_2_5: - from torch.export import export_for_training +from torchao.utils import TORCH_VERSION_AT_LEAST_2_7 class NodePosType(Enum): diff --git a/test/quantization/test_gptq.py b/test/quantization/test_gptq.py index 98760f8cf6..163819bea7 100644 --- a/test/quantization/test_gptq.py +++ b/test/quantization/test_gptq.py @@ -12,9 +12,6 @@ from torchao._models.llama.tokenizer import get_tokenizer from torchao.quantization import Int4WeightOnlyConfig, quantize_ from torchao.quantization.utils import compute_error -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_4, -) torch.manual_seed(0) @@ -101,7 +98,6 @@ def test_gptq_quantizer_int4_weight_only(self): class TestMultiTensorFlow(TestCase): - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_multitensor_add_tensors(self): from torchao.quantization.GPTQ import MultiTensor @@ -114,7 +110,6 @@ def test_multitensor_add_tensors(self): self.assertTrue(torch.equal(mt.values[0], tensor1)) self.assertTrue(torch.equal(mt.values[1], tensor2)) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_multitensor_pad_unpad(self): from torchao.quantization.GPTQ import MultiTensor @@ -126,7 +121,6 @@ def test_multitensor_pad_unpad(self): mt.unpad() self.assertEqual(mt.count, 1) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_multitensor_inplace_operation(self): from torchao.quantization.GPTQ import MultiTensor diff --git a/test/quantization/test_marlin_qqq.py b/test/quantization/test_marlin_qqq.py index 8fe21c6bd3..56b309b948 100644 --- a/test/quantization/test_marlin_qqq.py +++ b/test/quantization/test_marlin_qqq.py @@ -24,7 +24,6 @@ _choose_qparams_and_quantize_affine_qqq, ) from torchao.testing.utils import skip_if_rocm -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 @skip_if_rocm("ROCm enablement in progress") @@ -67,7 +66,6 @@ def test_marlin_qqq(self): "Results are not close" ) - @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="Needs PyTorch 2.5+") @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") @skip_if_rocm("ROCm development in progress") def test_marlin_qqq_compile(self): diff --git a/test/quantization/test_moe_quant.py b/test/quantization/test_moe_quant.py index 425b881dba..fae4d8e41e 100644 --- a/test/quantization/test_moe_quant.py +++ b/test/quantization/test_moe_quant.py @@ -27,11 +27,7 @@ quantize_, ) from torchao.quantization.utils import compute_error -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, - TORCH_VERSION_AT_LEAST_2_6, - is_sm_at_least_90, -) +from torchao.utils import is_sm_at_least_90 if torch.version.hip is not None: pytest.skip( @@ -116,8 +112,6 @@ def _test_impl_moe_quant( def test_int4wo_fake_dim(self, name, num_tokens, fullgraph): if not torch.cuda.is_available(): self.skipTest("Need CUDA available") - if not TORCH_VERSION_AT_LEAST_2_5: - self.skipTest("Test only enabled for 2.5+") config = MoEQuantConfig( Int4WeightOnlyConfig(), use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE @@ -142,8 +136,6 @@ def test_int4wo_base(self, name, num_tokens, fullgraph): self.skipTest("Need CUDA available") if not is_sm_at_least_90(): self.skipTest("Requires CUDA capability >= 9.0") - if not TORCH_VERSION_AT_LEAST_2_5: - self.skipTest("Test only enabled for 2.5+") config = MoEQuantConfig(Int4WeightOnlyConfig()) tensor_impl_class = TensorCoreTiledAQTTensorImpl @@ -164,8 +156,6 @@ def test_int4wo_base(self, name, num_tokens, fullgraph): def test_int8wo_fake_dim(self, name, num_tokens, fullgraph): if not torch.cuda.is_available(): self.skipTest("Need CUDA available") - if not TORCH_VERSION_AT_LEAST_2_5: - self.skipTest("Test only enabled for 2.5+") config = MoEQuantConfig( Int8WeightOnlyConfig(), use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE @@ -188,8 +178,6 @@ def test_int8wo_fake_dim(self, name, num_tokens, fullgraph): def test_int8wo_base(self, name, num_tokens, fullgraph): if not torch.cuda.is_available(): self.skipTest("Need CUDA available") - if not TORCH_VERSION_AT_LEAST_2_6: - self.skipTest("Test only enabled for 2.6+") config = MoEQuantConfig(Int8WeightOnlyConfig()) tensor_impl_class = PlainAQTTensorImpl @@ -208,9 +196,6 @@ def test_int8wo_base(self, name, num_tokens, fullgraph): ] ) def test_int8wo_base_cpu(self, name, num_tokens, fullgraph): - if not TORCH_VERSION_AT_LEAST_2_6: - self.skipTest("Test only enabled for 2.6+") - config = MoEQuantConfig(Int8WeightOnlyConfig()) tensor_impl_class = PlainAQTTensorImpl @@ -230,8 +215,6 @@ def test_int8wo_base_cpu(self, name, num_tokens, fullgraph): def test_int8dq_fake_dim(self, name, num_tokens, fullgraph): if not torch.cuda.is_available(): self.skipTest("Need CUDA available") - if not TORCH_VERSION_AT_LEAST_2_5: - self.skipTest("Test only enabled for 2.5+") config = MoEQuantConfig( Int8DynamicActivationInt8WeightConfig(), @@ -255,8 +238,6 @@ def test_int8dq_fake_dim(self, name, num_tokens, fullgraph): def test_int8dq_base(self, name, num_tokens, fullgraph): if not torch.cuda.is_available(): self.skipTest("Need CUDA available") - if not TORCH_VERSION_AT_LEAST_2_5: - self.skipTest("Test only enabled for 2.5+") config = MoEQuantConfig(Int8DynamicActivationInt8WeightConfig()) base_class = LinearActivationQuantizedTensor diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index b1d4a097d0..d1ebc5cc88 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -83,11 +83,6 @@ get_groupwise_affine_qparams, groupwise_affine_quantize_tensor, ) -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_3, - TORCH_VERSION_AT_LEAST_2_4, - TORCH_VERSION_AT_LEAST_2_6, -) # TODO: put this in a common test utils file _CUDA_IS_AVAILABLE = torch.cuda.is_available() @@ -201,9 +196,6 @@ def forward(self, x): class TestQAT(unittest.TestCase): SEED = 123 - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_fake_quantize_per_channel_group(self): n_bit = 4 (qmin, qmax) = _get_qmin_qmax(n_bit) @@ -248,9 +240,6 @@ def test_fake_quantize_per_channel_group(self): ) torch.testing.assert_close(out, out_ptq, atol=0, rtol=0) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_fake_quantize_per_token(self): (qmin, qmax) = _get_qmin_qmax(8) @@ -348,9 +337,6 @@ def _set_ptq_weight( else: raise ValueError("Unknown ptq_linear type: %s" % type(ptq_linear)) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_qat_8da4w_linear(self): from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear from torchao.quantization.qat.linear import Int8DynActInt4WeightQATLinear @@ -381,9 +367,6 @@ def test_qat_8da4w_linear(self): ptq_out = ptq_linear(x2) torch.testing.assert_close(ptq_out, qat_out, atol=0, rtol=0) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_qat_8da4w_quantizer(self): from torchao.quantization.GPTQ import Int8DynActInt4WeightQuantizer from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer @@ -419,9 +402,6 @@ def test_qat_8da4w_quantizer(self): ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0 ) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_qat_8da4w_quantizer_meta_weights(self): from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer @@ -433,9 +413,6 @@ def test_qat_8da4w_quantizer_meta_weights(self): qat_model = qat_quantizer.prepare(m) self.assertTrue(all(v.is_meta for v in qat_model.state_dict().values())) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_qat_8da4w_quantizer_disable_fake_quant(self): """ Test that 8da4w QAT with disabled fake quant matches nn.Linear in forward. @@ -494,9 +471,6 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self): qat_out2 = qat_model2(*x2) torch.testing.assert_close(qat_out, qat_out2, atol=0, rtol=0) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_qat_8da4w_quantizer_disable_fake_quant_backward(self): """ Test that 8da4w QAT with disabled fake quant matches nn.Linear in backward. @@ -593,9 +567,6 @@ def _test_qat_quantized_gradients(self, quantizer): optimizer.step() current_step += 1 - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_qat_8da4w_quantizer_gradients(self): from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer @@ -662,9 +633,6 @@ def test_qat_4w_primitives(self): self._assert_close_4w(qat_out, ptq_out) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") def test_qat_4w_linear(self): from torchao.quantization.GPTQ import WeightOnlyInt4Linear @@ -700,18 +668,12 @@ def test_qat_4w_linear(self): ptq_out = ptq_linear(x2) self._assert_close_4w(qat_out, ptq_out) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_qat_4w_quantizer_gradients(self): from torchao.quantization.qat import Int4WeightOnlyQATQuantizer quantizer = Int4WeightOnlyQATQuantizer(groupsize=32, inner_k_tiles=8) self._test_qat_quantized_gradients(quantizer) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") def test_qat_4w_quantizer(self): from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer @@ -797,9 +759,6 @@ def test_composable_qat_quantizer(self): values_list, ["quantizer1", "quantizer2", "quantizer1", "quantizer2"] ) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_qat_4w_embedding(self): from torchao._executorch_ops import ( _quantized_decomposed_quantize_per_channel_group_wrapper, @@ -977,15 +936,14 @@ def test_fake_quantize_config_dtype(self): with self.assertRaisesRegex(ValueError, msg): IntxFakeQuantizeConfig(torch.float32, "per_token") # OK - if TORCH_VERSION_AT_LEAST_2_3: - IntxFakeQuantizeConfig(torch.uint1, "per_token") - IntxFakeQuantizeConfig(torch.uint2, "per_token") - IntxFakeQuantizeConfig(torch.uint3, "per_token") - IntxFakeQuantizeConfig(torch.uint4, "per_token") - IntxFakeQuantizeConfig(torch.uint5, "per_token") - IntxFakeQuantizeConfig(torch.uint6, "per_token") - IntxFakeQuantizeConfig(torch.uint7, "per_token") - IntxFakeQuantizeConfig(torch.uint8, "per_token") + IntxFakeQuantizeConfig(torch.uint1, "per_token") + IntxFakeQuantizeConfig(torch.uint2, "per_token") + IntxFakeQuantizeConfig(torch.uint3, "per_token") + IntxFakeQuantizeConfig(torch.uint4, "per_token") + IntxFakeQuantizeConfig(torch.uint5, "per_token") + IntxFakeQuantizeConfig(torch.uint6, "per_token") + IntxFakeQuantizeConfig(torch.uint7, "per_token") + IntxFakeQuantizeConfig(torch.uint8, "per_token") IntxFakeQuantizeConfig(TorchAODType.INT1, "per_token") IntxFakeQuantizeConfig(TorchAODType.INT2, "per_token") IntxFakeQuantizeConfig(TorchAODType.INT3, "per_token") @@ -1010,9 +968,6 @@ def test_fake_quantize_config_dynamic_and_range_learning(self): torch.int8, "per_channel", is_dynamic=True, range_learning=True ) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_fake_quantized_linear_8da4w(self): """ Test that we can express int8 dynamic activations + int4 weights with `FakeQuantizedLinear`. @@ -1066,9 +1021,6 @@ def linear_forward_8da4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: baseline_out = linear_forward_8da4w(x2, fq_linear.weight) torch.testing.assert_close(baseline_out, fq_out, atol=0, rtol=0) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_fake_quantized_linear_4w(self): """ Test that we can express int4 weight only (tinygemm) with `FakeQuantizedLinear`. @@ -1115,9 +1067,6 @@ def linear_forward_4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: baseline_out = linear_forward_4w(x2, fq_linear.weight) torch.testing.assert_close(baseline_out, fq_out, atol=0, rtol=0) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_replace_linear_8da4w(self): module = torch.nn.ModuleList( [ @@ -1137,9 +1086,6 @@ def test_replace_linear_8da4w(self): assert isinstance(module[0], Int8DynActInt4WeightQATLinear) assert isinstance(module[1], Int8DynActInt4WeightQATLinear) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_replace_linear_int4(self): module = torch.nn.ModuleList( [torch.nn.Linear(in_features=256, out_features=50, bias=True)] @@ -1172,9 +1118,6 @@ def test_replace_linear_int4(self): ) assert isinstance(module[0], Int4WeightOnlyQATLinear) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_fake_quantized_embedding_4w(self): """ Test that we can express int4 per group symmetric weight only fake quantization @@ -1212,9 +1155,6 @@ def embedding_forward_4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: baseline_out = embedding_forward_4w(x2, fq_embedding.weight) torch.testing.assert_close(baseline_out, fq_out, atol=0, rtol=0) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_qat_prototype_bc(self): """ Just to make sure we can import all the old prototype paths. @@ -1268,9 +1208,6 @@ def test_qat_prototype_bc(self): Int8DynActInt4WeightQATQuantizer, ) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_qat_config_init(self): """ Test that the correct errors are thrown if `QATConfig` is not instantiated properly. @@ -1324,9 +1261,6 @@ def test_qat_config_init(self): ): QATConfig(fq_config, step="prepare") - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_quantize_api_prepare(self): """ Test that the following: @@ -1375,9 +1309,6 @@ def test_quantize_api_prepare(self): baseline_out = baseline_model(*x2) torch.testing.assert_close(out, baseline_out, atol=0, rtol=0) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_quantize_api_errors(self): """ Test that we throw exceptions with helpful error messages if `quantize_` @@ -1397,9 +1328,6 @@ def test_quantize_api_errors(self): with self.assertRaisesRegex(ValueError, "does not have QAT support"): quantize_(m, qat_config, lambda m, _: isinstance(m, torch.nn.ReLU)) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_quantize_api_e2e(self): """ Test that the following: @@ -1448,9 +1376,6 @@ def test_quantize_api_e2e(self): baseline_out = baseline_model(*x2) torch.testing.assert_close(out, baseline_out, atol=0, rtol=0) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_6, "skipping when torch version is 2.6 or lower" - ) def test_fake_quantize_config_torch_intx(self): """ Test that `IntxFakeQuantizeConfig` works with torch.intx. @@ -1468,9 +1393,6 @@ def test_fake_quantize_config_torch_intx(self): out2 = linear2(*x2) torch.testing.assert_close(out1, out2, atol=0, rtol=0) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_6, "skipping when torch version is 2.6 or lower" - ) def test_fake_quantizer_repr(self): """ Test that `repr(IntxFakeQuantizer(config))` exposes useful config details. @@ -1483,9 +1405,6 @@ def test_fake_quantizer_repr(self): self.assertTrue("PerGroup" in fake_quantizer_repr) self.assertTrue("MappingType.SYMMETRIC" in fake_quantizer_repr) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_qat_linear_bias(self): """ Test that QAT supports linear bias. @@ -1501,9 +1420,6 @@ def test_qat_linear_bias(self): m(*example_inputs) @parameterized.expand([(torch.float32,), (torch.bfloat16,), (torch.float16,)]) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_fake_quantize_per_token_vs_convert(self, dtype: torch.dtype): """ Test that the following produce the exact same numerics: @@ -1521,9 +1437,6 @@ def test_fake_quantize_per_token_vs_convert(self, dtype: torch.dtype): torch.testing.assert_close(fake_quantizer_out, baseline_out, atol=0, rtol=0) @parameterized.expand([(torch.float32,), (torch.bfloat16,), (torch.float16,)]) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_qat_8da4w_prepare_vs_convert(self, dtype: torch.dtype): """ Test that the prepare and convert steps of Int8DynActInt4QATQuantizer produces @@ -1562,9 +1475,6 @@ def test_qat_8da4w_prepare_vs_convert(self, dtype: torch.dtype): ) self.assertEqual(len(non_inf_sqnr), 0, fail_message) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_fake_quantize_config_eps(self): """ Test that users can set arbitrary eps value in `IntxFakeQuantizeConfig`. @@ -1591,9 +1501,6 @@ def test_fake_quantize_config_eps(self): actual_out = fake_quantizer(x) torch.testing.assert_close(expected_out, actual_out, atol=0, rtol=0) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_qat_8da4w_eps(self): """ Test that the 8da4w QAT flow uses the expected eps. @@ -1641,9 +1548,6 @@ def test_qat_8da4w_eps(self): torch.testing.assert_close(expected_out, actual_out, atol=0, rtol=0) @parameterized.expand([(True,), (False,)]) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_fake_quantizer_range_learning(self, is_symmetric): """ Test that range learning requires `IntxFakeQuantizer`s to be initialized correctly. @@ -1685,9 +1589,6 @@ def test_fake_quantizer_range_learning(self, is_symmetric): fake_quantizer(*example_inputs) @parameterized.expand([(True,), (False,)]) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_qat_range_learning(self, is_symmetric): """ Test end-to-end QAT flow with range learning. @@ -1780,9 +1681,6 @@ def test_float8_rowwise_fake_quantize(self): ).to_original_precision() torch.testing.assert_close(out, out_expected, atol=0, rtol=0) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_6, "skipping when torch version is 2.6 or lower" - ) def test_qat_fp8a4w_quantizer(self): """ Test basic model training with `Float8ActInt4WeightQATQuantizer`. @@ -1817,9 +1715,6 @@ def test_qat_fp8a4w_quantizer(self): self.assertNotEqual(torch.count_nonzero(new_weight.grad), 0) self.assertFalse(torch.equal(new_weight, prev_weight)) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_legacy_quantize_api_e2e(self): """ Test that the following two APIs are numerically equivalent: @@ -1871,9 +1766,6 @@ def test_legacy_quantize_api_e2e(self): baseline_out = baseline_model(*x2) torch.testing.assert_close(out, baseline_out, atol=0, rtol=0) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_qat_api_deprecation(self): """ Test that the appropriate deprecation warning is logged exactly once per class. diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index b9d99e7ac7..3b26cd25d6 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -66,10 +66,6 @@ from torchao.quantization.utils import compute_error from torchao.testing.utils import skip_if_rocm from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_3, - TORCH_VERSION_AT_LEAST_2_4, - TORCH_VERSION_AT_LEAST_2_5, - TORCH_VERSION_AT_LEAST_2_6, TORCH_VERSION_AT_LEAST_2_8, is_sm_at_least_89, is_sm_at_least_90, @@ -279,7 +275,6 @@ def api(model): torch.testing.assert_close(ref, res.cpu()) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "only works for torch 2.4+") def test_int8_wo_quant_save_load(self): m = ToyLinearModel().eval().cpu() @@ -308,9 +303,6 @@ def api(model): atol, rtol = (1e-2, 1e-2) if torch.version.hip else (None, None) torch.testing.assert_close(ref, res.cpu(), atol=atol, rtol=rtol) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch verion is 2.3 or lower" - ) def test_8da4w_quantizer(self): from torchao.quantization.linear_quant_modules import Int8DynActInt4WeightLinear from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer @@ -323,9 +315,6 @@ def test_8da4w_quantizer(self): assert isinstance(m.linear2, Int8DynActInt4WeightLinear) m(*example_inputs) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch verion is 2.3 or lower" - ) def test_8da4w_quantizer_linear_bias(self): from torchao.quantization.linear_quant_modules import Int8DynActInt4WeightLinear from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer @@ -444,7 +433,6 @@ def test_eval_wrapper_llama3(self): ) # TODO: move to a separate test file - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @common_utils.parametrize( "mapping_type", [MappingType.SYMMETRIC, MappingType.SYMMETRIC_NO_CLIPPING_ERR] ) @@ -484,8 +472,6 @@ def test_quantized_tensor_subclass_8da4w(self, mapping_type): ref = m_copy(*example_inputs) self.assertTrue(torch.equal(res, ref)) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") - # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "Test currently doesn't work for 2.5+") @unittest.skipIf(len(GPU_DEVICES) == 0, "Need GPU available") def test_quantized_tensor_subclass_int4(self): for device in self.GPU_DEVICES: @@ -512,7 +498,6 @@ def test_quantized_tensor_subclass_int4(self): self.assertTrue(torch.equal(res, ref)) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_quantized_tensor_subclass_int8_wo(self): m = ToyLinearModel().eval().to(torch.bfloat16) @@ -532,50 +517,6 @@ def test_quantized_tensor_subclass_int8_wo(self): self.assertTrue(torch.equal(res, ref)) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.5 and below") - def test_quantized_tensor_subclass_int8_dyn_quant(self): - # use multiples of 1024 so that we don't need padding - m = ToyLinearModel(1024, 1024, 2048).eval().to(torch.bfloat16).to("cuda") - m_copy = copy.deepcopy(m) - # setting batch_size to 20 to be compatible with the kernel - example_inputs = m.example_inputs( - batch_size=20, dtype=torch.bfloat16, device="cuda" - ) - quantize_(m, int8_dynamic_activation_int8_weight()) - - assert isinstance(m.linear1.weight, LinearActivationQuantizedTensor) - assert isinstance(m.linear2.weight, LinearActivationQuantizedTensor) - assert isinstance( - m.linear1.weight.original_weight_tensor, AffineQuantizedTensor - ) - assert isinstance( - m.linear2.weight.original_weight_tensor, AffineQuantizedTensor - ) - - # reference - _ref_change_linear_weights_to_int8_dqtensors(m_copy) - - res = m(*example_inputs) - ref = m_copy(*example_inputs) - - self.assertTrue(torch.equal(res, ref)) - - # workaround for export path - from torchao.utils import unwrap_tensor_subclass - - m_unwrapped = unwrap_tensor_subclass(m) - - m = torch.export.export(m_unwrapped, example_inputs, strict=True).module() - exported_model_res = m(*example_inputs) - - self.assertTrue(torch.equal(exported_model_res, ref)) - - # make sure it compiles - torch._export.aot_compile(m_unwrapped, example_inputs) - - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_quantized_tensor_subclass_save_load(self): m = ToyLinearModel().eval().to(torch.bfloat16) @@ -594,7 +535,6 @@ def test_quantized_tensor_subclass_save_load(self): res = m_copy(*example_inputs) self.assertEqual(res, ref) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_int8wo_quantized_model_to_device(self): m = ToyLinearModel().eval().to(torch.bfloat16) @@ -608,25 +548,6 @@ def test_int8wo_quantized_model_to_device(self): cuda_res = m(*example_inputs_cuda) self.assertEqual(cuda_res.cpu(), ref) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "Test currently doesn't work for 2.5+") - def test_int4wo_quantized_model_to_device(self): - # TODO: change initial model to "cpu" - devices = ["cuda", "cuda:0"] - for device in devices: - m = ToyLinearModel().eval().to(torch.bfloat16).to(device) - example_inputs = m.example_inputs(dtype=torch.bfloat16, device=device) - - quantize_(m, int4_weight_only()) - ref = m(*example_inputs) - - example_inputs_cuda = (example_inputs[0].to(device),) - m.to(device=device) - cuda_res = m(*example_inputs_cuda) - self.assertEqual(cuda_res.cpu(), ref) - - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_quantized_tensor_subclass_save_load_map_location(self): m = ToyLinearModel().eval().to(dtype=torch.bfloat16, device="cuda") @@ -648,7 +569,6 @@ def test_quantized_tensor_subclass_save_load_map_location(self): res = m_copy(*example_inputs) self.assertEqual(res, ref) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_quantized_model_streaming(self): def reset_memory(): @@ -671,7 +591,6 @@ def reset_memory(): assert param.is_cuda self.assertLess(memory_streaming, memory_baseline) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+") @common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half]) @common_utils.parametrize("x_dim", [2, 3]) @common_utils.parametrize("use_hqq", [True, False]) @@ -698,7 +617,6 @@ def test_int4wo_cpu(self, dtype, x_dim, use_hqq): assert "aten.mm.default" not in code[0] # TODO(#1690): move to new config names - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @common_utils.parametrize( "config", @@ -795,7 +713,6 @@ def test_module_fqn_to_config_module_name(self): assert isinstance(model.linear2.weight, AffineQuantizedTensor) assert isinstance(model.linear2.weight._layout, PlainLayout) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Need torch 2.6+") def test_module_fqn_to_config_embedding_linear(self): weight_dtype = torch.int8 granularity = PerGroup(8) diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index 12027243a8..f3d265e14a 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -29,10 +29,6 @@ groupwise_affine_quantize_tensor_from_qparams, ) from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_3, - TORCH_VERSION_AT_LEAST_2_4, - TORCH_VERSION_AT_LEAST_2_5, - TORCH_VERSION_AT_LEAST_2_6, check_cpu_version, check_xpu_version, is_fbcode, @@ -132,11 +128,10 @@ def _groupwise_affine_quantize_tensor_from_qparams( .reshape_as(w) ) - if TORCH_VERSION_AT_LEAST_2_5: - if (not (check_cpu_version(w.device))) and (not (check_xpu_version(w.device))): - w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8) - if check_xpu_version(w.device): - w_int4x8 = (w_int4x8[::, 1::2] << 4 | w_int4x8[::, ::2]).to(torch.uint8) + if (not (check_cpu_version(w.device))) and (not (check_xpu_version(w.device))): + w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8) + if check_xpu_version(w.device): + w_int4x8 = (w_int4x8[::, 1::2] << 4 | w_int4x8[::, ::2]).to(torch.uint8) return w_int4x8 @@ -175,9 +170,6 @@ def _groupwise_affine_dequantize_tensor_from_qparams( class TestQuantPrimitives(unittest.TestCase): SEED = 123 - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch version is 2.3 or lower" - ) def test_get_group_qparams_symmetric(self): """ Test that `get_group_qparams_symmetric` produces the exact same scales as @@ -264,34 +256,21 @@ def test_choose_qparams_group_sym_no_clipping_err(self): self.assertTrue(torch.equal(scale, scale_ref)) self.assertTrue(torch.equal(zero_point, zp_ref)) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch version is 2.3 or lower" - ) @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_choose_qparams_token_asym(self): input = torch.randn(10, 10) mapping_type = MappingType.ASYMMETRIC dtype = torch.int8 block_size = (1, 10) - if TORCH_VERSION_AT_LEAST_2_6: - scale, zero_point = choose_qparams_affine( - input, - mapping_type, - block_size, - dtype, - eps=torch.finfo(torch.float32).eps, - scale_dtype=torch.float64, - zero_point_dtype=torch.int64, - ) - else: - scale, zero_point = choose_qparams_affine( - input, - mapping_type, - block_size, - dtype, - eps=torch.finfo(torch.float32).eps, - ) - + scale, zero_point = choose_qparams_affine( + input, + mapping_type, + block_size, + dtype, + eps=torch.finfo(torch.float32).eps, + scale_dtype=torch.float64, + zero_point_dtype=torch.int64, + ) scale_ref, zp_ref = ( torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric( input, dtype @@ -347,9 +326,6 @@ def test_choose_qparams_tensor_sym(self): self.assertTrue(torch.equal(scale, scale_ref)) self.assertTrue(torch.equal(zero_point, zp_ref)) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_quantize_activation_per_token_abs_max(self): input = torch.randn(10, 10) quantized_ref, scale_ref = _quantize_activation_per_token_absmax(input) @@ -380,17 +356,11 @@ def test_quantize_activation_per_token_abs_max(self): self.assertTrue(torch.equal(quantized, quantized_ref)) self.assertTrue(torch.equal(scale, scale_ref)) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_quantize_activation_per_token_abs_max_zero_input(self): input = torch.zeros(10, 10) # make sure it still works quantized_ref, scale_ref = _quantize_activation_per_token_absmax(input) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_quantize_activation_per_token_abs_max_dtype(self): input = torch.zeros(10, 10, dtype=torch.bfloat16) quantized_ref, scale_ref = _quantize_activation_per_token_absmax(input) @@ -404,9 +374,6 @@ def test_quantize_activation_per_token_abs_max_dtype(self): quantized_ref, scale_ref = _quantize_activation_per_token_absmax(input) self.assertTrue(scale_ref.dtype, torch.float32) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_quantize_dequantize_group_sym(self): input = torch.randn(10, 10) @@ -449,9 +416,6 @@ def test_quantize_dequantize_group_sym(self): self.assertTrue(torch.equal(quantized, quantized_ref)) self.assertTrue(torch.equal(dequantized, dequantized_ref)) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_quantize_dequantize_channel_asym(self): input = torch.randn(10, 10) @@ -493,9 +457,6 @@ def test_quantize_dequantize_channel_asym(self): self.assertTrue(torch.equal(quantized, quantized_ref)) self.assertTrue(torch.equal(dequantized, dequantized_ref)) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_quantize_dequantize_tensor_asym(self): input = torch.randn(10, 10) @@ -535,9 +496,6 @@ def test_quantize_dequantize_tensor_asym(self): self.assertTrue(torch.equal(quantized, quantized_ref)) self.assertTrue(torch.equal(dequantized, dequantized_ref)) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_quantize_dequantize_channel_asym_4d(self): input = torch.randn(3, 3, 10, 10) @@ -578,9 +536,6 @@ def test_quantize_dequantize_channel_asym_4d(self): self.assertTrue(torch.equal(quantized, quantized_ref)) self.assertTrue(torch.equal(dequantized, dequantized_ref)) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch version is 2.3 or lower" - ) def test_quantize_dequantize_channel_asym_4d_multi_dim_reduction(self): input = torch.randn(3, 3, 10, 10) mapping_type = MappingType.ASYMMETRIC @@ -726,32 +681,22 @@ def test_groupwise_affine_dequantize_tensor_from_qparams(self): for zero_point_domain in [ZeroPointDomain.FLOAT, ZeroPointDomain.INT]: if zero_point_domain == ZeroPointDomain.INT: zeros = torch.randint(0, 15, (10, 2), dtype=torch.int32) - if TORCH_VERSION_AT_LEAST_2_5: - input_tmp = input - if (not (check_cpu_version(input.device))) and ( - not (check_xpu_version(input.device)) - ): - input_tmp = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8) - if check_xpu_version(input.device): - input_tmp = (input[::, 1::2] << 4 | input[::, ::2]).to(torch.uint8) - w_bf16 = groupwise_affine_dequantize_tensor_from_qparams( - input_tmp, scales, zeros, n_bit, groupsize, zero_point_domain - ) - else: - if zero_point_domain == ZeroPointDomain.INT: - continue - w_bf16 = groupwise_affine_dequantize_tensor_from_qparams( - input, scales, zeros, n_bit, groupsize - ) + input_tmp = input + if (not (check_cpu_version(input.device))) and ( + not (check_xpu_version(input.device)) + ): + input_tmp = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8) + if check_xpu_version(input.device): + input_tmp = (input[::, 1::2] << 4 | input[::, ::2]).to(torch.uint8) + w_bf16 = groupwise_affine_dequantize_tensor_from_qparams( + input_tmp, scales, zeros, n_bit, groupsize, zero_point_domain + ) w_bf16_ref = _groupwise_affine_dequantize_tensor_from_qparams( input, scales, zeros, n_bit, groupsize, zero_point_domain ) self.assertTrue(torch.equal(w_bf16, w_bf16_ref)) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_fake_quantize_affine(self): input = torch.randn(10, 10) @@ -785,9 +730,6 @@ def test_fake_quantize_affine(self): ) torch.testing.assert_close(dequantized, fake_quantized) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_fake_quantize_affine_cachemask(self): input = torch.randn(10, 10) diff --git a/test/sparsity/test_fast_sparse_training.py b/test/sparsity/test_fast_sparse_training.py index 804a585dd8..424306f897 100644 --- a/test/sparsity/test_fast_sparse_training.py +++ b/test/sparsity/test_fast_sparse_training.py @@ -15,7 +15,7 @@ swap_linear_with_semi_sparse_linear, swap_semi_sparse_linear_with_linear, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, is_fbcode +from torchao.utils import is_fbcode class ToyModel(nn.Module): @@ -32,7 +32,6 @@ def forward(self, x): class TestRuntimeSemiStructuredSparsity(TestCase): - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "pytorch 2.4+ feature") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf(is_fbcode(), "broken in fbcode") @unittest.skip("Temporarily skipping to unpin nightlies") @@ -81,7 +80,6 @@ def test_runtime_weight_sparsification(self): for name, mod in model_c.named_modules(): assert not isinstance(mod, SemiSparseLinear) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "pytorch 2.4+ feature") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf(is_fbcode(), "broken in fbcode") @unittest.skip("Temporarily skipping to unpin nightlies") diff --git a/test/sparsity/test_marlin.py b/test/sparsity/test_marlin.py index 783de6c6ae..3cf310d71f 100644 --- a/test/sparsity/test_marlin.py +++ b/test/sparsity/test_marlin.py @@ -20,7 +20,6 @@ from torchao.sparsity.marlin import inject_24, pack_to_marlin_24, unpack_from_marlin_24 from torchao.sparsity.sparse_api import apply_fake_sparsity from torchao.testing.utils import skip_if_rocm -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 class SparseMarlin24(TestCase): @@ -58,7 +57,6 @@ def test_quant_sparse_marlin_layout_eager(self): "Results are not close" ) - @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="Needs PyTorch 2.5+") @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") @skip_if_rocm("ROCm enablement in progress") def test_quant_sparse_marlin_layout_compile(self): diff --git a/test/sparsity/test_sparse_api.py b/test/sparsity/test_sparse_api.py index 5e3086c411..30a063bf79 100644 --- a/test/sparsity/test_sparse_api.py +++ b/test/sparsity/test_sparse_api.py @@ -18,12 +18,6 @@ quantize_, ) from torchao.sparsity import apply_fake_sparsity, semi_sparse_weight, sparsify_ -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_3, - TORCH_VERSION_AT_LEAST_2_4, - TORCH_VERSION_AT_LEAST_2_5, - TORCH_VERSION_AT_LEAST_2_6, -) logging.basicConfig( format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO @@ -31,7 +25,6 @@ class TestSemiStructuredSparse(common_utils.TestCase): - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "pytorch 2.3+ feature") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skip("Temporarily skipping to unpin nightlies") def test_sparse(self): @@ -59,7 +52,6 @@ def test_sparse(self): class TestQuantSemiSparse(common_utils.TestCase): - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "pytorch 2.5+ feature") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @common_utils.parametrize("compile", [False]) @unittest.skip("Temporarily skip to unbreak CI") @@ -97,7 +89,6 @@ def test_quant_semi_sparse(self, compile): torch.testing.assert_close(dense_result, sparse_result, rtol=1e-2, atol=1e-2) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "pytorch 2.5+ feature") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @common_utils.parametrize("compile", [True, False]) def test_sparse_marlin(self, compile): @@ -132,10 +123,6 @@ def test_sparse_marlin(self, compile): class TestBlockSparseWeight(common_utils.TestCase): - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, - "pytorch 2.4+ feature due to need for custom op support", - ) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @common_utils.parametrize("compile", [True, False]) @common_utils.parametrize("input_shape", [1, 1024]) @@ -170,7 +157,6 @@ def test_sparse(self, compile, input_shape): class TestQuantBlockSparseWeight(common_utils.TestCase): - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "pytorch 2.6+ feature") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @common_utils.parametrize("compile", [True, False]) def test_sparse(self, compile): diff --git a/test/test_low_bit_optim.py b/test/test_low_bit_optim.py index 00c30b919a..64df37ac88 100644 --- a/test/test_low_bit_optim.py +++ b/test/test_low_bit_optim.py @@ -41,7 +41,6 @@ from torchao.optim.subclass_fp8 import OptimStateFp8 from torchao.testing.utils import skip_if_rocm from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_7, get_available_devices, ) @@ -222,8 +221,6 @@ def test_param_groups(self, optim_name, device): @parametrize("device", _DEVICES) def test_subclass_slice(self, subclass, shape, device): if subclass == OptimStateFp8: - if device == "cpu" and len(shape) > 1 and not TORCH_VERSION_AT_LEAST_2_5: - pytest.skip("fill_cpu not implemented for Float8_e4m3fn for torch<2.5") if device == "cuda" and torch.cuda.get_device_capability() < (8, 9): pytest.skip("FP8 CUDA requires compute capability >= 8.9") @@ -469,9 +466,6 @@ class TestFSDP2(FSDPTest): def world_size(self) -> int: return _FSDP_WORLD_SIZE - @pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_5, reason="PyTorch>=2.5 is required." - ) @skip_if_lt_x_gpu(_FSDP_WORLD_SIZE) @skip_if_rocm("ROCm enablement in progress") def test_fsdp2(self): @@ -587,9 +581,6 @@ def _test_fsdp2(self, args): v2 = v2.dequantize() self.assertEqual(v1, v2) - @pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_5, reason="PyTorch>=2.5 is required." - ) @skip_if_lt_x_gpu(_FSDP_WORLD_SIZE) @skip_if_rocm("ROCm enablement in progress") def test_uneven_shard(self): diff --git a/test/test_ops.py b/test/test_ops.py index faec689a69..bc9fe0e4f9 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -28,7 +28,6 @@ ) from torchao.sparsity.marlin import inject_24, marlin_24_workspace, pack_to_marlin_24 from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_7, compute_max_diff, ) @@ -281,25 +280,21 @@ def make_test_id(param): @pytest.mark.skipif(not IS_CUDA, reason="CUDA not available") -# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+") @pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK, ids=make_test_id) def test_unpack_tensor_core_tiled_layout_correctness(shape, inner_k_tiles): N, K = shape assert K % (inner_k_tiles * kTileSizeK) == 0 and N % kTileSizeN == 0 t = torch.randint(0, 16, dtype=torch.int, size=shape, device="cuda") - if TORCH_VERSION_AT_LEAST_2_5: - t = (t[::, ::2] << 4 | t[::, 1::2]).to(torch.uint8) + t = (t[::, ::2] << 4 | t[::, 1::2]).to(torch.uint8) packed_w = torch.ops.aten._convert_weight_to_int4pack(t, inner_k_tiles) unpacked = torchao.ops.unpack_tensor_core_tiled_layout(packed_w, inner_k_tiles) - if TORCH_VERSION_AT_LEAST_2_5: - unpacked = (unpacked[::, ::2] << 4 | unpacked[::, 1::2]).to(torch.uint8) + unpacked = (unpacked[::, ::2] << 4 | unpacked[::, 1::2]).to(torch.uint8) assert torch.equal(t, unpacked) # TODO: Fix "test_aot_dispatch_dynamic" test failure @pytest.mark.skipif(not IS_CUDA, reason="CUDA not available") -# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+") @pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK, ids=make_test_id) def test_unpack_tensor_core_tiled_layout_op(shape, inner_k_tiles): test_utils = [ @@ -308,13 +303,10 @@ def test_unpack_tensor_core_tiled_layout_op(shape, inner_k_tiles): "test_faketensor", ] - # TODO: Figure out why test fails unless torch >= 2.5 - if TORCH_VERSION_AT_LEAST_2_5: - test_utils.append("test_aot_dispatch_dynamic") + test_utils.append("test_aot_dispatch_dynamic") t = torch.randint(0, 16, dtype=torch.int, size=shape, device="cuda") - if TORCH_VERSION_AT_LEAST_2_5: - t = (t[::, ::2] << 4 | t[::, 1::2]).to(torch.uint8) + t = (t[::, ::2] << 4 | t[::, 1::2]).to(torch.uint8) packed_w = torch.ops.aten._convert_weight_to_int4pack(t, inner_k_tiles) opcheck( @@ -345,7 +337,6 @@ def dequant_ref(q, scales, zeros, group_size, nbits=4, dtype=torch.bfloat16): @pytest.mark.skipif(not IS_CUDA, reason="CUDA not available") -# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+") @pytest.mark.parametrize( "shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str ) @@ -413,7 +404,6 @@ def test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant( # This test differs from one above in that it uses `unpack_tensor_core_tiled_layout` to unpack then dequantize @pytest.mark.skipif(not IS_CUDA, reason="CUDA not available") -# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+") @pytest.mark.parametrize( "shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str ) @@ -438,8 +428,7 @@ def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant( # Unpack and dequantize unpacked = torchao.ops.unpack_tensor_core_tiled_layout(packed, inner_k_tiles) - if TORCH_VERSION_AT_LEAST_2_5: - unpacked = (unpacked[::, ::2] << 4 | unpacked[::, 1::2]).to(torch.uint8) + unpacked = (unpacked[::, ::2] << 4 | unpacked[::, 1::2]).to(torch.uint8) dq_ao = groupwise_affine_dequantize_tensor_from_qparams( unpacked, scales, zeros, n_bit=4, groupsize=group_size @@ -479,7 +468,6 @@ def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant( @pytest.mark.skipif(not IS_CUDA, reason="CUDA not available") -# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+") @pytest.mark.parametrize( "shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str ) @@ -488,8 +476,7 @@ def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size device = "cuda" q = torch.randint(0, 16, shape, dtype=torch.int, device=device) - if TORCH_VERSION_AT_LEAST_2_5: - q = (q[::, ::2] << 4 | q[::, 1::2]).to(torch.uint8) + q = (q[::, ::2] << 4 | q[::, 1::2]).to(torch.uint8) packed_w = torch._convert_weight_to_int4pack(q, inner_k_tiles) q_groups = k // group_size scales = torch.randn(n, q_groups, dtype=torch.bfloat16, device=device) @@ -501,9 +488,7 @@ def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size "test_autograd_registration", "test_faketensor", ] - # TODO: Figure out why test fails unless torch >= 2.5 - if TORCH_VERSION_AT_LEAST_2_5: - test_utils.append("test_aot_dispatch_dynamic") + test_utils.append("test_aot_dispatch_dynamic") opcheck( torch.ops.torchao.dequantize_tensor_core_tiled_layout, (packed_w, scales_and_zeros, group_size, inner_k_tiles), @@ -766,9 +751,7 @@ def test_swizzle_mm(): "test_faketensor", ] - # TODO: Figure out why test fails unless torch >= 2.5 - if TORCH_VERSION_AT_LEAST_2_5: - test_utils.append("test_aot_dispatch_dynamic") + test_utils.append("test_aot_dispatch_dynamic") mat1 = torch.randint(0, 16, dtype=torch.float, size=(16, 32), device="cuda") mat2 = torch.randint(0, 16, dtype=torch.float, size=(32, 16), device="cuda") diff --git a/torchao/_executorch_ops.py b/torchao/_executorch_ops.py index 4b761ad725..5d680bcf82 100644 --- a/torchao/_executorch_ops.py +++ b/torchao/_executorch_ops.py @@ -12,37 +12,17 @@ def _quantized_decomposed_quantize_per_channel_group_wrapper(*args, **kwargs): """ Wrapper around torch.ops.quantized_decomposed.quantize_per_channel_group to mitigate availability issue until it can be supplanted by new quantize_affine function. - - torch.ops.quantized_decomposed.quantize_per_channel_group is only available - in PyTorch 2.3+ and recently changed signatures. """ - from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 - - if TORCH_VERSION_AT_LEAST_2_3: - return torch.ops.quantized_decomposed.quantize_per_channel_group( - *args, **kwargs - ) - raise ImportError( - "Need torch.ops.quantized_decomposed.quantize_per_channel_group, which is only available with PyTorch 2.3 or later." - ) + return torch.ops.quantized_decomposed.quantize_per_channel_group(*args, **kwargs) def _quantized_decomposed_choose_qparams_per_token_asymmetric_wrapper(*args, **kwargs): """ Wrapper around torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric to mitigate availability issue until it can be supplanted by new choose_qparams_affine function. - - torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric is only available - in PyTorch 2.3+ and recently changed signatures. """ - from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 - - if TORCH_VERSION_AT_LEAST_2_3: - return torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric( - *args, **kwargs - ) - raise ImportError( - "Need torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric, which is only available with PyTorch 2.3 or later." + return torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric( + *args, **kwargs ) @@ -50,50 +30,21 @@ def _quantized_decomposed_dequantize_per_channel_group_wrapper(*args, **kwargs): """ Wrapper around torch.ops.quantized_decomposed.dequantize_per_channel_group to mitigate availability issue until it can be supplanted by new choose_qparams_affine function. - - torch.ops.quantized_decomposed.dequantize_per_channel_group is only available - in PyTorch 2.3+ and recently changed signatures. """ - from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 - - if TORCH_VERSION_AT_LEAST_2_3: - return torch.ops.quantized_decomposed.dequantize_per_channel_group( - *args, **kwargs - ) - raise ImportError( - "Need torch.ops.quantized_decomposed.dequantize_per_channel_group, which is only available with PyTorch 2.3 or later." - ) + return torch.ops.quantized_decomposed.dequantize_per_channel_group(*args, **kwargs) def _quantized_decomposed_quantize_per_token_wrapper(*args, **kwargs): """ Wrapper around torch.ops.quantized_decomposed.quantize_per_token to mitigate availability issue until it can be supplanted by new choose_qparams_affine function. - - torch.ops.quantized_decomposed.quantize_per_token is only available - in PyTorch 2.3+ and recently changed signatures. """ - from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 - - if TORCH_VERSION_AT_LEAST_2_3: - return torch.ops.quantized_decomposed.quantize_per_token(*args, **kwargs) - raise ImportError( - "Need torch.ops.quantized_decomposed.quantize_per_token, which is only available with PyTorch 2.3 or later." - ) + return torch.ops.quantized_decomposed.quantize_per_token(*args, **kwargs) def _quantized_decomposed_dequantize_per_token_wrapper(*args, **kwargs): """ Wrapper around torch.ops.quantized_decomposed.dequantize_per_token to mitigate availability issue until it can be supplanted by new choose_qparams_affine function. - - torch.ops.quantized_decomposed.dequantize_per_token is only available - in PyTorch 2.3+ and recently changed signatures. """ - from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 - - if TORCH_VERSION_AT_LEAST_2_3: - return torch.ops.quantized_decomposed.dequantize_per_token(*args, **kwargs) - raise ImportError( - "Need torch.ops.quantized_decomposed.dequantize_per_token, which is only available with PyTorch 2.3 or later." - ) + return torch.ops.quantized_decomposed.dequantize_per_token(*args, **kwargs) diff --git a/torchao/_models/llama/eval.py b/torchao/_models/llama/eval.py index cc4e439a49..57b67ab16e 100644 --- a/torchao/_models/llama/eval.py +++ b/torchao/_models/llama/eval.py @@ -28,7 +28,6 @@ quantize_, uintx_weight_only, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, unwrap_tensor_subclass def run_evaluation( @@ -151,9 +150,6 @@ def run_evaluation( model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length) quantizer.quantize(model, *inputs) model = model.to(device) - else: - if not TORCH_VERSION_AT_LEAST_2_5: - unwrap_tensor_subclass(model) if "float8wo" in quantization: quantize_(model, float8_weight_only()) if "float8dq" in quantization: @@ -239,11 +235,6 @@ def run_evaluation( ) elif quantization.startswith("awq-uintx"): from torchao._models._eval import TransformerEvalWrapper - from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 - - if not TORCH_VERSION_AT_LEAST_2_3: - print("Awq requires torch2.3+") - exit() from torchao.prototype.awq import ( AWQObservedLinear, awq_uintx, diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 8f02e83a99..0a18e41c39 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -20,11 +20,7 @@ write_json_result_ossci, ) from torchao.quantization.quant_primitives import MappingType -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, - TORCH_VERSION_AT_LEAST_2_6, - get_model_size_in_bytes, -) +from torchao.utils import get_model_size_in_bytes torch.sparse.SparseSemiStructuredTensor._FORCE_CUTLASS = False torch.backends.cuda.enable_cudnn_sdp(True) @@ -356,7 +352,6 @@ def ffn_or_attn_only(mod, fqn): uintx_weight_only, ) from torchao.quantization.granularity import PerRow, PerTensor - from torchao.utils import unwrap_tensor_subclass if "spinquant" in quantization: from torchao.prototype.spinquant import apply_spinquant @@ -505,11 +500,6 @@ def ffn_or_attn_only(mod, fqn): ) elif quantization.startswith("awq"): from torchao._models._eval import TransformerEvalWrapper - from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 - - if not TORCH_VERSION_AT_LEAST_2_3: - print("Awq requires torch2.3+") - exit() from torchao.prototype.awq import ( AWQObservedLinear, awq_uintx, @@ -567,9 +557,6 @@ def ffn_or_attn_only(mod, fqn): group_size = int(_quant_args[2]) quantize_(model, uintx_weight_only(dtype, group_size, use_hqq=use_hqq)) elif "int8_dynamic_activation_intx_weight" in quantization: - assert TORCH_VERSION_AT_LEAST_2_6, ( - "int8_dynamic_activation_intx_weight requires torch2.6+" - ) assert precision == torch.float32, ( "int8_dynamic_activation_intx_weight requires using precision=torch.float32" ) @@ -829,10 +816,6 @@ def ffn_or_attn_only(mod, fqn): model, codebook_weight_only(dtype=torch.uint4, scale_block_size=64) ) - else: - if not TORCH_VERSION_AT_LEAST_2_5: - unwrap_tensor_subclass(model) - # standalone sparsity elif sparsity: from torchao.sparsity import semi_sparse_weight, sparsify_ diff --git a/torchao/_models/sam/eval_combo.py b/torchao/_models/sam/eval_combo.py index a0410fb734..97bb04ef8b 100644 --- a/torchao/_models/sam/eval_combo.py +++ b/torchao/_models/sam/eval_combo.py @@ -28,7 +28,6 @@ quantize_, ) from torchao.sparsity import apply_fake_sparsity, semi_sparse_weight, sparsify_ -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, unwrap_tensor_subclass torch._dynamo.config.cache_size_limit = 50000 @@ -364,10 +363,6 @@ def mlp_only(mod, name): if compress == "int8_dynamic_quant": quantize_(predictor.model.image_encoder, int8_dynamic_activation_int8_weight()) - if not TORCH_VERSION_AT_LEAST_2_5: - predictor.model.image_encoder = unwrap_tensor_subclass( - predictor.model.image_encoder - ) elif compress == "sparse_mlp_only": def mlp_only(mod, name): @@ -395,10 +390,6 @@ def mlp_only(mod, name): mlp_lin1_only, ) sparsify_(predictor.model.image_encoder, semi_sparse_weight(), mlp_lin2_only) - if not TORCH_VERSION_AT_LEAST_2_5: - predictor.model.image_encoder = unwrap_tensor_subclass( - predictor.model.image_encoder - ) elif compress == "int4_weight_only_sparse": # apply sparsify first to set qparams apply_fake_sparsity(predictor.model.image_encoder, filter_fn=mlp_only) @@ -415,10 +406,6 @@ def mlp_only(mod, name): mlp_lin1_only, ) sparsify_(predictor.model.image_encoder, semi_sparse_weight(), mlp_lin2_only) - if not TORCH_VERSION_AT_LEAST_2_5: - predictor.model.image_encoder = unwrap_tensor_subclass( - predictor.model.image_encoder - ) elif compress is not None and "autoquant_v2" in compress: example_input = torch.randn( diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index f4386e43ad..63e0dcc562 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -35,10 +35,7 @@ dequantize_affine, quantize_affine, ) -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, - TorchAOBaseTensor, -) +from torchao.utils import TorchAOBaseTensor logger = logging.getLogger(__name__) aten = torch.ops.aten @@ -613,6 +610,5 @@ def _apply_fn_to_data(self, fn): # experimental will be merged in to floatx to_affine_quantized_fpx = AffineQuantizedTensor.from_hp_to_fpx -if TORCH_VERSION_AT_LEAST_2_5: - # Allow a model with AffineQuantizedTensor weights to be loaded with `weights_only=True` - torch.serialization.add_safe_globals([AffineQuantizedTensor]) +# Allow a model with AffineQuantizedTensor weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals([AffineQuantizedTensor]) diff --git a/torchao/dtypes/fbgemm_fp8_tensor.py b/torchao/dtypes/fbgemm_fp8_tensor.py index 85f83bcb50..6f007c9339 100644 --- a/torchao/dtypes/fbgemm_fp8_tensor.py +++ b/torchao/dtypes/fbgemm_fp8_tensor.py @@ -11,7 +11,6 @@ from torch.utils._python_dispatch import return_and_correct_aliasing from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, TorchAOBaseTensor, fill_defaults, ) @@ -265,6 +264,5 @@ def _(func, types, args, kwargs): to_fbgemm_fp8 = FbgemmFp8Tensor.from_float -if TORCH_VERSION_AT_LEAST_2_5: - # Allow a model with FbgemmFp8Tensor weights to be loaded with `weights_only=True` - torch.serialization.add_safe_globals([FbgemmFp8Tensor]) +# Allow a model with FbgemmFp8Tensor weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals([FbgemmFp8Tensor]) diff --git a/torchao/dtypes/nf4tensor.py b/torchao/dtypes/nf4tensor.py index 4764e8b69b..5542a9de58 100644 --- a/torchao/dtypes/nf4tensor.py +++ b/torchao/dtypes/nf4tensor.py @@ -15,8 +15,6 @@ from torch._prims_common import make_contiguous_strides_for from torch.distributed.device_mesh import DeviceMesh -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 - aten = torch.ops.aten c10d_functional = torch.ops.c10d_functional @@ -1156,6 +1154,5 @@ def nf4_constructor( ) -if TORCH_VERSION_AT_LEAST_2_5: - torch.serialization.add_safe_globals([NF4Tensor]) - torch.serialization.add_safe_globals([NF4Tensor]) +torch.serialization.add_safe_globals([NF4Tensor]) +torch.serialization.add_safe_globals([NF4Tensor]) diff --git a/torchao/dtypes/uintx/int4_cpu_layout.py b/torchao/dtypes/uintx/int4_cpu_layout.py index da19bbc259..cd09eec452 100644 --- a/torchao/dtypes/uintx/int4_cpu_layout.py +++ b/torchao/dtypes/uintx/int4_cpu_layout.py @@ -21,11 +21,7 @@ ZeroPointDomain, _quantize_affine_tinygemm, ) -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, - TORCH_VERSION_AT_LEAST_2_6, - fill_defaults, -) +from torchao.utils import fill_defaults aten = torch.ops.aten @@ -114,29 +110,13 @@ def from_plain( ): assert isinstance(_layout, Int4CPULayout) - if TORCH_VERSION_AT_LEAST_2_6: - assert int_data.dtype == torch.int32, ( - "torch.ops.aten._convert_weight_to_int4pack_for_cpu expects `int32` dtype" - ) - packed_weight = torch.ops.aten._convert_weight_to_int4pack_for_cpu( - int_data, - 1, # TODO:remove - ) - elif TORCH_VERSION_AT_LEAST_2_5: - int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8) - assert int_data.dtype == torch.uint8, ( - "torch.ops.aten._convert_weight_to_int4pack in torch 2.5 expects `uint8` dtype" - ) - packed_weight = torch.ops.aten._convert_weight_to_int4pack( - int_data, _layout.inner_k_tiles - ) - else: - assert int_data.dtype == torch.int32, ( - "torch.ops.aten._convert_weight_to_int4pack in torch 2.4 expects `int32` dtype" - ) - packed_weight = torch.ops.aten._convert_weight_to_int4pack( - int_data, _layout.inner_k_tiles - ) + assert int_data.dtype == torch.int32, ( + "torch.ops.aten._convert_weight_to_int4pack_for_cpu expects `int32` dtype" + ) + packed_weight = torch.ops.aten._convert_weight_to_int4pack_for_cpu( + int_data, + 1, # TODO:remove + ) scale = scale.reshape(int_data.shape[0], -1) zero_point = zero_point.reshape(int_data.shape[0], -1) @@ -284,8 +264,7 @@ def _is_float(dtype): def _linear_fp_act_uint4_weight_cpu_check(input_tensor, weight_tensor, bias): return ( - TORCH_VERSION_AT_LEAST_2_6 - and is_device(input_tensor.device.type, "cpu") + is_device(input_tensor.device.type, "cpu") and is_device(weight_tensor.device.type, "cpu") and (bias is None or is_device(bias.device.type, "cpu")) and not is_traceable_wrapper_subclass(input_tensor) @@ -300,9 +279,6 @@ def _linear_fp_act_uint4_weight_cpu_check(input_tensor, weight_tensor, bias): def _linear_fp_act_uint4_weight_cpu_impl(input_tensor, weight_tensor, bias): - assert TORCH_VERSION_AT_LEAST_2_6, ( - f"Requires PyTorch version at least 2.6, but got: {torch.__version__}" - ) assert is_device(input_tensor.device.type, "cpu"), ( f"For CPU device only but got: {input_tensor.device}" ) diff --git a/torchao/dtypes/uintx/packed_linear_int8_dynamic_activation_intx_weight_layout.py b/torchao/dtypes/uintx/packed_linear_int8_dynamic_activation_intx_weight_layout.py index dc7b073f32..fb75f3380b 100644 --- a/torchao/dtypes/uintx/packed_linear_int8_dynamic_activation_intx_weight_layout.py +++ b/torchao/dtypes/uintx/packed_linear_int8_dynamic_activation_intx_weight_layout.py @@ -19,7 +19,6 @@ _DTYPE_TO_QVALUE_BOUNDS, ZeroPointDomain, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_6 logger = logging.getLogger(__name__) logger.setLevel(logging.WARNING) @@ -170,9 +169,6 @@ def from_plain( if layout.target != Target.ATEN: _check_torchao_ops_loaded() else: - assert TORCH_VERSION_AT_LEAST_2_6, ( - "aten target is requires torch version > 2.6.0" - ) assert torch.backends.kleidiai.is_available(), ( "ATEN target requires torch.backends.kleidiai.is_available()" ) @@ -378,7 +374,6 @@ def _impl_2d_aten(input_tensor, weight_tensor): ) if target == Target.ATEN: - assert TORCH_VERSION_AT_LEAST_2_6 == 1, "Target.ATEN requires torch >= 2.6.0" _impl_2d = _impl_2d_aten else: _impl_2d = _impl_2d_non_aten @@ -420,11 +415,6 @@ def make_packed_linear_int8_dynamic_activation_intx_weight_tensor( Constructs an AffineQuantizedTensor with PackedLinearInt8DynamicActivationIntxWeightLayout from plain data. """ - # TORCH_VERSION_AT_LEAST_2_6 is needed for torch.intx with x < 8 - assert TORCH_VERSION_AT_LEAST_2_6, ( - "Using PackedLinearInt8DynamicActivationIntxWeightLayout requires torch version > 2.6.0" - ) - layout = PackedLinearInt8DynamicActivationIntxWeightLayout(target=target) bit_width = _DTYPE_TO_BIT_WIDTH[data_dtype] diff --git a/torchao/dtypes/uintx/tensor_core_tiled_layout.py b/torchao/dtypes/uintx/tensor_core_tiled_layout.py index 591d9a9be1..992294b766 100644 --- a/torchao/dtypes/uintx/tensor_core_tiled_layout.py +++ b/torchao/dtypes/uintx/tensor_core_tiled_layout.py @@ -24,7 +24,6 @@ _quantize_affine_tinygemm, ) from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, fill_defaults, find_multiple, ) @@ -274,14 +273,9 @@ def from_plain( ) def quant_2d(int_data_2d): - if TORCH_VERSION_AT_LEAST_2_5: - int_data_2d = (int_data_2d[::, ::2] << 4 | int_data_2d[::, 1::2]).to( - torch.uint8 - ) - else: - assert int_data_2d.dtype == torch.int32, ( - "torch.ops.aten._convert_weight_to_int4pack in torch 2.4 expects `int32` dtype" - ) + int_data_2d = (int_data_2d[::, ::2] << 4 | int_data_2d[::, 1::2]).to( + torch.uint8 + ) return torch.ops.aten._convert_weight_to_int4pack( int_data_2d.contiguous(), _layout.inner_k_tiles ) diff --git a/torchao/dtypes/uintx/uintx_layout.py b/torchao/dtypes/uintx/uintx_layout.py index 96e5401de5..3180e9f2c9 100644 --- a/torchao/dtypes/uintx/uintx_layout.py +++ b/torchao/dtypes/uintx/uintx_layout.py @@ -14,7 +14,7 @@ from torchao.dtypes.utils import ( Layout, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TorchAOBaseTensor +from torchao.utils import TorchAOBaseTensor from .bitpacking import pack, unpack @@ -24,20 +24,17 @@ _DTYPE_TO_BIT_WIDTH = {} _BIT_WIDTH_TO_DTYPE = {} -if TORCH_VERSION_AT_LEAST_2_3: - _DTYPE_TO_BIT_WIDTH = { - torch.uint1: 1, - torch.uint2: 2, - torch.uint3: 3, - torch.uint4: 4, - torch.uint5: 5, - torch.uint6: 6, - torch.uint7: 7, - } - - _BIT_WIDTH_TO_DTYPE = {v: k for k, v in _DTYPE_TO_BIT_WIDTH.items()} -else: - print("uintx feature requires torch 2.3+, please upgrade pytorch") +_DTYPE_TO_BIT_WIDTH = { + torch.uint1: 1, + torch.uint2: 2, + torch.uint3: 3, + torch.uint4: 4, + torch.uint5: 5, + torch.uint6: 6, + torch.uint7: 7, +} + +_BIT_WIDTH_TO_DTYPE = {v: k for k, v in _DTYPE_TO_BIT_WIDTH.items()} class UintxTensor(TorchAOBaseTensor): diff --git a/torchao/float8/README.md b/torchao/float8/README.md index 6b16b241c8..9c25f51a9a 100644 --- a/torchao/float8/README.md +++ b/torchao/float8/README.md @@ -28,10 +28,6 @@ import time import torch import torch.nn as nn from torchao.float8 import convert_to_float8_training, Float8LinearConfig -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 - -if not TORCH_VERSION_AT_LEAST_2_5: - raise AssertionError("torchao.float8 requires PyTorch version 2.5 or greater") # create model and sample input M, K, N = 4096, 8192, 4096 @@ -239,10 +235,6 @@ import torch.nn.functional as F from torchao.float8.float8_linear_utils import convert_to_float8_training from torchao.float8.float8_linear import Float8Linear from torchao.float8 import convert_to_float8_training -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 - -if not TORCH_VERSION_AT_LEAST_2_5: - raise AssertionError("torchao.float8 requires PyTorch version 2.5 or greater") # create model and sample input m = nn.Sequential( diff --git a/torchao/float8/__init__.py b/torchao/float8/__init__.py index 170d0ddd81..04589312a2 100644 --- a/torchao/float8/__init__.py +++ b/torchao/float8/__init__.py @@ -1,4 +1,7 @@ # Lets define a few top level things here +# Needed to load Float8TrainingTensor with weights_only = True +from torch.serialization import add_safe_globals + from torchao.float8.config import ( CastConfig, Float8GemmConfig, @@ -19,22 +22,17 @@ from torchao.float8.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp from torchao.float8.inference import Float8MMConfig from torchao.float8.types import FP8Granularity -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 - -if TORCH_VERSION_AT_LEAST_2_5: - # Needed to load Float8TrainingTensor with weights_only = True - from torch.serialization import add_safe_globals - add_safe_globals( - [ - Float8TrainingTensor, - ScaledMMConfig, - GemmInputRole, - LinearMMConfig, - Float8MMConfig, - ScalingGranularity, - ] - ) +add_safe_globals( + [ + Float8TrainingTensor, + ScaledMMConfig, + GemmInputRole, + LinearMMConfig, + Float8MMConfig, + ScalingGranularity, + ] +) __all__ = [ # configuration diff --git a/torchao/kernel/bsr_triton_ops.py b/torchao/kernel/bsr_triton_ops.py index 18cfba9ad9..4d80c4c577 100644 --- a/torchao/kernel/bsr_triton_ops.py +++ b/torchao/kernel/bsr_triton_ops.py @@ -9,15 +9,7 @@ from typing import Optional import torch - -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 - -if TORCH_VERSION_AT_LEAST_2_4: - from torch._dynamo.utils import warn_once -else: - import warnings - - warn_once = warnings.warn +from torch._dynamo.utils import warn_once from torch.sparse._triton_ops import ( broadcast_batch_dims, launch_kernel, diff --git a/torchao/kernel/intmm.py b/torchao/kernel/intmm.py index 2f064b3f2f..292b67380d 100644 --- a/torchao/kernel/intmm.py +++ b/torchao/kernel/intmm.py @@ -7,18 +7,16 @@ import os import torch +from torch._dynamo import is_compiling as dynamo_is_compiling +from torch._higher_order_ops.out_dtype import out_dtype -from torchao.utils import TORCH_VERSION_AT_LEAST_2_2, check_cpu_version +from torchao.utils import check_cpu_version logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) try: - # Only works for torch2.2 or newer. - if TORCH_VERSION_AT_LEAST_2_2: - from torchao.kernel import intmm_triton - else: - intmm_triton = None + from torchao.kernel import intmm_triton except ImportError: logger.warning( "Warning: Detected no triton, on systems without Triton certain kernels will not work" @@ -28,85 +26,63 @@ AUTOTUNER_ENABLE = bool(int(os.getenv("TORCHAO_AUTOTUNER_ENABLE", 0))) -# torch._int_mm doesn't exist before 2.2 -if TORCH_VERSION_AT_LEAST_2_2: - from torch._dynamo import is_compiling as dynamo_is_compiling - from torch._higher_order_ops.out_dtype import out_dtype - - def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: - """ - Performs a safe integer matrix multiplication, considering different paths for - torch.compile, cublas, and fallback cases. - - Args: - input (torch.Tensor): The input tensor of shape [i, j]. - mat2 (torch.Tensor): The matrix to multiply with, of shape [j, k]. - - Returns: - torch.Tensor: The result of the matrix multiplication. - - Raises: - AssertionError: If the tensors are not on the same device. - """ - # torch.compile path - if dynamo_is_compiling() or "FakeTensor" in input.__repr__(): - if input.device.type == "cpu": - # Matmul in int32 is slow on CPU and not supported well by Inductor cpp backend - return out_dtype( - torch.ops.aten.mm.default, torch.int32, input.float(), mat2.float() - ) - return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2) - - # error checking for cublas path - assert mat2.device == input.device, ( - f"need both tensors to be on the same device but got {mat2.device} and {input.device}" - ) - device_cpu = "cpu" in [mat2.device.type, input.device.type] - # with input.shape = [i,j] and mat2.shape = [j,k] - j_is_nonzero_multiple_of_8 = (input.shape[1] % 8 == 0) and (input.shape[1] > 0) - k_is_nonzero_multiple_of_8 = (mat2.shape[1] % 8 == 0) and (mat2.shape[1] > 0) - bad_dimensions_for_cublas = not ( - j_is_nonzero_multiple_of_8 and k_is_nonzero_multiple_of_8 - ) - if device_cpu or bad_dimensions_for_cublas: - # fallback path - return torch.matmul( - input.cpu().to(torch.int32), mat2.cpu().to(torch.int32) - ).to(input.device.type) - - # cublas paths - if not mat2.is_contiguous(): # silently gives incorrect result without this - mat2 = mat2.contiguous() - if (not input.is_contiguous()) and ( - input.shape[0] % 8 != 0 - ): # gives cryptic error without this - input = ( - input.contiguous() - ) # (it seems the transpose makes cublas check the above j constraint on i) - try: - return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2) - except Exception: - # fallback path, would run on H100 for float8 dtypes - # Exception on H100 float8 dtype : "addmm_cuda" not implemented for 'Float8_e4m3fn' - return torch.matmul(input.to(torch.float32), mat2.to(torch.float32)).to( - torch.int32 +def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: + """ + Performs a safe integer matrix multiplication, considering different paths for + torch.compile, cublas, and fallback cases. + + Args: + input (torch.Tensor): The input tensor of shape [i, j]. + mat2 (torch.Tensor): The matrix to multiply with, of shape [j, k]. + + Returns: + torch.Tensor: The result of the matrix multiplication. + + Raises: + AssertionError: If the tensors are not on the same device. + """ + # torch.compile path + if dynamo_is_compiling() or "FakeTensor" in input.__repr__(): + if input.device.type == "cpu": + # Matmul in int32 is slow on CPU and not supported well by Inductor cpp backend + return out_dtype( + torch.ops.aten.mm.default, torch.int32, input.float(), mat2.float() ) -else: + return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2) - def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: - """ - Performs a fallback integer matrix multiplication for torch versions before 2.2. + # error checking for cublas path + assert mat2.device == input.device, ( + f"need both tensors to be on the same device but got {mat2.device} and {input.device}" + ) + device_cpu = "cpu" in [mat2.device.type, input.device.type] + # with input.shape = [i,j] and mat2.shape = [j,k] + j_is_nonzero_multiple_of_8 = (input.shape[1] % 8 == 0) and (input.shape[1] > 0) + k_is_nonzero_multiple_of_8 = (mat2.shape[1] % 8 == 0) and (mat2.shape[1] > 0) + bad_dimensions_for_cublas = not ( + j_is_nonzero_multiple_of_8 and k_is_nonzero_multiple_of_8 + ) - Args: - input (torch.Tensor): The input tensor of shape [i, j]. - mat2 (torch.Tensor): The matrix to multiply with, of shape [j, k]. + if device_cpu or bad_dimensions_for_cublas: + # fallback path + return torch.matmul(input.cpu().to(torch.int32), mat2.cpu().to(torch.int32)).to( + input.device.type + ) - Returns: - torch.Tensor: The result of the matrix multiplication in int32. - """ - # We can improve on this by writing Triton code that works for older versions of Triton - # that ship with 2.1 or 2.0. + # cublas paths + if not mat2.is_contiguous(): # silently gives incorrect result without this + mat2 = mat2.contiguous() + if (not input.is_contiguous()) and ( + input.shape[0] % 8 != 0 + ): # gives cryptic error without this + input = ( + input.contiguous() + ) # (it seems the transpose makes cublas check the above j constraint on i) + try: + return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2) + except Exception: + # fallback path, would run on H100 for float8 dtypes + # Exception on H100 float8 dtype : "addmm_cuda" not implemented for 'Float8_e4m3fn' return torch.matmul(input.to(torch.float32), mat2.to(torch.float32)).to( torch.int32 ) diff --git a/torchao/kernel/intmm_triton.py b/torchao/kernel/intmm_triton.py index 1a516a7163..6f657cdfd8 100644 --- a/torchao/kernel/intmm_triton.py +++ b/torchao/kernel/intmm_triton.py @@ -10,7 +10,6 @@ import triton.language as tl from torchao.kernel.autotuner import get_best_config_fn -from torchao.utils import TORCH_VERSION_AFTER_2_5 # TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_SEARCH_SPACE=EXHAUSTIVE to enable exhaustive option int8_mm_kernel_configs = sum( @@ -38,16 +37,15 @@ [], ) -if TORCH_VERSION_AFTER_2_5: - if torch._inductor.config.max_autotune_gemm_search_space == "EXHAUSTIVE": - int8_mm_kernel_configs = [ - (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps) - for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product( - [16, 32, 64, 128, 256], repeat=3 - ) - for num_stages in [1, 2, 3, 4, 5, 6, 7, 8] - for num_warps in [2, 4, 8] - ] +if torch._inductor.config.max_autotune_gemm_search_space == "EXHAUSTIVE": + int8_mm_kernel_configs = [ + (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps) + for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product( + [16, 32, 64, 128, 256], repeat=3 + ) + for num_stages in [1, 2, 3, 4, 5, 6, 7, 8] + for num_warps in [2, 4, 8] + ] # Baseline configs from pytorch/pytorch diff --git a/torchao/ops.py b/torchao/ops.py index babe5506c0..4b643cae98 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -9,8 +9,6 @@ import torch from torch import Tensor -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 - lib = torch.library.Library("torchao", "FRAGMENT") lib.define( "quant_llm_linear(int EXPONENT, int MANTISSA, Tensor _in_feats, Tensor _weights, Tensor _scales, int splitK) -> Tensor" @@ -74,20 +72,14 @@ def register_custom_op(name): def decorator(func): - if TORCH_VERSION_AT_LEAST_2_4: - return torch.library.register_fake(f"{name}")(func) - else: - return torch.library.impl_abstract(f"{name}")(func) + return torch.library.register_fake(f"{name}")(func) return decorator def register_custom_op_impl(name): def decorator(func): - if TORCH_VERSION_AT_LEAST_2_4: - return torch.library.custom_op(f"{name}", mutates_args=())(func) - else: - return torch.library.impl(f"{name}", "CUDA")(func) + return torch.library.custom_op(f"{name}", mutates_args=())(func) return decorator diff --git a/torchao/optim/cpu_offload.py b/torchao/optim/cpu_offload.py index cca55749db..53acd4057f 100644 --- a/torchao/optim/cpu_offload.py +++ b/torchao/optim/cpu_offload.py @@ -8,7 +8,7 @@ import torch from torch.optim.optimizer import Optimizer, ParamsT -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, get_available_devices +from torchao.utils import get_available_devices # NOTE: We make this inherit Optimizer so it works with PyTorch's built-in LR @@ -36,11 +36,7 @@ def __init__( kwargs: other keyword arguments to be passed to the base optimizer e.g. `lr`, `weight_decay`. """ # default to fused CPU AdamW - if ( - optimizer_class is torch.optim.AdamW - and TORCH_VERSION_AT_LEAST_2_4 - and "fused" not in kwargs - ): + if optimizer_class is torch.optim.AdamW and "fused" not in kwargs: kwargs.update(fused=True) param_groups = list(params) diff --git a/torchao/optim/subclass_4bit.py b/torchao/optim/subclass_4bit.py index bc5fd33414..82bb6a3788 100644 --- a/torchao/optim/subclass_4bit.py +++ b/torchao/optim/subclass_4bit.py @@ -7,13 +7,10 @@ import torch from torch import Tensor +from torch.serialization import add_safe_globals from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_4, - TORCH_VERSION_AT_LEAST_2_5, - TorchAOBaseTensor, -) +from torchao.utils import TorchAOBaseTensor from .quant_utils import ( create_dynamic_map, @@ -113,25 +110,6 @@ def __repr__(self): ) -# in pre-2.4, calling .to(device, dtype) will not dispatch aten._to_copy.default when -# dtype is the same but device is different. thus, we must override .to() method instead. -if not TORCH_VERSION_AT_LEAST_2_4: - - def _to(self, *args, **kwargs): - # ignore other args/kwargs - device = kwargs.pop("device", None) - return OptimState4bit( - self.codes.to(device), - self.scale.to(device), - self.qmap.to(device), - self.signed, - self.shape, - ) - - OptimState4bit.to = _to - del _to # make sure to not re-use - - @OptimState4bit.implements(aten.copy_.default) def _(func, types, args, kwargs): dst = args[0] @@ -268,7 +246,4 @@ def _(func, types, args, kwargs): return OptimState4bit(codes, scale, x.qmap.clone(), x.signed, shape) -if TORCH_VERSION_AT_LEAST_2_5: - from torch.serialization import add_safe_globals - - add_safe_globals([OptimState4bit]) +add_safe_globals([OptimState4bit]) diff --git a/torchao/optim/subclass_8bit.py b/torchao/optim/subclass_8bit.py index d3f7634526..bbc6cfa958 100644 --- a/torchao/optim/subclass_8bit.py +++ b/torchao/optim/subclass_8bit.py @@ -7,13 +7,10 @@ import torch from torch import Tensor +from torch.serialization import add_safe_globals from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_4, - TORCH_VERSION_AT_LEAST_2_5, - TorchAOBaseTensor, -) +from torchao.utils import TorchAOBaseTensor from .quant_utils import ( create_dynamic_map, @@ -101,24 +98,6 @@ def __repr__(self): ) -# in pre-2.4, calling .to(device, dtype) will not dispatch aten._to_copy.default when -# dtype is the same but device is different. thus, we must override .to() method instead. -if not TORCH_VERSION_AT_LEAST_2_4: - - def _to(self, *args, **kwargs): - # ignore other args/kwargs - device = kwargs.pop("device", None) - return OptimState8bit( - self.codes.to(device), - self.scale.to(device), - self.qmap.to(device), - self.signed, - ) - - OptimState8bit.to = _to - del _to # make sure to not re-use - - @OptimState8bit.implements(aten.copy_.default) def _(func, types, args, kwargs): dst = args[0] @@ -237,7 +216,4 @@ def _(func, types, args, kwargs): ) -if TORCH_VERSION_AT_LEAST_2_5: - from torch.serialization import add_safe_globals - - add_safe_globals([OptimState8bit]) +add_safe_globals([OptimState8bit]) diff --git a/torchao/optim/subclass_fp8.py b/torchao/optim/subclass_fp8.py index 1ae670dd6d..e898932138 100644 --- a/torchao/optim/subclass_fp8.py +++ b/torchao/optim/subclass_fp8.py @@ -7,9 +7,10 @@ import torch from torch import Tensor +from torch.serialization import add_safe_globals from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TorchAOBaseTensor +from torchao.utils import TorchAOBaseTensor aten = torch.ops.aten c10d_functional = torch.ops.c10d_functional @@ -192,7 +193,4 @@ def _(func, types, args, kwargs): ) -if TORCH_VERSION_AT_LEAST_2_5: - from torch.serialization import add_safe_globals - - add_safe_globals([OptimStateFp8]) +add_safe_globals([OptimStateFp8]) diff --git a/torchao/prototype/autoround/eval_autoround.py b/torchao/prototype/autoround/eval_autoround.py index 16c1736843..04864e546a 100644 --- a/torchao/prototype/autoround/eval_autoround.py +++ b/torchao/prototype/autoround/eval_autoround.py @@ -12,7 +12,6 @@ import torchao import torchao.prototype.autoround.utils as ar_utils import torchao.quantization -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 logger = logging.getLogger(__name__) @@ -165,7 +164,7 @@ def main(args): bench_accuracy(model, tokenizer, tasks=args.tasks, msg=msg) -if __name__ == "__main__" and TORCH_VERSION_AT_LEAST_2_5 and torch.cuda.is_available(): +if __name__ == "__main__" and torch.cuda.is_available(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter ) diff --git a/torchao/prototype/float8nocompile/examples/example.py b/torchao/prototype/float8nocompile/examples/example.py index 97d42eee90..1351e2c938 100644 --- a/torchao/prototype/float8nocompile/examples/example.py +++ b/torchao/prototype/float8nocompile/examples/example.py @@ -9,10 +9,6 @@ from torchao.prototype.float8nocompile.float8nocompile_linear_utils import ( convert_to_float8_nocompile_training, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 - -if not TORCH_VERSION_AT_LEAST_2_5: - raise AssertionError("torchao.float8 requires PyTorch version 2.5 or greater") # create model and sample input m = ( diff --git a/torchao/prototype/float8nocompile/test/fsdp_test.py b/torchao/prototype/float8nocompile/test/fsdp_test.py index 4e73fb9b97..375e48311d 100644 --- a/torchao/prototype/float8nocompile/test/fsdp_test.py +++ b/torchao/prototype/float8nocompile/test/fsdp_test.py @@ -22,10 +22,6 @@ from torchao.prototype.float8nocompile.float8nocompile_linear_utils import ( convert_to_float8_nocompile_training, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 - -if not TORCH_VERSION_AT_LEAST_2_5: - raise AssertionError("torchao.float8 requires PyTorch version 2.5 or greater") class TestModel(nn.Module): diff --git a/torchao/prototype/float8nocompile/test/train_test.py b/torchao/prototype/float8nocompile/test/train_test.py index 3f2ee47cd7..aceca5b400 100644 --- a/torchao/prototype/float8nocompile/test/train_test.py +++ b/torchao/prototype/float8nocompile/test/train_test.py @@ -11,10 +11,6 @@ from torchao.prototype.float8nocompile.float8nocompile_linear_utils import ( convert_to_float8_nocompile_training, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 - -if not TORCH_VERSION_AT_LEAST_2_5: - raise AssertionError("torchao.float8 requires PyTorch version 2.5 or greater") class TestModel(nn.Module): diff --git a/torchao/prototype/hqq/hqq_tinygemm_linear.py b/torchao/prototype/hqq/hqq_tinygemm_linear.py index f15c9a8104..8f049b431b 100644 --- a/torchao/prototype/hqq/hqq_tinygemm_linear.py +++ b/torchao/prototype/hqq/hqq_tinygemm_linear.py @@ -17,7 +17,7 @@ from torch import Tensor, nn from torchao.dtypes.utils import is_device -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, check_cpu_version +from torchao.utils import check_cpu_version class HQQLinearTorchWeightOnlyInt4(torch.nn.Module): @@ -209,9 +209,8 @@ def hqq_quants_to_torch_quants( .reshape(shape) .contiguous() ) - if TORCH_VERSION_AT_LEAST_2_5: - if not is_device(W_q.device.type, "cpu"): - W_q = (W_q[::, ::2] << 4 | W_q[::, 1::2]).to(torch.uint8) + if not is_device(W_q.device.type, "cpu"): + W_q = (W_q[::, ::2] << 4 | W_q[::, 1::2]).to(torch.uint8) # group_dequantize_tensor_from_qparams # W_r = W_q*scales + min_val diff --git a/torchao/prototype/mx_formats/inference_workflow.py b/torchao/prototype/mx_formats/inference_workflow.py index 133cedee74..96c4c6c73b 100644 --- a/torchao/prototype/mx_formats/inference_workflow.py +++ b/torchao/prototype/mx_formats/inference_workflow.py @@ -25,7 +25,6 @@ register_quantize_module_handler, ) from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_8, is_sm_at_least_100, ) @@ -213,16 +212,15 @@ def _nvfp4_inference_linear_transform( return module -if TORCH_VERSION_AT_LEAST_2_5: - torch.serialization.add_safe_globals( - [ - MXTensor, - NVFP4Tensor, - NVFP4MMConfig, - MXGemmKernelChoice, - _input_activation_quant_func_mxfp, - ] - ) +torch.serialization.add_safe_globals( + [ + MXTensor, + NVFP4Tensor, + NVFP4MMConfig, + MXGemmKernelChoice, + _input_activation_quant_func_mxfp, + ] +) import torch.nn as nn diff --git a/torchao/prototype/mx_formats/kernels.py b/torchao/prototype/mx_formats/kernels.py index f506681223..cabb61276a 100644 --- a/torchao/prototype/mx_formats/kernels.py +++ b/torchao/prototype/mx_formats/kernels.py @@ -17,7 +17,6 @@ _floatx_unpacked_to_f32, ) from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_7, is_sm_at_least_100, ) @@ -25,7 +24,7 @@ # TODO(future): if needed, make the below work on previous PyTorch versions, # just need to hunt down the previous location of `libdevice`. An assert # at the callsite prevents usage of this on unsupported versions. -if TORCH_VERSION_AT_LEAST_2_4 and has_triton(): +if has_triton(): from torch._inductor.runtime.triton_helpers import libdevice from torchao.prototype.mx_formats.constants import ( @@ -752,7 +751,6 @@ def triton_f4_to_scaled_bf16( Output: a tensor of bfloat16 values, multiplied by the encoded scale """ s_e8m0 = s_e8m0.view(torch.uint8) - assert TORCH_VERSION_AT_LEAST_2_4, "unsupported" new_shape = (*x.shape[:-1], x.shape[-1] * 2) output = torch.empty(*new_shape, device=x.device, dtype=torch.bfloat16) assert x.is_contiguous() @@ -855,119 +853,104 @@ def triton_f6_e3m2_to_bf16(x: torch.Tensor) -> torch.Tensor: return output -if TORCH_VERSION_AT_LEAST_2_4: - - @torch.library.custom_op("ao::triton_f6_e2m3_to_scaled_bf16", mutates_args=()) - def triton_f6_e2m3_to_scaled_bf16( - x: torch.Tensor, - s_e8m0: torch.Tensor, - mx_block_size: int, - ) -> torch.Tensor: - """ - Input: a tensor of packed fp6 values, and a scale in e8m0 format. The block - size is currently assumed to be 32. - Output: a tensor of bfloat16 values, multiplied by the encoded scale - """ - s_e8m0 = s_e8m0.view(torch.uint8) +@torch.library.custom_op("ao::triton_f6_e2m3_to_scaled_bf16", mutates_args=()) +def triton_f6_e2m3_to_scaled_bf16( + x: torch.Tensor, + s_e8m0: torch.Tensor, + mx_block_size: int, +) -> torch.Tensor: + """ + Input: a tensor of packed fp6 values, and a scale in e8m0 format. The block + size is currently assumed to be 32. + Output: a tensor of bfloat16 values, multiplied by the encoded scale + """ + s_e8m0 = s_e8m0.view(torch.uint8) - packed_mx_block_size = 3 * mx_block_size // 4 + packed_mx_block_size = 3 * mx_block_size // 4 - x = x.view(-1, packed_mx_block_size) - new_shape = (x.numel() // packed_mx_block_size, mx_block_size) + x = x.view(-1, packed_mx_block_size) + new_shape = (x.numel() // packed_mx_block_size, mx_block_size) - output = torch.empty(*new_shape, device=x.device, dtype=torch.bfloat16) + output = torch.empty(*new_shape, device=x.device, dtype=torch.bfloat16) - assert x.is_contiguous() - assert x.is_cuda and output.is_cuda + assert x.is_contiguous() + assert x.is_cuda and output.is_cuda - n_mx_blocks = x.shape[0] - grid = lambda meta: (triton.cdiv(n_mx_blocks, meta["BLOCK_SIZE_IN"]),) - triton_f6_to_scaled_bf16_kernel[grid]( - x, - s_e8m0, - output, - n_mx_blocks, - mx_block_size, - packed_mx_block_size, - sign_mask_f6=SIGN_MASK_F6_E2M3, - mbits_f6=MBITS_F6_E2M3, - f6_exp_bias=F6_E2M3_EXP_BIAS, - mbits_f32=MBITS_F32, - f32_exp_bias=F32_EXP_BIAS, - e8m0_exponent_bias=E8M0_EXPONENT_BIAS, - e8m0_exponent_nan_val=E8M0_EXPONENT_NAN_VAL, - ) - return output + n_mx_blocks = x.shape[0] + grid = lambda meta: (triton.cdiv(n_mx_blocks, meta["BLOCK_SIZE_IN"]),) + triton_f6_to_scaled_bf16_kernel[grid]( + x, + s_e8m0, + output, + n_mx_blocks, + mx_block_size, + packed_mx_block_size, + sign_mask_f6=SIGN_MASK_F6_E2M3, + mbits_f6=MBITS_F6_E2M3, + f6_exp_bias=F6_E2M3_EXP_BIAS, + mbits_f32=MBITS_F32, + f32_exp_bias=F32_EXP_BIAS, + e8m0_exponent_bias=E8M0_EXPONENT_BIAS, + e8m0_exponent_nan_val=E8M0_EXPONENT_NAN_VAL, + ) + return output - @torch.library.custom_op("ao::triton_f6_e3m2_to_scaled_bf16", mutates_args=()) - def triton_f6_e3m2_to_scaled_bf16( - x: torch.Tensor, - s_e8m0: torch.Tensor, - mx_block_size: int, - ) -> torch.Tensor: - """ - Input: a tensor of packed fp6 values, and a scale in e8m0 format. The block - size is currently assumed to be 32. - Output: a tensor of bfloat16 values, multiplied by the encoded scale - """ - s_e8m0 = s_e8m0.view(torch.uint8) - packed_mx_block_size = 3 * mx_block_size // 4 +@torch.library.custom_op("ao::triton_f6_e3m2_to_scaled_bf16", mutates_args=()) +def triton_f6_e3m2_to_scaled_bf16( + x: torch.Tensor, + s_e8m0: torch.Tensor, + mx_block_size: int, +) -> torch.Tensor: + """ + Input: a tensor of packed fp6 values, and a scale in e8m0 format. The block + size is currently assumed to be 32. + Output: a tensor of bfloat16 values, multiplied by the encoded scale + """ + s_e8m0 = s_e8m0.view(torch.uint8) - x = x.view(-1, packed_mx_block_size) - new_shape = (x.numel() // packed_mx_block_size, mx_block_size) + packed_mx_block_size = 3 * mx_block_size // 4 - output = torch.empty(*new_shape, device=x.device, dtype=torch.bfloat16) + x = x.view(-1, packed_mx_block_size) + new_shape = (x.numel() // packed_mx_block_size, mx_block_size) - assert x.is_contiguous() - assert x.is_cuda and output.is_cuda + output = torch.empty(*new_shape, device=x.device, dtype=torch.bfloat16) - n_mx_blocks = x.numel() // packed_mx_block_size - grid = lambda meta: (triton.cdiv(n_mx_blocks, meta["BLOCK_SIZE_IN"]),) - triton_f6_to_scaled_bf16_kernel[grid]( - x, - s_e8m0, - output, - n_mx_blocks, - mx_block_size, - packed_mx_block_size, - sign_mask_f6=SIGN_MASK_F6_E3M2, - mbits_f6=MBITS_F6_E3M2, - f6_exp_bias=F6_E3M2_EXP_BIAS, - mbits_f32=MBITS_F32, - f32_exp_bias=F32_EXP_BIAS, - e8m0_exponent_bias=E8M0_EXPONENT_BIAS, - e8m0_exponent_nan_val=E8M0_EXPONENT_NAN_VAL, - ) - return output + assert x.is_contiguous() + assert x.is_cuda and output.is_cuda - @triton_f6_e3m2_to_scaled_bf16.register_fake - def _(x, s_e8m0, mx_block_size): - _padded_mx_block_size = 3 * mx_block_size // 4 - out_shape = (x.numel() // _padded_mx_block_size, mx_block_size) - return torch.empty(*out_shape, device=x.device, dtype=torch.bfloat16) + n_mx_blocks = x.numel() // packed_mx_block_size + grid = lambda meta: (triton.cdiv(n_mx_blocks, meta["BLOCK_SIZE_IN"]),) + triton_f6_to_scaled_bf16_kernel[grid]( + x, + s_e8m0, + output, + n_mx_blocks, + mx_block_size, + packed_mx_block_size, + sign_mask_f6=SIGN_MASK_F6_E3M2, + mbits_f6=MBITS_F6_E3M2, + f6_exp_bias=F6_E3M2_EXP_BIAS, + mbits_f32=MBITS_F32, + f32_exp_bias=F32_EXP_BIAS, + e8m0_exponent_bias=E8M0_EXPONENT_BIAS, + e8m0_exponent_nan_val=E8M0_EXPONENT_NAN_VAL, + ) + return output - @triton_f6_e2m3_to_scaled_bf16.register_fake - def _(x, s_e8m0, mx_block_size): - _padded_mx_block_size = 3 * mx_block_size // 4 - out_shape = (x.numel() // _padded_mx_block_size, mx_block_size) - return torch.empty(*out_shape, device=x.device, dtype=torch.bfloat16) -else: +@triton_f6_e3m2_to_scaled_bf16.register_fake +def _(x, s_e8m0, mx_block_size): + _padded_mx_block_size = 3 * mx_block_size // 4 + out_shape = (x.numel() // _padded_mx_block_size, mx_block_size) + return torch.empty(*out_shape, device=x.device, dtype=torch.bfloat16) - def triton_f6_e2m3_to_scaled_bf16( - x: torch.Tensor, - s_e8m0: torch.Tensor, - mx_block_size: int, - ) -> torch.Tensor: - raise AssertionError("unsupported without torch >= 2.4") - def triton_f6_e3m2_to_scaled_bf16( - x: torch.Tensor, - s_e8m0: torch.Tensor, - mx_block_size: int, - ) -> torch.Tensor: - raise AssertionError("unsupported without torch >= 2.4") +@triton_f6_e2m3_to_scaled_bf16.register_fake +def _(x, s_e8m0, mx_block_size): + _padded_mx_block_size = 3 * mx_block_size // 4 + out_shape = (x.numel() // _padded_mx_block_size, mx_block_size) + return torch.empty(*out_shape, device=x.device, dtype=torch.bfloat16) # pack/unpack code copy-pasted from @@ -1049,48 +1032,42 @@ def pack_uint6_pytorch(uint8_data: torch.Tensor) -> torch.Tensor: ).view(packed_shape) -if TORCH_VERSION_AT_LEAST_2_4: - - @torch.library.custom_op("ao::pack_uint6", mutates_args=()) - def pack_uint6(uint8_data: torch.Tensor) -> torch.Tensor: - # ensure input data is contiguous before passing to kernel - assert uint8_data.is_contiguous() +@torch.library.custom_op("ao::pack_uint6", mutates_args=()) +def pack_uint6(uint8_data: torch.Tensor) -> torch.Tensor: + # ensure input data is contiguous before passing to kernel + assert uint8_data.is_contiguous() - # tensor should already be of shape [..., mx_block_size] - mx_block_size = uint8_data.shape[-1] - assert mx_block_size % 4 == 0 + # tensor should already be of shape [..., mx_block_size] + mx_block_size = uint8_data.shape[-1] + assert mx_block_size % 4 == 0 - # effective mx block size since we're packing 2 fp4 into 1 uint8 - packed_mx_block_size = 3 * mx_block_size // 4 - packed_shape = [*uint8_data.shape[:-1], packed_mx_block_size] - n_mx_blocks = uint8_data.numel() // mx_block_size + # effective mx block size since we're packing 2 fp4 into 1 uint8 + packed_mx_block_size = 3 * mx_block_size // 4 + packed_shape = [*uint8_data.shape[:-1], packed_mx_block_size] + n_mx_blocks = uint8_data.numel() // mx_block_size - grid = lambda meta: (triton.cdiv(n_mx_blocks, meta["BLOCK_SIZE_IN"]),) + grid = lambda meta: (triton.cdiv(n_mx_blocks, meta["BLOCK_SIZE_IN"]),) - # contiguous uint8 container in which we can store the unpacked tensor - packed_uint8_data = torch.empty( - packed_shape, dtype=torch.uint8, device=uint8_data.device - ) + # contiguous uint8 container in which we can store the unpacked tensor + packed_uint8_data = torch.empty( + packed_shape, dtype=torch.uint8, device=uint8_data.device + ) - triton_pack_uint6_kernel[grid]( - uint8_data, - packed_uint8_data, - n_mx_blocks, - MX_BLOCK_SIZE=mx_block_size, - PACKED_MX_BLOCK_SIZE=packed_mx_block_size, - ) + triton_pack_uint6_kernel[grid]( + uint8_data, + packed_uint8_data, + n_mx_blocks, + MX_BLOCK_SIZE=mx_block_size, + PACKED_MX_BLOCK_SIZE=packed_mx_block_size, + ) - return packed_uint8_data + return packed_uint8_data - @pack_uint6.register_fake - def _(uint8_data): - out_shape = (*uint8_data.shape[:-1], 3 * uint8_data.shape[-1] // 4) - return torch.empty(*out_shape, device=uint8_data.device, dtype=torch.uint8) -else: - def pack_uint6(uint8_data: torch.Tensor) -> torch.Tensor: - # Dummy placeholder op for torch < 2.4 - raise AssertionError("fp6 packing unsupported without torch >= 2.4") +@pack_uint6.register_fake +def _(uint8_data): + out_shape = (*uint8_data.shape[:-1], 3 * uint8_data.shape[-1] // 4) + return torch.empty(*out_shape, device=uint8_data.device, dtype=torch.uint8) if TORCH_VERSION_AT_LEAST_2_7 and has_triton(): diff --git a/torchao/prototype/quantization/autoquant_v2.py b/torchao/prototype/quantization/autoquant_v2.py index 9ddfddda08..1240bbacd0 100644 --- a/torchao/prototype/quantization/autoquant_v2.py +++ b/torchao/prototype/quantization/autoquant_v2.py @@ -47,8 +47,6 @@ ) from torchao.quantization.utils import _quantize_activation_per_token_absmax from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_3, - TORCH_VERSION_AT_LEAST_2_5, TorchAOBaseTensor, is_sm_at_least_89, is_sm_at_least_90, @@ -469,6 +467,8 @@ def do_autoquant_bench(op, *args, **kwargs): """ runs benchmark op(*args, **kwargs) avoiding torch.compile overhead """ + from torch._inductor.runtime.benchmarking import benchmarker + rep = kwargs.pop("rep", 100) warmup = kwargs.pop("warmup", 25) with torch.no_grad(): @@ -483,24 +483,9 @@ def do_autoquant_bench(op, *args, **kwargs): graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph, stream=stream): op(*args, **kwargs) - if TORCH_VERSION_AT_LEAST_2_5: - from torch._inductor.runtime.benchmarking import benchmarker - - res = benchmarker.benchmark_gpu( - lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median" - ) - elif TORCH_VERSION_AT_LEAST_2_3: - from torch._inductor.runtime.runtime_utils import do_bench_gpu - - res = do_bench_gpu( - lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median" - ) - else: - from torch._inductor.utils import do_bench - - res = do_bench( - lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median" - ) + res = benchmarker.benchmark_gpu( + lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median" + ) return res diff --git a/torchao/prototype/quantization/dynamic_activation_lut/int8_dynamic_activation_lut_tensor.py b/torchao/prototype/quantization/dynamic_activation_lut/int8_dynamic_activation_lut_tensor.py index c2e995e942..a15ea944fd 100644 --- a/torchao/prototype/quantization/dynamic_activation_lut/int8_dynamic_activation_lut_tensor.py +++ b/torchao/prototype/quantization/dynamic_activation_lut/int8_dynamic_activation_lut_tensor.py @@ -9,10 +9,7 @@ from torch.utils._python_dispatch import return_and_correct_aliasing from torchao.quantization.quant_primitives import _DTYPE_TO_QVALUE_BOUNDS -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, - TorchAOBaseTensor, -) +from torchao.utils import TorchAOBaseTensor aten = torch.ops.aten @@ -231,6 +228,5 @@ def _(func, types, args, kwargs): ) -if TORCH_VERSION_AT_LEAST_2_5: - # Allow a model with Int8DynamicActivationLutTensor weights to be loaded with `weights_only=True` - torch.serialization.add_safe_globals([Int8DynamicActivationLutTensor]) +# Allow a model with Int8DynamicActivationLutTensor weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals([Int8DynamicActivationLutTensor]) diff --git a/torchao/prototype/quantization/gguf/gguf_quantized_tensor.py b/torchao/prototype/quantization/gguf/gguf_quantized_tensor.py index c1272fceb6..f26083b90d 100644 --- a/torchao/prototype/quantization/gguf/gguf_quantized_tensor.py +++ b/torchao/prototype/quantization/gguf/gguf_quantized_tensor.py @@ -14,10 +14,7 @@ _dequantize_gguf, _quantize_gguf, ) -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, - TorchAOBaseTensor, -) +from torchao.utils import TorchAOBaseTensor _QK_K = 256 aten = torch.ops.aten @@ -267,6 +264,5 @@ def _(func, types, args, kwargs): return torch.nn.functional.linear(input_tensor, weight_tensor, bias) -if TORCH_VERSION_AT_LEAST_2_5: - # Allow a model with GGUFQuantizedTensor weights to be loaded with `weights_only=True` - torch.serialization.add_safe_globals([GGUFQuantizedTensor]) +# Allow a model with GGUFQuantizedTensor weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals([GGUFQuantizedTensor]) diff --git a/torchao/prototype/spinquant/hadamard_utils.py b/torchao/prototype/spinquant/hadamard_utils.py index 515a38ad83..0b276a0d03 100644 --- a/torchao/prototype/spinquant/hadamard_utils.py +++ b/torchao/prototype/spinquant/hadamard_utils.py @@ -11,7 +11,6 @@ import torch -from torchao.ops import lib from torchao.prototype.spinquant._hadamard_matrices import ( get_had12, get_had20, @@ -26,7 +25,6 @@ get_had156, get_had172, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 try: from fast_hadamard_transform import hadamard_transform as _fast_hadamard_transform @@ -50,21 +48,14 @@ def matmul_hadU(X, hadK, K): def register_custom_op_impl(name): def decorator(func): - if TORCH_VERSION_AT_LEAST_2_4: - return torch.library.custom_op(f"{name}", mutates_args=())(func) - else: - lib.define("hadamard_transform(Tensor x, float scale = 0.0) -> Tensor") - return torch.library.impl(f"{name}", "cuda")(func) + return torch.library.custom_op(f"{name}", mutates_args=())(func) return decorator def register_custom_op_abstract(name): def decorator(func): - if TORCH_VERSION_AT_LEAST_2_4: - return torch.library.register_fake(f"{name}")(func) - else: - return torch.library.impl_abstract(f"{name}")(func) + return torch.library.register_fake(f"{name}")(func) return decorator diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index 47ecb9aabe..fa0293bf82 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -304,12 +304,6 @@ quantize_(m, Int4WeightOnlyConfig(group_size=group_size)) ## If different zero_point_domain needed # quantize_(m, Int4WeightOnlyConfig(group_size=group_size, zero_point_domain=ZeroPointDomain.FLOAT)) -# temporary workaround for tensor subclass + torch.compile -# NOTE: this is only need for torch version < 2.5+ -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 -from torchao.utils import unwrap_tensor_subclass -if not TORCH_VERSION_AT_LEAST_2_5: - unwrap_tensor_subclass(m) # compile the model to improve performance m = torch.compile(m, mode='max-autotune') diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index cf3fbad6ad..5745f00e99 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -31,8 +31,6 @@ compute_error, ) from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_3, - TORCH_VERSION_AT_LEAST_2_5, TorchAOBaseTensor, is_sm_at_least_89, is_sm_at_least_90, @@ -329,6 +327,8 @@ def do_autoquant_bench(op, *args, **kwargs): """ runs benchmark op(*args, **kwargs) avoiding torch.compile overhead """ + from torch._inductor.runtime.benchmarking import benchmarker + rep = kwargs.pop("rep", 100) warmup = kwargs.pop("warmup", 25) with torch.no_grad(): @@ -343,24 +343,9 @@ def do_autoquant_bench(op, *args, **kwargs): graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph, stream=stream): op(*args, **kwargs) - if TORCH_VERSION_AT_LEAST_2_5: - from torch._inductor.runtime.benchmarking import benchmarker - - res = benchmarker.benchmark_gpu( - lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median" - ) - elif TORCH_VERSION_AT_LEAST_2_3: - from torch._inductor.runtime.runtime_utils import do_bench_gpu - - res = do_bench_gpu( - lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median" - ) - else: - from torch._inductor.utils import do_bench - - res = do_bench( - lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median" - ) + res = benchmarker.benchmark_gpu( + lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median" + ) return res @@ -1346,12 +1331,11 @@ def finalize_autoquant(): return model -if TORCH_VERSION_AT_LEAST_2_5: - torch.serialization.add_safe_globals(ALL_AUTOQUANT_CLASS_LIST) - torch.serialization.add_safe_globals( - [ - _to_float16, - _to_bfloat16, - _identity, - ] - ) +torch.serialization.add_safe_globals(ALL_AUTOQUANT_CLASS_LIST) +torch.serialization.add_safe_globals( + [ + _to_float16, + _to_bfloat16, + _identity, + ] +) diff --git a/torchao/quantization/linear_activation_quantized_tensor.py b/torchao/quantization/linear_activation_quantized_tensor.py index 658b172994..cbeb9cdb6f 100644 --- a/torchao/quantization/linear_activation_quantized_tensor.py +++ b/torchao/quantization/linear_activation_quantized_tensor.py @@ -8,10 +8,7 @@ import torch from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, - TorchAOBaseTensor, -) +from torchao.utils import TorchAOBaseTensor __all__ = [ "LinearActivationQuantizedTensor", @@ -290,6 +287,5 @@ def _(func, types, args, kwargs): to_linear_activation_quantized = LinearActivationQuantizedTensor.from_float # Converts a float tensor to LinearActivationQuantizedTensor for dynamic activation quantization -if TORCH_VERSION_AT_LEAST_2_5: - # Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True` - torch.serialization.add_safe_globals([LinearActivationQuantizedTensor]) +# Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals([LinearActivationQuantizedTensor]) diff --git a/torchao/quantization/linear_activation_scale.py b/torchao/quantization/linear_activation_scale.py index 005bc8d32d..500228cf3c 100644 --- a/torchao/quantization/linear_activation_scale.py +++ b/torchao/quantization/linear_activation_scale.py @@ -6,10 +6,7 @@ import torch from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, - TorchAOBaseTensor, -) +from torchao.utils import TorchAOBaseTensor __all__ = [ "WeightTensorWithLinearActivationScaleMetadata", @@ -119,8 +116,5 @@ def _(func, types, args, kwargs): WeightTensorWithLinearActivationScaleMetadata.from_float ) -if TORCH_VERSION_AT_LEAST_2_5: - # Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True` - torch.serialization.add_safe_globals( - [WeightTensorWithLinearActivationScaleMetadata] - ) +# Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals([WeightTensorWithLinearActivationScaleMetadata]) diff --git a/torchao/quantization/linear_activation_weight_observed_tensor.py b/torchao/quantization/linear_activation_weight_observed_tensor.py index 029b89e54b..d17bc382db 100644 --- a/torchao/quantization/linear_activation_weight_observed_tensor.py +++ b/torchao/quantization/linear_activation_weight_observed_tensor.py @@ -9,10 +9,7 @@ from torch.utils._python_dispatch import return_and_correct_aliasing from torchao.quantization.observer import AffineQuantizedObserverBase -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, - TorchAOBaseTensor, -) +from torchao.utils import TorchAOBaseTensor __all__ = [ "LinearActivationWeightObservedTensor", @@ -153,6 +150,5 @@ def _(func, types, args, kwargs): ) -if TORCH_VERSION_AT_LEAST_2_5: - # Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True` - torch.serialization.add_safe_globals([LinearActivationWeightObservedTensor]) +# Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals([LinearActivationWeightObservedTensor]) diff --git a/torchao/quantization/linear_quant_modules.py b/torchao/quantization/linear_quant_modules.py index 73e95036f1..de6755a55d 100644 --- a/torchao/quantization/linear_quant_modules.py +++ b/torchao/quantization/linear_quant_modules.py @@ -16,10 +16,7 @@ import torch.nn.functional as F from torchao.dtypes.utils import is_device -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_6, - find_multiple, -) +from torchao.utils import find_multiple from .quant_primitives import ( MappingType, @@ -60,7 +57,7 @@ def linear_forward_int4( ): origin_x_size = x.size() x = x.reshape(-1, origin_x_size[-1]) - if is_device(x.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6: + if is_device(x.device.type, "cpu"): c = torch.ops.aten._weight_int4pack_mm_for_cpu( x.to(precision), weight_int4pack, @@ -299,10 +296,7 @@ def _create_quantized_state_dict( self.precision, # dtype for scales_and_zeros ) # TODO: just get the device from mod.weight.device? - if ( - is_device(w_int4x8.device.type, "cpu") - and TORCH_VERSION_AT_LEAST_2_6 - ): + if is_device(w_int4x8.device.type, "cpu"): weight_int4pack = ( torch.ops.aten._convert_weight_to_int4pack_for_cpu( w_int4x8.to(self.device), self.inner_k_tiles diff --git a/torchao/quantization/observer.py b/torchao/quantization/observer.py index 6084da6e8d..6d928a4477 100644 --- a/torchao/quantization/observer.py +++ b/torchao/quantization/observer.py @@ -11,7 +11,6 @@ import torch from torchao.quantization.quant_primitives import _fake_quantize_affine -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 from .granularity import ( Granularity, @@ -373,6 +372,5 @@ def calculate_qparams(self): ) -if TORCH_VERSION_AT_LEAST_2_5: - # Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True` - torch.serialization.add_safe_globals([PerRow, PerTensor]) +# Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals([PerRow, PerTensor]) diff --git a/torchao/quantization/pt2e/_numeric_debugger.py b/torchao/quantization/pt2e/_numeric_debugger.py index 0346981391..5211e0f340 100644 --- a/torchao/quantization/pt2e/_numeric_debugger.py +++ b/torchao/quantization/pt2e/_numeric_debugger.py @@ -14,13 +14,9 @@ from torch.ao.ns.fx.utils import compute_sqnr from torch.export import ExportedProgram from torch.fx import GraphModule, Node +from torch.fx.traceback import NodeSource from torch.nn import functional as F -from torchao.utils import TORCH_VERSION_AT_LEAST_2_6 - -if TORCH_VERSION_AT_LEAST_2_6: - from torch.fx.traceback import NodeSource - from .graph_utils import bfs_trace_with_node_process NUMERIC_DEBUG_HANDLE_KEY = "numeric_debug_handle" @@ -262,12 +258,6 @@ def prepare_for_propagation_comparison(model: GraphModule) -> GraphModule: Returns: a model with output loggers for all unlifted nodes """ - if not TORCH_VERSION_AT_LEAST_2_6: - log.warning( - "prepare_for_propagation_comparison is only supported for PyTorch 2.6+" - ) - return model - # don't change the original model model = copy.deepcopy(model) for n in model.graph.nodes: diff --git a/torchao/quantization/pt2e/constant_fold.py b/torchao/quantization/pt2e/constant_fold.py index 27f82e6757..365eb0a77a 100644 --- a/torchao/quantization/pt2e/constant_fold.py +++ b/torchao/quantization/pt2e/constant_fold.py @@ -12,8 +12,6 @@ from torch._inductor.freezing_utils import maybe_set_is_frozen_param from torch.utils._ordered_set import OrderedSet -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 - aten = torch.ops.aten # We would like to split modules into two subgraphs for runtime weight updates to work correctly. @@ -162,13 +160,9 @@ def is_woq_int8_pattern(node: torch.fx.node.Node) -> bool: torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, torch.ops.quantized_decomposed.convert_element_type.no_fuse, + torch.ops.torchao.dequantize_affine, ] - if TORCH_VERSION_AT_LEAST_2_5: - DEQUANT_OPS += [ - torch.ops.torchao.dequantize_affine, - ] - if node.target in DEQUANT_OPS: # For the pattern fp32_weight -> q -> dq # We only folding fp32_weight -> q diff --git a/torchao/quantization/pt2e/convert.py b/torchao/quantization/pt2e/convert.py index 99516ac4c3..3728d7c252 100644 --- a/torchao/quantization/pt2e/convert.py +++ b/torchao/quantization/pt2e/convert.py @@ -69,14 +69,11 @@ from torch.fx import GraphModule from torch.fx.graph import Argument, Graph, Node from torch.fx.graph_module import _USER_PRESERVED_ATTRIBUTES_KEY +from torch.fx.traceback import NodeSource, NodeSourceAction from torch.nn.utils.parametrize import type_before_parametrizations from torchao.quantization.pt2e import FROM_NODE_KEY from torchao.quantization.pt2e.observer import _is_activation_post_process -from torchao.utils import TORCH_VERSION_AT_LEAST_2_6 - -if TORCH_VERSION_AT_LEAST_2_6: - from torch.fx.traceback import NodeSource, NodeSourceAction __all__ = [ "convert", @@ -188,8 +185,6 @@ def add_dequantize_op_kwargs(dequantize_op, input_node): def add_quantize_dequantize_node_info(qdq_node, original_node): # propagate from_node info from observer/fake_quant node to quantize/dequantize node - if not TORCH_VERSION_AT_LEAST_2_6: - return qdq_node.meta[FROM_NODE_KEY] = [ NodeSource( original_node, diff --git a/torchao/quantization/pt2e/observer.py b/torchao/quantization/pt2e/observer.py index 4115040669..60962f8d41 100644 --- a/torchao/quantization/pt2e/observer.py +++ b/torchao/quantization/pt2e/observer.py @@ -1877,13 +1877,6 @@ def convert(self, model: torch.fx.GraphModule, observer_node: Node): observer_node: the observer node to convert """ - from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 - - if not TORCH_VERSION_AT_LEAST_2_5: - raise NotImplementedError( - "convert for AffineQuantization is not implemented for pytorch version earlier than 2.5, please upgrade your pytorch to 2.5+." - ) - from torchao.quantization.pt2e.utils import create_getattr_from_value with model.graph.inserting_before(observer_node): diff --git a/torchao/quantization/pt2e/prepare.py b/torchao/quantization/pt2e/prepare.py index d8f5b99fc5..a1d57062f2 100644 --- a/torchao/quantization/pt2e/prepare.py +++ b/torchao/quantization/pt2e/prepare.py @@ -38,7 +38,6 @@ SharedQuantizationSpec, ) from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY -from torchao.utils import TORCH_VERSION_AT_LEAST_2_6 # TODO: make pt2e folder private? __all__ = [ @@ -553,7 +552,6 @@ def _maybe_insert_output_observer_for_node( isinstance(node, Node) and isinstance(new_output, Node) and FROM_NODE_KEY in node.meta - and TORCH_VERSION_AT_LEAST_2_6 ): new_output.meta[FROM_NODE_KEY] = node.meta[FROM_NODE_KEY] return new_output diff --git a/torchao/quantization/pt2e/quantize_pt2e.py b/torchao/quantization/pt2e/quantize_pt2e.py index 5eb385b7de..e58dc8e3ee 100644 --- a/torchao/quantization/pt2e/quantize_pt2e.py +++ b/torchao/quantization/pt2e/quantize_pt2e.py @@ -6,7 +6,7 @@ import torch -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_7 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_7 if TORCH_VERSION_AT_LEAST_2_7: from .constant_fold import constant_fold @@ -217,14 +217,9 @@ def train_loop(model, train_data): torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.quantize_per_tensor.tensor, torch.ops.quantized_decomposed.quantize_per_channel.default, + torch.ops.torchao.quantize_affine, ] -# ops are only registered after 2.5 -if TORCH_VERSION_AT_LEAST_2_5: - _QUANT_OPS += [ - torch.ops.torchao.quantize_affine, - ] - def _quant_node_constraint(n: Node) -> bool: """If there is any pure ops between get_attr and quantize op they will be const propagated diff --git a/torchao/quantization/pt2e/quantizer/port_metadata_pass.py b/torchao/quantization/pt2e/quantizer/port_metadata_pass.py index bef93a19fc..5e7e9344ee 100644 --- a/torchao/quantization/pt2e/quantizer/port_metadata_pass.py +++ b/torchao/quantization/pt2e/quantizer/port_metadata_pass.py @@ -15,7 +15,6 @@ from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY from torchao.quantization.pt2e.utils import _filter_sym_size_users from torchao.quantization.quant_primitives import quant_lib # noqa: F401 -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 from .quantizer import QuantizationSpecBase from .utils import is_valid_annotation @@ -34,27 +33,23 @@ torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.quantize_per_tensor.tensor, torch.ops.quantized_decomposed.quantize_per_channel.default, + torch.ops.torchao.quantize_affine, ] _DEQUANTIZE_OPS = [ torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, torch.ops.quantized_decomposed.dequantize_per_channel.default, + torch.ops.torchao.dequantize_affine, ] _CHOOSE_QPARAMS_OPS = [ torch.ops.quantized_decomposed.choose_qparams.tensor, torch.ops.quantized_decomposed.choose_qparams_symmetric.tensor, + torch.ops.torchao.choose_qparams_affine, ] -# ops are only registered after 2.5 -if TORCH_VERSION_AT_LEAST_2_5: - _QUANTIZE_OPS += [torch.ops.torchao.quantize_affine] - _DEQUANTIZE_OPS += [torch.ops.torchao.dequantize_affine] - _CHOOSE_QPARAMS_OPS += [torch.ops.torchao.choose_qparams_affine] - - def _add_metadata(to_node: torch.fx.Node, from_node: torch.fx.Node) -> None: from_meta = from_node.meta for meta_name in _METADATA_TO_PORT: diff --git a/torchao/quantization/qat/linear.py b/torchao/quantization/qat/linear.py index 59e759dab3..f94ec6f272 100644 --- a/torchao/quantization/qat/linear.py +++ b/torchao/quantization/qat/linear.py @@ -25,7 +25,6 @@ ) from torchao.quantization.unified import TwoStepQuantizer from torchao.quantization.utils import get_group_qparams_symmetric -from torchao.utils import TORCH_VERSION_AT_LEAST_2_6 from .fake_quantize_config import ( FakeQuantizeConfigBase, @@ -471,10 +470,7 @@ def _convert_qat_linear_4w(self, module: torch.nn.Module): n_bit, config.group_size, ) - if ( - is_device(q_weight.device.type, "cpu") - and TORCH_VERSION_AT_LEAST_2_6 - ): + if is_device(q_weight.device.type, "cpu"): q_weight = torch.ops.aten._convert_weight_to_int4pack_for_cpu( q_weight.to(child.weight.device), child.inner_k_tiles, diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 72efd18752..41f98baf06 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -86,9 +86,6 @@ to_weight_tensor_with_linear_activation_quantization_metadata, ) from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_4, - TORCH_VERSION_AT_LEAST_2_5, - TORCH_VERSION_AT_LEAST_2_6, _is_fbgemm_genai_gpu_available, is_MI300, is_sm_at_least_89, @@ -182,16 +179,16 @@ def _in_features_greater_than_16(mod, *args): return hasattr(mod, "in_features") and mod.in_features > 16 +# TODO: delete def change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs): """ Converts all linear weight tensors to the `Int8DynamicallyQuantizedLinearWeight` Tensor subclass, effectively applying the same form of quantization as apply_dynamic_quant while not modifying the linear modules. """ - if TORCH_VERSION_AT_LEAST_2_4: - raise ImportError( - "This API is deprecated for pytorch 2.4+, please checkout quantization/README.md for most up to date APIs" - ) + raise ImportError( + "This API is deprecated for pytorch 2.4+, please checkout quantization/README.md for most up to date APIs" + ) if filter_fn is None: filter_fn = lambda *args: _is_linear(*args) and _in_features_greater_than_16( @@ -207,6 +204,7 @@ def change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs): ) +# TODO: delete def change_linear_weights_to_int8_woqtensors(model, filter_fn=None, **kwargs): """ Converts all linear weight tensors to the @@ -214,10 +212,9 @@ def change_linear_weights_to_int8_woqtensors(model, filter_fn=None, **kwargs): effectively applying the same form of quantization as apply_weight_only_int8_quant while not modifying the linear modules. """ - if TORCH_VERSION_AT_LEAST_2_4: - raise ImportError( - "This API is deprecated for pytorch 2.4+, please checkout quantization/README.md for most up to date APIs" - ) + raise ImportError( + "This API is deprecated for pytorch 2.4+, please checkout quantization/README.md for most up to date APIs" + ) _replace_with_custom_fn_if_matches_filter( model, @@ -228,6 +225,7 @@ def change_linear_weights_to_int8_woqtensors(model, filter_fn=None, **kwargs): ) +# TODO: delete def change_linear_weights_to_int4_woqtensors( model, groupsize=128, @@ -251,10 +249,9 @@ def change_linear_weights_to_int4_woqtensors( ZeroPointDomain.INT, ZeroPointDomain.NONE] `preserve_zero`: whether to preserve zero, default is False """ - if TORCH_VERSION_AT_LEAST_2_4: - raise ImportError( - "This API is deprecated for pytorch 2.4+, please checkout quantization/README.md for most up to date APIs" - ) + raise ImportError( + "This API is deprecated for pytorch 2.4+, please checkout quantization/README.md for most up to date APIs" + ) if filter_fn is None: filter_fn = _is_linear @@ -655,20 +652,15 @@ def _int8_asymm_per_token_quant(x: torch.Tensor) -> torch.Tensor: scale_dtype = torch.float32 eps = torch.finfo(torch.float32).eps zero_point_dtype = torch.int8 - if TORCH_VERSION_AT_LEAST_2_6: - return to_affine_quantized_intx( - x, - mapping_type, - _get_per_token_block_size(x), - target_dtype, - eps=eps, - scale_dtype=scale_dtype, - zero_point_dtype=zero_point_dtype, - ) - else: - return to_affine_quantized_intx( - x, mapping_type, _get_per_token_block_size(x), target_dtype - ) + return to_affine_quantized_intx( + x, + mapping_type, + _get_per_token_block_size(x), + target_dtype, + eps=eps, + scale_dtype=scale_dtype, + zero_point_dtype=zero_point_dtype, + ) def _uint8_asymm_per_token_quant(x: torch.Tensor) -> torch.Tensor: @@ -679,27 +671,17 @@ def _uint8_asymm_per_token_quant(x: torch.Tensor) -> torch.Tensor: zero_point_dtype = torch.int32 quant_min = 0 quant_max = 255 - if TORCH_VERSION_AT_LEAST_2_6: - out = to_affine_quantized_intx( - x, - mapping_type, - _get_per_token_block_size(x), - target_dtype, - quant_min=quant_min, - quant_max=quant_max, - eps=eps, - scale_dtype=scale_dtype, - zero_point_dtype=zero_point_dtype, - ) - else: - out = to_affine_quantized_intx( - x, - mapping_type, - _get_per_token_block_size(x), - target_dtype, - quant_min=quant_min, - quant_max=quant_max, - ) + out = to_affine_quantized_intx( + x, + mapping_type, + _get_per_token_block_size(x), + target_dtype, + quant_min=quant_min, + quant_max=quant_max, + eps=eps, + scale_dtype=scale_dtype, + zero_point_dtype=zero_point_dtype, + ) return out @@ -832,7 +814,6 @@ class Int8DynamicActivationIntxWeightConfig(AOBaseConfig): args: weight_dtype: The dtype to use for weight quantization. Must be torch.intx, where 1 <= x <= 8. - torch.intx with x < 8 requires TORCH_VERSION_AT_LEAST_2_6 weight_granularity: The granularity to use for weight quantization. Must be PerGroup or PerAxis(axis=0). weight_mapping_type: The type of mapping to use for the weight quantization. Must be one of MappingType.ASYMMETRIC or MappingType.SYMMETRIC. MappingType.SYMMETRIC requires ZeroPointDomain.NONE @@ -854,9 +835,6 @@ class Int8DynamicActivationIntxWeightConfig(AOBaseConfig): layout: Layout = QDQLayout() def __post_init__(self): - assert TORCH_VERSION_AT_LEAST_2_6, ( - "Int8DynamicActivationIntxWeightConfig requires torch 2.6+" - ) assert self.weight_dtype in [getattr(torch, f"int{b}") for b in range(1, 9)], ( f"weight_dtype must be torch.intx, where 1 <= x <= 8, but got {self.weight_dtype}" ) @@ -2046,7 +2024,6 @@ class IntxWeightOnlyConfig(AOBaseConfig): manner using the number of bits specified by weight_dtype. args: weight_dtype: The dtype to use for weight quantization. Must be torch.intx, where 1 <= x <= 8. - torch.intx with x < 8 requires TORCH_VERSION_AT_LEAST_2_6 granularity: The granularity to use for weight quantization. Must be PerGroup or PerAxis(0). mapping_type: The type of mapping to use for the weight quantization. Must be one of MappingType.ASYMMETRIC or MappingType.SYMMETRIC. @@ -2063,7 +2040,6 @@ class IntxWeightOnlyConfig(AOBaseConfig): layout: Layout = QDQLayout() def __post_init__(self): - assert TORCH_VERSION_AT_LEAST_2_6, "IntxWeightOnlyConfig requires torch 2.6+" assert self.weight_dtype in [getattr(torch, f"int{b}") for b in range(1, 9)], ( f"weight_dtype must be torch.intx, where 1 <= x <= 8, but got {self.weight_dtype}" ) @@ -2287,16 +2263,15 @@ def _module_fqn_to_config_handler( return module -if TORCH_VERSION_AT_LEAST_2_5: - torch.serialization.add_safe_globals( - [ - _int8_asymm_per_token_quant, - _int8_symm_per_token_reduced_range_quant, - _input_activation_quant_func_fp8, - _int4_symm_cutlass_quant, - _int8_symm_cutlass_quant, - _float8_cutlass_quant, - _float8_cutlass_quant_sparse, - Target, - ] - ) +torch.serialization.add_safe_globals( + [ + _int8_asymm_per_token_quant, + _int8_symm_per_token_reduced_range_quant, + _input_activation_quant_func_fp8, + _int4_symm_cutlass_quant, + _int8_symm_cutlass_quant, + _float8_cutlass_quant, + _float8_cutlass_quant_sparse, + Target, + ] +) diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index a91c3acd28..ebd2c7ecd8 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -16,9 +16,6 @@ _n_ones, ) from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_3, - TORCH_VERSION_AT_LEAST_2_5, - TORCH_VERSION_AT_LEAST_2_6, _register_custom_op, _register_meta_op, ) @@ -107,8 +104,7 @@ class TorchAODType(Enum): INT7 = auto() -if TORCH_VERSION_AT_LEAST_2_5: - torch.serialization.add_safe_globals([MappingType, ZeroPointDomain]) +torch.serialization.add_safe_globals([MappingType, ZeroPointDomain]) FP8_TYPES = { torch.float8_e4m3fn, @@ -152,53 +148,49 @@ class TorchAODType(Enum): TorchAODType.INT7: (-(2**6), 2**6 - 1), } -# torch.uintX available only in PyTorch 2.3+ -if TORCH_VERSION_AT_LEAST_2_3: - _SUB_BYTE_UINT_BOUNDS = { - torch.uint1: (0, 2**1 - 1), - torch.uint2: (0, 2**2 - 1), - torch.uint3: (0, 2**3 - 1), - torch.uint4: (0, 2**4 - 1), - torch.uint5: (0, 2**5 - 1), - torch.uint6: (0, 2**6 - 1), - torch.uint7: (0, 2**7 - 1), +_SUB_BYTE_UINT_BOUNDS = { + torch.uint1: (0, 2**1 - 1), + torch.uint2: (0, 2**2 - 1), + torch.uint3: (0, 2**3 - 1), + torch.uint4: (0, 2**4 - 1), + torch.uint5: (0, 2**5 - 1), + torch.uint6: (0, 2**6 - 1), + torch.uint7: (0, 2**7 - 1), +} +_DTYPE_TO_BIT_WIDTH.update( + { + torch.uint1: 1, + torch.uint2: 2, + torch.uint3: 3, + torch.uint4: 4, + torch.uint5: 5, + torch.uint6: 6, + torch.uint7: 7, } - _DTYPE_TO_BIT_WIDTH.update( - { - torch.uint1: 1, - torch.uint2: 2, - torch.uint3: 3, - torch.uint4: 4, - torch.uint5: 5, - torch.uint6: 6, - torch.uint7: 7, - } - ) - -# torch.intX available only in PyTorch 2.6+ -if TORCH_VERSION_AT_LEAST_2_6: - _SUB_BYTE_INT_BOUNDS.update( - { - torch.int1: (-(2**0), 2**0 - 1), - torch.int2: (-(2**1), 2**1 - 1), - torch.int3: (-(2**2), 2**2 - 1), - torch.int4: (-(2**3), 2**3 - 1), - torch.int5: (-(2**4), 2**4 - 1), - torch.int6: (-(2**5), 2**5 - 1), - torch.int7: (-(2**6), 2**6 - 1), - } - ) - _DTYPE_TO_BIT_WIDTH.update( - { - torch.int1: 1, - torch.int2: 2, - torch.int3: 3, - torch.int4: 4, - torch.int5: 5, - torch.int6: 6, - torch.int7: 7, - } - ) +) + +_SUB_BYTE_INT_BOUNDS.update( + { + torch.int1: (-(2**0), 2**0 - 1), + torch.int2: (-(2**1), 2**1 - 1), + torch.int3: (-(2**2), 2**2 - 1), + torch.int4: (-(2**3), 2**3 - 1), + torch.int5: (-(2**4), 2**4 - 1), + torch.int6: (-(2**5), 2**5 - 1), + torch.int7: (-(2**6), 2**6 - 1), + } +) +_DTYPE_TO_BIT_WIDTH.update( + { + torch.int1: 1, + torch.int2: 2, + torch.int3: 3, + torch.int4: 4, + torch.int5: 5, + torch.int6: 6, + torch.int7: 7, + } +) _DTYPE_TO_QVALUE_BOUNDS.update(_SUB_BYTE_UINT_BOUNDS) _DTYPE_TO_QVALUE_BOUNDS.update(_SUB_BYTE_INT_BOUNDS) diff --git a/torchao/quantization/quantize_/common/kernel_preference.py b/torchao/quantization/quantize_/common/kernel_preference.py index 5430463543..c9b853f300 100644 --- a/torchao/quantization/quantize_/common/kernel_preference.py +++ b/torchao/quantization/quantize_/common/kernel_preference.py @@ -8,8 +8,6 @@ import torch -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 - # can switch to StrEnum (https://docs.python.org/3/library/enum.html#enum.StrEnum) # after python 3.10 is end of life (https://devguide.python.org/versions/) @@ -33,5 +31,4 @@ class KernelPreference(str, Enum): FBGEMM = "fbgemm" -if TORCH_VERSION_AT_LEAST_2_5: - torch.serialization.add_safe_globals([KernelPreference]) +torch.serialization.add_safe_globals([KernelPreference]) diff --git a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py index b94dc36361..7726b2094c 100644 --- a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py +++ b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py @@ -35,7 +35,6 @@ _choose_quant_func_and_quantize_tensor, ) from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, TorchAOBaseTensor, _is_fbgemm_genai_gpu_available, fill_defaults, @@ -608,6 +607,5 @@ def _(func, types, args, kwargs): Float8Tensor.__module__ = "torchao.quantization" -if TORCH_VERSION_AT_LEAST_2_5: - # Allow a model with Float8Tensor weights to be loaded with `weights_only=True` - torch.serialization.add_safe_globals([Float8Tensor, QuantizeTensorToFloat8Kwargs]) +# Allow a model with Float8Tensor weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals([Float8Tensor, QuantizeTensorToFloat8Kwargs]) diff --git a/torchao/quantization/quantize_/workflows/int4/int4_preshuffled_tensor.py b/torchao/quantization/quantize_/workflows/int4/int4_preshuffled_tensor.py index 16595f370e..50cf261642 100644 --- a/torchao/quantization/quantize_/workflows/int4/int4_preshuffled_tensor.py +++ b/torchao/quantization/quantize_/workflows/int4/int4_preshuffled_tensor.py @@ -11,7 +11,6 @@ import torch from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, TorchAOBaseTensor, ) @@ -260,6 +259,5 @@ def _(func, types, args, kwargs): Int4PreshuffledTensor.__module__ = "torchao.quantization" -if TORCH_VERSION_AT_LEAST_2_5: - # Allow a model with Int4PreshuffledTensor weights to be loaded with `weights_only=True` - torch.serialization.add_safe_globals([Int4PreshuffledTensor]) +# Allow a model with Int4PreshuffledTensor weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals([Int4PreshuffledTensor]) diff --git a/torchao/quantization/quantize_/workflows/int4/int4_tensor.py b/torchao/quantization/quantize_/workflows/int4/int4_tensor.py index ebf36dd644..1b2729fdd6 100644 --- a/torchao/quantization/quantize_/workflows/int4/int4_tensor.py +++ b/torchao/quantization/quantize_/workflows/int4/int4_tensor.py @@ -10,7 +10,7 @@ import torch from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TorchAOBaseTensor, fill_defaults +from torchao.utils import TorchAOBaseTensor, fill_defaults __all__ = [ "Int4Tensor", @@ -486,6 +486,5 @@ def _(func, types, args, kwargs): Int4Tensor.__module__ = "torchao.quantization" -if TORCH_VERSION_AT_LEAST_2_5: - # Allow a model with Int4Tensor weights to be loaded with `weights_only=True` - torch.serialization.add_safe_globals([Int4Tensor]) +# Allow a model with Int4Tensor weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals([Int4Tensor]) diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index a4097ecc25..d56fa0732d 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -25,7 +25,6 @@ quantize_affine, ) from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, check_cpu_version, check_xpu_version, ) @@ -449,7 +448,7 @@ def groupwise_affine_quantize_tensor_from_qparams( quant_min, quant_max, ) - if TORCH_VERSION_AT_LEAST_2_5 and w.shape[-1] > 1: + if w.shape[-1] > 1: if (not (check_cpu_version(int_data.device))) and ( not (check_xpu_version(int_data.device)) ): @@ -470,10 +469,8 @@ def groupwise_affine_dequantize_tensor_from_qparams( assert groupsize > 1 assert w_int4x8.dim() == 2 # need to handle single column case so check for dtype/size from groupwise_affine_quantize_tensor_from_qparams path - if ( - TORCH_VERSION_AT_LEAST_2_5 - and (w_int4x8.dtype == torch.uint8 or w_int4x8.shape[-1] > 1) - and not (check_cpu_version(w_int4x8.device)) + if (w_int4x8.dtype == torch.uint8 or w_int4x8.shape[-1] > 1) and not ( + check_cpu_version(w_int4x8.device) ): data = w_int4x8.to(torch.int32) high_bits = data >> 4 diff --git a/torchao/quantization/weight_tensor_linear_activation_quantization.py b/torchao/quantization/weight_tensor_linear_activation_quantization.py index 6612213bc1..c0b0a893e4 100644 --- a/torchao/quantization/weight_tensor_linear_activation_quantization.py +++ b/torchao/quantization/weight_tensor_linear_activation_quantization.py @@ -8,10 +8,7 @@ import torch from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, - TorchAOBaseTensor, -) +from torchao.utils import TorchAOBaseTensor __all__ = [ "WeightTensorWithLinearActivationQuantizationMetadata", @@ -201,8 +198,7 @@ def _(func, types, args, kwargs): WeightTensorWithLinearActivationQuantizationMetadata.from_float ) -if TORCH_VERSION_AT_LEAST_2_5: - # Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True` - torch.serialization.add_safe_globals( - [WeightTensorWithLinearActivationQuantizationMetadata] - ) +# Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals( + [WeightTensorWithLinearActivationQuantizationMetadata] +) diff --git a/torchao/sparsity/training/__init__.py b/torchao/sparsity/training/__init__.py index 3c4212101b..87ce3add4f 100644 --- a/torchao/sparsity/training/__init__.py +++ b/torchao/sparsity/training/__init__.py @@ -4,17 +4,15 @@ # LICENSE file in the root directory of this source tree. import torch +# load pointwise op support, which exists only for CUTLASS +from torch.sparse import SparseSemiStructuredTensorCUTLASS + from torchao.sparsity.training.autograd import semi_structured_sparsify from torchao.sparsity.training.pointwise_ops import CUTLASS_POINTWISE_OP_DISPATCH_TABLE -from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 - -# load pointwise op support, which exists only for CUTLASS -if TORCH_VERSION_AT_LEAST_2_3: - from torch.sparse import SparseSemiStructuredTensorCUTLASS - SparseSemiStructuredTensorCUTLASS._load_dispatch_table( - CUTLASS_POINTWISE_OP_DISPATCH_TABLE - ) +SparseSemiStructuredTensorCUTLASS._load_dispatch_table( + CUTLASS_POINTWISE_OP_DISPATCH_TABLE +) __all__ = [ "SemiSparseLinear", diff --git a/torchao/sparsity/training/autograd.py b/torchao/sparsity/training/autograd.py index fafbd7c3c3..40c6c98083 100644 --- a/torchao/sparsity/training/autograd.py +++ b/torchao/sparsity/training/autograd.py @@ -6,18 +6,14 @@ from enum import Enum import torch -from torch.sparse import SparseSemiStructuredTensor - -from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 - -if TORCH_VERSION_AT_LEAST_2_3: - from torch.sparse import ( - SparseSemiStructuredTensorCUSPARSELT, - SparseSemiStructuredTensorCUTLASS, - ) - - torch._dynamo.allow_in_graph(SparseSemiStructuredTensorCUSPARSELT) - torch._dynamo.allow_in_graph(SparseSemiStructuredTensorCUTLASS) +from torch.sparse import ( + SparseSemiStructuredTensor, + SparseSemiStructuredTensorCUSPARSELT, + SparseSemiStructuredTensorCUTLASS, +) + +torch._dynamo.allow_in_graph(SparseSemiStructuredTensorCUSPARSELT) +torch._dynamo.allow_in_graph(SparseSemiStructuredTensorCUTLASS) GRADIENT_TYPE = Enum("GRADIENT_TYPE", ["DENSE", "SPARSE", "STE"]) diff --git a/torchao/testing/pt2e/utils.py b/torchao/testing/pt2e/utils.py index c4773231a5..a41d3f597f 100644 --- a/torchao/testing/pt2e/utils.py +++ b/torchao/testing/pt2e/utils.py @@ -15,6 +15,7 @@ _convert_to_reference_decomposed_fx, prepare_fx, ) +from torch.export import export_for_training from torch.testing._internal.common_quantization import ( NodeSpec, QuantizationTestCase, @@ -29,16 +30,9 @@ prepare_pt2e, prepare_qat_pt2e, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_7 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_7 -if TORCH_VERSION_AT_LEAST_2_5: - from torch.export import export_for_training - -@unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_5, - "only works for torch 2.5+ since export_for_training is only supported after 2.5", -) class PT2EQuantizationTestCase(QuantizationTestCase): """ Base QuantizationTestCase for PT2 with some helper methods. diff --git a/torchao/testing/utils.py b/torchao/testing/utils.py index 38fc8b04ce..33def3f998 100644 --- a/torchao/testing/utils.py +++ b/torchao/testing/utils.py @@ -24,7 +24,6 @@ ) from torchao.testing.model_architectures import LlamaModelsLlama4Experts from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_6, DummyModule, get_compute_capability, ) @@ -420,10 +419,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: dn_dist(up_dist(input_dtensor)) - if not TORCH_VERSION_AT_LEAST_2_6: - # Need torch 2.6 to support compiled tensor parallelism - return - up_compiled = torch.compile(up_dist) y_up = up_compiled(input_dtensor) dn_compiled = torch.compile(dn_dist) diff --git a/torchao/utils.py b/torchao/utils.py index 40ca9e3702..f72e60e3d1 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -141,9 +141,8 @@ def get_available_devices(): devices.append("cuda") elif torch.xpu.is_available(): devices.append("xpu") - if TORCH_VERSION_AT_LEAST_2_5: - if torch.mps.is_available(): - devices.append("mps") + if torch.mps.is_available(): + devices.append("mps") return devices @@ -216,37 +215,31 @@ def _the_op_that_needs_to_be_preserved(...) ) def decorator(fn): - if TORCH_VERSION_AT_LEAST_2_5: - from torch._library.infer_schema import infer_schema + from torch._library.infer_schema import infer_schema - assert not any(c in fn.__name__ for c in ".<>"), ( - f"Expecting op to be defined in normal functions, not lambda or local: {fn.__name__}" - ) - op_name = fn.__name__ - if op_name[0] == "_": - op_name = op_name[1:] - schema = op_name + infer_schema(fn, mutates_args={}) - lib.define(schema) - lib.impl(op_name, fn, dispatch_key) - - lib_namespace = lib.ns - op = getattr(getattr(torch.ops, lib_namespace), op_name) - if inductor_decomposed: - register_decomposition([op])(fn) - return op - else: - return fn + assert not any(c in fn.__name__ for c in ".<>"), ( + f"Expecting op to be defined in normal functions, not lambda or local: {fn.__name__}" + ) + op_name = fn.__name__ + if op_name[0] == "_": + op_name = op_name[1:] + schema = op_name + infer_schema(fn, mutates_args={}) + lib.define(schema) + lib.impl(op_name, fn, dispatch_key) + + lib_namespace = lib.ns + op = getattr(getattr(torch.ops, lib_namespace), op_name) + if inductor_decomposed: + register_decomposition([op])(fn) + return op return decorator def _register_meta_op(lib, op_name): def decorator(fn): - if TORCH_VERSION_AT_LEAST_2_5: - op = lib.impl(op_name, fn, "Meta") - return op - else: - return fn + op = lib.impl(op_name, fn, "Meta") + return op return decorator @@ -644,9 +637,8 @@ def decorator(tensor_impl_class): tensor_class._LAYOUT_CONSTRUCTOR_TABLE[layout_class] = ( tensor_impl_class.from_plain ) - if TORCH_VERSION_AT_LEAST_2_5: - # Allow serialization to work for models uses this tensor impl subclass - torch.serialization.add_safe_globals([layout_class, tensor_impl_class]) + # Allow serialization to work for models uses this tensor impl subclass + torch.serialization.add_safe_globals([layout_class, tensor_impl_class]) return tensor_impl_class return decorator diff --git a/tutorials/quantize_vit/run_vit_b_quant.py b/tutorials/quantize_vit/run_vit_b_quant.py index faaa9b1ae9..c326828219 100644 --- a/tutorials/quantize_vit/run_vit_b_quant.py +++ b/tutorials/quantize_vit/run_vit_b_quant.py @@ -37,12 +37,6 @@ torch._inductor.config.use_mixed_mm = True ## compilation configs end -# temporary workaround for the API to work with torch.compile -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, unwrap_tensor_subclass - -if not TORCH_VERSION_AT_LEAST_2_5: - unwrap_tensor_subclass(model) - # temporary workaround to recover the perf with quantized model under torch.compile torch.backends.mha.set_fastpath_enabled(False)