|
8 | 8 | import itertools
|
9 | 9 | import re
|
10 | 10 | import time
|
| 11 | +import warnings |
11 | 12 | from functools import reduce
|
12 | 13 | from importlib.metadata import version
|
13 | 14 | from math import gcd
|
@@ -377,13 +378,62 @@ def torch_version_at_least(min_version):
|
377 | 378 | return is_fbcode() or compare_versions(torch.__version__, min_version) >= 0
|
378 | 379 |
|
379 | 380 |
|
| 381 | +# Deprecated, will be deleted in the future |
| 382 | +def _torch_version_after(min_version): |
| 383 | + return is_fbcode() or version("torch") >= min_version |
| 384 | + |
| 385 | + |
| 386 | +def _get_old_torch_version_deprecation_msg(version_str: str) -> str: |
| 387 | + return f"TORCH_VERSION_AT_LEAST_{version_str} is deprecated and will be removed in torchao 0.14.0" |
| 388 | + |
| 389 | + |
| 390 | +def _get_torch_version_after_deprecation_msg(version_str: str) -> str: |
| 391 | + return f"TORCH_VERSION_AFTER_{version_str} is deprecated and will be removed in torchao 0.14.0" |
| 392 | + |
| 393 | + |
| 394 | +class _BoolDeprecationWrapper: |
| 395 | + """ |
| 396 | + A deprecation wrapper that logs a warning when the given bool value is accessed. |
| 397 | + """ |
| 398 | + |
| 399 | + def __init__(self, bool_value: bool, msg: str): |
| 400 | + self.bool_value = bool_value |
| 401 | + self.msg = msg |
| 402 | + |
| 403 | + def __bool__(self): |
| 404 | + warnings.warn(self.msg) |
| 405 | + return self.bool_value |
| 406 | + |
| 407 | + |
380 | 408 | TORCH_VERSION_AT_LEAST_2_8 = torch_version_at_least("2.8.0")
|
381 | 409 | TORCH_VERSION_AT_LEAST_2_7 = torch_version_at_least("2.7.0")
|
382 | 410 | TORCH_VERSION_AT_LEAST_2_6 = torch_version_at_least("2.6.0")
|
383 |
| -TORCH_VERSION_AT_LEAST_2_5 = torch_version_at_least("2.5.0") |
384 |
| -TORCH_VERSION_AT_LEAST_2_4 = torch_version_at_least("2.4.0") |
385 |
| -TORCH_VERSION_AT_LEAST_2_3 = torch_version_at_least("2.3.0") |
386 |
| -TORCH_VERSION_AT_LEAST_2_2 = torch_version_at_least("2.2.0") |
| 411 | + |
| 412 | +# Deprecated |
| 413 | +TORCH_VERSION_AT_LEAST_2_5 = _BoolDeprecationWrapper( |
| 414 | + torch_version_at_least("2.5.0"), _get_old_torch_version_deprecation_msg("2_5") |
| 415 | +) |
| 416 | +TORCH_VERSION_AT_LEAST_2_4 = _BoolDeprecationWrapper( |
| 417 | + torch_version_at_least("2.4.0"), _get_old_torch_version_deprecation_msg("2_4") |
| 418 | +) |
| 419 | +TORCH_VERSION_AT_LEAST_2_3 = _BoolDeprecationWrapper( |
| 420 | + torch_version_at_least("2.3.0"), _get_old_torch_version_deprecation_msg("2_3") |
| 421 | +) |
| 422 | +TORCH_VERSION_AT_LEAST_2_2 = _BoolDeprecationWrapper( |
| 423 | + torch_version_at_least("2.2.0"), _get_old_torch_version_deprecation_msg("2_2") |
| 424 | +) |
| 425 | +TORCH_VERSION_AFTER_2_5 = _BoolDeprecationWrapper( |
| 426 | + _torch_version_after("2.5.0.dev"), _get_torch_version_after_deprecation_msg("2_5") |
| 427 | +) |
| 428 | +TORCH_VERSION_AFTER_2_4 = _BoolDeprecationWrapper( |
| 429 | + _torch_version_after("2.4.0.dev"), _get_torch_version_after_deprecation_msg("2_4") |
| 430 | +) |
| 431 | +TORCH_VERSION_AFTER_2_3 = _BoolDeprecationWrapper( |
| 432 | + _torch_version_after("2.3.0.dev"), _get_torch_version_after_deprecation_msg("2_3") |
| 433 | +) |
| 434 | +TORCH_VERSION_AFTER_2_2 = _BoolDeprecationWrapper( |
| 435 | + _torch_version_after("2.2.0.dev"), _get_torch_version_after_deprecation_msg("2_2") |
| 436 | +) |
387 | 437 |
|
388 | 438 |
|
389 | 439 | """
|
@@ -766,11 +816,6 @@ def fill_defaults(args, n, defaults_tail):
|
766 | 816 | return r
|
767 | 817 |
|
768 | 818 |
|
769 |
| -## Deprecated, will be deleted in the future |
770 |
| -def _torch_version_at_least(min_version): |
771 |
| - return is_fbcode() or version("torch") >= min_version |
772 |
| - |
773 |
| - |
774 | 819 | # Supported AMD GPU Models and their LLVM gfx Codes:
|
775 | 820 | #
|
776 | 821 | # | AMD GPU Model | LLVM gfx Code |
|
@@ -857,12 +902,6 @@ def ceil_div(a, b):
|
857 | 902 | return (a + b - 1) // b
|
858 | 903 |
|
859 | 904 |
|
860 |
| -TORCH_VERSION_AFTER_2_5 = _torch_version_at_least("2.5.0.dev") |
861 |
| -TORCH_VERSION_AFTER_2_4 = _torch_version_at_least("2.4.0.dev") |
862 |
| -TORCH_VERSION_AFTER_2_3 = _torch_version_at_least("2.3.0.dev") |
863 |
| -TORCH_VERSION_AFTER_2_2 = _torch_version_at_least("2.2.0.dev") |
864 |
| - |
865 |
| - |
866 | 905 | def is_package_at_least(package_name: str, min_version: str):
|
867 | 906 | package_exists = importlib.util.find_spec(package_name) is not None
|
868 | 907 | if not package_exists:
|
|
0 commit comments