Skip to content

Commit fc3f994

Browse files
committed
Deprecate old TORCH_VERSION variables
**Summary:** This commit deprecates the following variables in favor of simply calling `torch_version_at_least`. ``` TORCH_VERSION_AT_LEAST_2_8 TORCH_VERSION_AT_LEAST_2_7 TORCH_VERSION_AT_LEAST_2_6 TORCH_VERSION_AT_LEAST_2_5 TORCH_VERSION_AT_LEAST_2_4 TORCH_VERSION_AT_LEAST_2_3 TORCH_VERSION_AT_LEAST_2_2 TORCH_VERSION_AFTER_2_5 TORCH_VERSION_AFTER_2_4 TORCH_VERSION_AFTER_2_3 TORCH_VERSION_AFTER_2_2 ``` **Test Plan:** ``` python test/test_utils.py -k torch_version_deprecation ``` ghstack-source-id: 4fb3a75 Pull Request resolved: #2719
1 parent 6cfa477 commit fc3f994

File tree

2 files changed

+105
-19
lines changed

2 files changed

+105
-19
lines changed

test/test_utils.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66
import unittest
7+
import warnings
78
from unittest.mock import patch
89

910
import torch
@@ -12,7 +13,7 @@
1213
from torchao.utils import TorchAOBaseTensor, torch_version_at_least
1314

1415

15-
class TestTorchVersionAtLeast(unittest.TestCase):
16+
class TestTorchVersion(unittest.TestCase):
1617
def test_torch_version_at_least(self):
1718
test_cases = [
1819
("2.5.0a0+git9f17037", "2.5.0", True),
@@ -35,6 +36,55 @@ def test_torch_version_at_least(self):
3536
f"Failed for torch.__version__={torch_version}, comparing with {compare_version}",
3637
)
3738

39+
def test_torch_version_deprecation(self):
40+
"""
41+
Test that TORCH_VERSION_AT_LEAST* and TORCH_VERSION_AFTER*
42+
trigger deprecation warnings on use, not on import.
43+
"""
44+
# Reset deprecation warning state, otherwise we won't log warnings here
45+
warnings.resetwarnings()
46+
47+
# Importing and referencing should not trigger deprecation warning
48+
with warnings.catch_warnings(record=True) as _warnings:
49+
from torchao.utils import (
50+
TORCH_VERSION_AFTER_2_2,
51+
TORCH_VERSION_AFTER_2_3,
52+
TORCH_VERSION_AFTER_2_4,
53+
TORCH_VERSION_AFTER_2_5,
54+
TORCH_VERSION_AT_LEAST_2_2,
55+
TORCH_VERSION_AT_LEAST_2_3,
56+
TORCH_VERSION_AT_LEAST_2_4,
57+
TORCH_VERSION_AT_LEAST_2_5,
58+
TORCH_VERSION_AT_LEAST_2_6,
59+
TORCH_VERSION_AT_LEAST_2_7,
60+
TORCH_VERSION_AT_LEAST_2_8,
61+
)
62+
63+
deprecated_api_to_name = [
64+
(TORCH_VERSION_AT_LEAST_2_8, "TORCH_VERSION_AT_LEAST_2_8"),
65+
(TORCH_VERSION_AT_LEAST_2_7, "TORCH_VERSION_AT_LEAST_2_7"),
66+
(TORCH_VERSION_AT_LEAST_2_6, "TORCH_VERSION_AT_LEAST_2_6"),
67+
(TORCH_VERSION_AT_LEAST_2_5, "TORCH_VERSION_AT_LEAST_2_5"),
68+
(TORCH_VERSION_AT_LEAST_2_4, "TORCH_VERSION_AT_LEAST_2_4"),
69+
(TORCH_VERSION_AT_LEAST_2_3, "TORCH_VERSION_AT_LEAST_2_3"),
70+
(TORCH_VERSION_AT_LEAST_2_2, "TORCH_VERSION_AT_LEAST_2_2"),
71+
(TORCH_VERSION_AFTER_2_5, "TORCH_VERSION_AFTER_2_5"),
72+
(TORCH_VERSION_AFTER_2_4, "TORCH_VERSION_AFTER_2_4"),
73+
(TORCH_VERSION_AFTER_2_3, "TORCH_VERSION_AFTER_2_3"),
74+
(TORCH_VERSION_AFTER_2_2, "TORCH_VERSION_AFTER_2_2"),
75+
]
76+
self.assertEqual(len(_warnings), 0)
77+
78+
# Accessing the boolean value should trigger deprecation warning
79+
with warnings.catch_warnings(record=True) as _warnings:
80+
for api, name in deprecated_api_to_name:
81+
num_warnings_before = len(_warnings)
82+
if api:
83+
pass
84+
regex = f"{name} is deprecated and will be removed"
85+
self.assertEqual(len(_warnings), num_warnings_before + 1)
86+
self.assertIn(regex, str(_warnings[-1].message))
87+
3888

3989
class TestTorchAOBaseTensor(unittest.TestCase):
4090
def test_print_arg_types(self):

torchao/utils.py

