diff --git a/test/test_utils.py b/test/test_utils.py index 3ba2f32613..9213097276 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. import unittest +import warnings from unittest.mock import patch import torch @@ -12,7 +13,7 @@ from torchao.utils import TorchAOBaseTensor, torch_version_at_least -class TestTorchVersionAtLeast(unittest.TestCase): +class TestTorchVersion(unittest.TestCase): def test_torch_version_at_least(self): test_cases = [ ("2.5.0a0+git9f17037", "2.5.0", True), @@ -35,6 +36,55 @@ def test_torch_version_at_least(self): f"Failed for torch.__version__={torch_version}, comparing with {compare_version}", ) + def test_torch_version_deprecation(self): + """ + Test that TORCH_VERSION_AT_LEAST* and TORCH_VERSION_AFTER* + trigger deprecation warnings on use, not on import. + """ + # Reset deprecation warning state, otherwise we won't log warnings here + warnings.resetwarnings() + + # Importing and referencing should not trigger deprecation warning + with warnings.catch_warnings(record=True) as _warnings: + from torchao.utils import ( + TORCH_VERSION_AFTER_2_2, + TORCH_VERSION_AFTER_2_3, + TORCH_VERSION_AFTER_2_4, + TORCH_VERSION_AFTER_2_5, + TORCH_VERSION_AT_LEAST_2_2, + TORCH_VERSION_AT_LEAST_2_3, + TORCH_VERSION_AT_LEAST_2_4, + TORCH_VERSION_AT_LEAST_2_5, + TORCH_VERSION_AT_LEAST_2_6, + TORCH_VERSION_AT_LEAST_2_7, + TORCH_VERSION_AT_LEAST_2_8, + ) + + deprecated_api_to_name = [ + (TORCH_VERSION_AT_LEAST_2_8, "TORCH_VERSION_AT_LEAST_2_8"), + (TORCH_VERSION_AT_LEAST_2_7, "TORCH_VERSION_AT_LEAST_2_7"), + (TORCH_VERSION_AT_LEAST_2_6, "TORCH_VERSION_AT_LEAST_2_6"), + (TORCH_VERSION_AT_LEAST_2_5, "TORCH_VERSION_AT_LEAST_2_5"), + (TORCH_VERSION_AT_LEAST_2_4, "TORCH_VERSION_AT_LEAST_2_4"), + (TORCH_VERSION_AT_LEAST_2_3, "TORCH_VERSION_AT_LEAST_2_3"), + (TORCH_VERSION_AT_LEAST_2_2, "TORCH_VERSION_AT_LEAST_2_2"), + (TORCH_VERSION_AFTER_2_5, "TORCH_VERSION_AFTER_2_5"), + (TORCH_VERSION_AFTER_2_4, "TORCH_VERSION_AFTER_2_4"), + (TORCH_VERSION_AFTER_2_3, "TORCH_VERSION_AFTER_2_3"), + (TORCH_VERSION_AFTER_2_2, "TORCH_VERSION_AFTER_2_2"), + ] + self.assertEqual(len(_warnings), 0) + + # Accessing the boolean value should trigger deprecation warning + with warnings.catch_warnings(record=True) as _warnings: + for api, name in deprecated_api_to_name: + num_warnings_before = len(_warnings) + if api: + pass + regex = f"{name} is deprecated and will be removed" + self.assertEqual(len(_warnings), num_warnings_before + 1) + self.assertIn(regex, str(_warnings[-1].message)) + class TestTorchAOBaseTensor(unittest.TestCase): def test_print_arg_types(self): diff --git a/torchao/utils.py b/torchao/utils.py index fb82b9f005..ea939bdd9a 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -8,6 +8,7 @@ import itertools import re import time +import warnings from functools import reduce from importlib.metadata import version from math import gcd @@ -377,13 +378,59 @@ def torch_version_at_least(min_version): return is_fbcode() or compare_versions(torch.__version__, min_version) >= 0 -TORCH_VERSION_AT_LEAST_2_8 = torch_version_at_least("2.8.0") -TORCH_VERSION_AT_LEAST_2_7 = torch_version_at_least("2.7.0") -TORCH_VERSION_AT_LEAST_2_6 = torch_version_at_least("2.6.0") -TORCH_VERSION_AT_LEAST_2_5 = torch_version_at_least("2.5.0") -TORCH_VERSION_AT_LEAST_2_4 = torch_version_at_least("2.4.0") -TORCH_VERSION_AT_LEAST_2_3 = torch_version_at_least("2.3.0") -TORCH_VERSION_AT_LEAST_2_2 = torch_version_at_least("2.2.0") +def _deprecated_torch_version_at_least(version_str: str) -> str: + """ + Wrapper for existing TORCH_VERSION_AT_LEAST* variables that will log + a deprecation warning if the variable is used. + """ + version_str_var_name = "_".join(version_str.split(".")[:2]) + deprecation_msg = f"TORCH_VERSION_AT_LEAST_{version_str_var_name} is deprecated and will be removed in torchao 0.14.0" + return _BoolDeprecationWrapper( + torch_version_at_least(version_str), + deprecation_msg, + ) + + +def _deprecated_torch_version_after(version_str: str) -> str: + """ + Wrapper for existing TORCH_VERSION_AFTER* variables that will log + a deprecation warning if the variable is used. + """ + bool_value = is_fbcode() or version("torch") >= version_str + version_str_var_name = "_".join(version_str.split(".")[:2]) + deprecation_msg = f"TORCH_VERSION_AFTER_{version_str_var_name} is deprecated and will be removed in torchao 0.14.0" + return _BoolDeprecationWrapper(bool_value, deprecation_msg) + + +class _BoolDeprecationWrapper: + """ + A deprecation wrapper that logs a warning when the given bool value is accessed. + """ + + def __init__(self, bool_value: bool, msg: str): + self.bool_value = bool_value + self.msg = msg + + def __bool__(self): + warnings.warn(self.msg) + return self.bool_value + + def __eq__(self, other): + return bool(self) == bool(other) + + +# Deprecated, use `torch_version_at_least` directly instead +TORCH_VERSION_AT_LEAST_2_8 = _deprecated_torch_version_at_least("2.8.0") +TORCH_VERSION_AT_LEAST_2_7 = _deprecated_torch_version_at_least("2.7.0") +TORCH_VERSION_AT_LEAST_2_6 = _deprecated_torch_version_at_least("2.6.0") +TORCH_VERSION_AT_LEAST_2_5 = _deprecated_torch_version_at_least("2.5.0") +TORCH_VERSION_AT_LEAST_2_4 = _deprecated_torch_version_at_least("2.4.0") +TORCH_VERSION_AT_LEAST_2_3 = _deprecated_torch_version_at_least("2.3.0") +TORCH_VERSION_AT_LEAST_2_2 = _deprecated_torch_version_at_least("2.2.0") +TORCH_VERSION_AFTER_2_5 = _deprecated_torch_version_after("2.5.0.dev") +TORCH_VERSION_AFTER_2_4 = _deprecated_torch_version_after("2.4.0.dev") +TORCH_VERSION_AFTER_2_3 = _deprecated_torch_version_after("2.3.0.dev") +TORCH_VERSION_AFTER_2_2 = _deprecated_torch_version_after("2.2.0.dev") """ @@ -766,11 +813,6 @@ def fill_defaults(args, n, defaults_tail): return r -## Deprecated, will be deleted in the future -def _torch_version_at_least(min_version): - return is_fbcode() or version("torch") >= min_version - - # Supported AMD GPU Models and their LLVM gfx Codes: # # | AMD GPU Model | LLVM gfx Code | @@ -857,12 +899,6 @@ def ceil_div(a, b): return (a + b - 1) // b -TORCH_VERSION_AFTER_2_5 = _torch_version_at_least("2.5.0.dev") -TORCH_VERSION_AFTER_2_4 = _torch_version_at_least("2.4.0.dev") -TORCH_VERSION_AFTER_2_3 = _torch_version_at_least("2.3.0.dev") -TORCH_VERSION_AFTER_2_2 = _torch_version_at_least("2.2.0.dev") - - def is_package_at_least(package_name: str, min_version: str): package_exists = importlib.util.find_spec(package_name) is not None if not package_exists: