4
4
# This source code is licensed under the BSD 3-Clause license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
import unittest
7
+ import warnings
7
8
from unittest .mock import patch
8
9
9
10
import torch
12
13
from torchao .utils import TorchAOBaseTensor , torch_version_at_least
13
14
14
15
15
- class TestTorchVersionAtLeast (unittest .TestCase ):
16
+ class TestTorchVersion (unittest .TestCase ):
16
17
def test_torch_version_at_least (self ):
17
18
test_cases = [
18
19
("2.5.0a0+git9f17037" , "2.5.0" , True ),
@@ -35,6 +36,55 @@ def test_torch_version_at_least(self):
35
36
f"Failed for torch.__version__={ torch_version } , comparing with { compare_version } " ,
36
37
)
37
38
39
+ def test_torch_version_deprecation (self ):
40
+ """
41
+ Test that TORCH_VERSION_AT_LEAST* and TORCH_VERSION_AFTER*
42
+ trigger deprecation warnings on use, not on import.
43
+ """
44
+ # Reset deprecation warning state, otherwise we won't log warnings here
45
+ warnings .resetwarnings ()
46
+
47
+ # Importing and referencing should not trigger deprecation warning
48
+ with warnings .catch_warnings (record = True ) as _warnings :
49
+ from torchao .utils import (
50
+ TORCH_VERSION_AFTER_2_2 ,
51
+ TORCH_VERSION_AFTER_2_3 ,
52
+ TORCH_VERSION_AFTER_2_4 ,
53
+ TORCH_VERSION_AFTER_2_5 ,
54
+ TORCH_VERSION_AT_LEAST_2_2 ,
55
+ TORCH_VERSION_AT_LEAST_2_3 ,
56
+ TORCH_VERSION_AT_LEAST_2_4 ,
57
+ TORCH_VERSION_AT_LEAST_2_5 ,
58
+ TORCH_VERSION_AT_LEAST_2_6 ,
59
+ TORCH_VERSION_AT_LEAST_2_7 ,
60
+ TORCH_VERSION_AT_LEAST_2_8 ,
61
+ )
62
+
63
+ deprecated_api_to_name = [
64
+ (TORCH_VERSION_AT_LEAST_2_8 , "TORCH_VERSION_AT_LEAST_2_8" ),
65
+ (TORCH_VERSION_AT_LEAST_2_7 , "TORCH_VERSION_AT_LEAST_2_7" ),
66
+ (TORCH_VERSION_AT_LEAST_2_6 , "TORCH_VERSION_AT_LEAST_2_6" ),
67
+ (TORCH_VERSION_AT_LEAST_2_5 , "TORCH_VERSION_AT_LEAST_2_5" ),
68
+ (TORCH_VERSION_AT_LEAST_2_4 , "TORCH_VERSION_AT_LEAST_2_4" ),
69
+ (TORCH_VERSION_AT_LEAST_2_3 , "TORCH_VERSION_AT_LEAST_2_3" ),
70
+ (TORCH_VERSION_AT_LEAST_2_2 , "TORCH_VERSION_AT_LEAST_2_2" ),
71
+ (TORCH_VERSION_AFTER_2_5 , "TORCH_VERSION_AFTER_2_5" ),
72
+ (TORCH_VERSION_AFTER_2_4 , "TORCH_VERSION_AFTER_2_4" ),
73
+ (TORCH_VERSION_AFTER_2_3 , "TORCH_VERSION_AFTER_2_3" ),
74
+ (TORCH_VERSION_AFTER_2_2 , "TORCH_VERSION_AFTER_2_2" ),
75
+ ]
76
+ self .assertEqual (len (_warnings ), 0 )
77
+
78
+ # Accessing the boolean value should trigger deprecation warning
79
+ with warnings .catch_warnings (record = True ) as _warnings :
80
+ for api , name in deprecated_api_to_name :
81
+ num_warnings_before = len (_warnings )
82
+ if api :
83
+ pass
84
+ regex = f"{ name } is deprecated and will be removed"
85
+ self .assertEqual (len (_warnings ), num_warnings_before + 1 )
86
+ self .assertIn (regex , str (_warnings [- 1 ].message ))
87
+
38
88
39
89
class TestTorchAOBaseTensor (unittest .TestCase ):
40
90
def test_print_arg_types (self ):
0 commit comments