Lines changed: 54 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import itertools
99
import re
1010
import time
11+
import warnings
1112
from functools import reduce
1213
from importlib.metadata import version
1314
from math import gcd
@@ -377,13 +378,59 @@ def torch_version_at_least(min_version):
377378
return is_fbcode() or compare_versions(torch.__version__, min_version) >= 0
378379

379380

380-
TORCH_VERSION_AT_LEAST_2_8 = torch_version_at_least("2.8.0")
381-
TORCH_VERSION_AT_LEAST_2_7 = torch_version_at_least("2.7.0")
382-
TORCH_VERSION_AT_LEAST_2_6 = torch_version_at_least("2.6.0")
383-
TORCH_VERSION_AT_LEAST_2_5 = torch_version_at_least("2.5.0")
384-
TORCH_VERSION_AT_LEAST_2_4 = torch_version_at_least("2.4.0")
385-
TORCH_VERSION_AT_LEAST_2_3 = torch_version_at_least("2.3.0")
386-
TORCH_VERSION_AT_LEAST_2_2 = torch_version_at_least("2.2.0")
381+
def _deprecated_torch_version_at_least(version_str: str) -> str:
382+
"""
383+
Wrapper for existing TORCH_VERSION_AT_LEAST* variables that will log
384+
a deprecation warning if the variable is used.
385+
"""
386+
version_str_var_name = "_".join(version_str.split(".")[:2])
387+
deprecation_msg = f"TORCH_VERSION_AT_LEAST_{version_str_var_name} is deprecated and will be removed in torchao 0.14.0"
388+
return _BoolDeprecationWrapper(
389+
torch_version_at_least(version_str),
390+
deprecation_msg,
391+
)
392+
393+
394+
def _deprecated_torch_version_after(version_str: str) -> str:
395+
"""
396+
Wrapper for existing TORCH_VERSION_AFTER* variables that will log
397+
a deprecation warning if the variable is used.
398+
"""
399+
bool_value = is_fbcode() or version("torch") >= version_str
400+
version_str_var_name = "_".join(version_str.split(".")[:2])
401+
deprecation_msg = f"TORCH_VERSION_AFTER_{version_str_var_name} is deprecated and will be removed in torchao 0.14.0"
402+
return _BoolDeprecationWrapper(bool_value, deprecation_msg)
403+
404+
405+
class _BoolDeprecationWrapper:
406+
"""
407+
A deprecation wrapper that logs a warning when the given bool value is accessed.
408+
"""
409+
410+
def __init__(self, bool_value: bool, msg: str):
411+
self.bool_value = bool_value
412+
self.msg = msg
413+
414+
def __bool__(self):
415+
warnings.warn(self.msg)
416+
return self.bool_value
417+
418+
def __eq__(self, other):
419+
return bool(self) == bool(other)
420+
421+
422+
# Deprecated, use `torch_version_at_least` directly instead
423+
TORCH_VERSION_AT_LEAST_2_8 = _deprecated_torch_version_at_least("2.8.0")
424+
TORCH_VERSION_AT_LEAST_2_7 = _deprecated_torch_version_at_least("2.7.0")
425+
TORCH_VERSION_AT_LEAST_2_6 = _deprecated_torch_version_at_least("2.6.0")
426+
TORCH_VERSION_AT_LEAST_2_5 = _deprecated_torch_version_at_least("2.5.0")
427+
TORCH_VERSION_AT_LEAST_2_4 = _deprecated_torch_version_at_least("2.4.0")
428+
TORCH_VERSION_AT_LEAST_2_3 = _deprecated_torch_version_at_least("2.3.0")
429+
TORCH_VERSION_AT_LEAST_2_2 = _deprecated_torch_version_at_least("2.2.0")
430+
TORCH_VERSION_AFTER_2_5 = _deprecated_torch_version_after("2.5.0.dev")
431+
TORCH_VERSION_AFTER_2_4 = _deprecated_torch_version_after("2.4.0.dev")
432+
TORCH_VERSION_AFTER_2_3 = _deprecated_torch_version_after("2.3.0.dev")
433+
TORCH_VERSION_AFTER_2_2 = _deprecated_torch_version_after("2.2.0.dev")
387434

388435

389436
"""
@@ -766,11 +813,6 @@ def fill_defaults(args, n, defaults_tail):
766813
return r
767814

768815

769-
## Deprecated, will be deleted in the future
770-
def _torch_version_at_least(min_version):
771-
return is_fbcode() or version("torch") >= min_version
772-
773-
774816
# Supported AMD GPU Models and their LLVM gfx Codes:
775817
#
776818
# | AMD GPU Model | LLVM gfx Code |
@@ -857,12 +899,6 @@ def ceil_div(a, b):
857899
return (a + b - 1) // b
858900

859901

860-
TORCH_VERSION_AFTER_2_5 = _torch_version_at_least("2.5.0.dev")
861-
TORCH_VERSION_AFTER_2_4 = _torch_version_at_least("2.4.0.dev")
862-
TORCH_VERSION_AFTER_2_3 = _torch_version_at_least("2.3.0.dev")
863-
TORCH_VERSION_AFTER_2_2 = _torch_version_at_least("2.2.0.dev")
864-
865-
866902
def is_package_at_least(package_name: str, min_version: str):
867903
package_exists = importlib.util.find_spec(package_name) is not None
868904
if not package_exists:

0 commit comments

Comments
 (0)