Skip to content

Commit 0739127

Browse files
pytorchbotzou3519
andauthored
Gracefully handle optree less than minimum version (pytorch#150977)
Gracefully handle optree less than minimum version (pytorch#150956) Summary: - We are saying the minimum version of pytree that PyTorch can use is 0.13.0 - If a user imports torch.utils._cxx_pytree, it will raise an ImportError if optree doesn't exist or exists and is less than the minimum version. Fixes pytorch#150889. There are actually two parts to that issue: 1. dtensor imports torch.utils._cxx_pytree, but the optree installed in the environment might be too old. Instead, raising ImportError in torch.utils._cxx_pytree solves the issue. 2. We emit an "optree too low version" warning. I've deleted the warning in favor of the more explicit ImportError. Test Plan: - code reading Pull Request resolved: pytorch#150956 Approved by: https://github.com/albanD, https://github.com/atalman, https://github.com/XuehaiPan (cherry picked from commit 061832b) Co-authored-by: rzou <[email protected]>
1 parent 0c236f3 commit 0739127

File tree

2 files changed

+16
-12
lines changed

2 files changed

+16
-12
lines changed

torch/utils/_cxx_pytree.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,19 @@
2020
from typing_extensions import deprecated, TypeIs
2121

2222
import optree
23+
24+
from torch._vendor.packaging.version import Version
25+
26+
27+
if Version(optree.__version__) < Version("0.13.0"): # type: ignore[attr-defined]
28+
raise ImportError(
29+
"torch.utils._cxx_pytree depends on optree, which is an optional dependency "
30+
"of PyTorch. To use it, please upgrade your optree package to >= 0.13.0"
31+
)
32+
33+
del Version
34+
35+
2336
from optree import PyTreeSpec as TreeSpec # direct import for type annotations
2437

2538
import torch.utils._pytree as python_pytree

torch/utils/_pytree.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -163,21 +163,12 @@ class _SerializeNodeDef(NamedTuple):
163163
try:
164164
_optree_version = importlib.metadata.version("optree")
165165
except importlib.metadata.PackageNotFoundError:
166+
# optree was not imported
166167
_cxx_pytree_dynamo_traceable = _cxx_pytree_exists = False
167168
else:
169+
# optree was imported
168170
_cxx_pytree_exists = True
169-
from torch._vendor.packaging.version import Version
170-
171-
_cxx_pytree_dynamo_traceable = Version(_optree_version) >= Version("0.13.0")
172-
if not _cxx_pytree_dynamo_traceable:
173-
warnings.warn(
174-
"optree is installed but the version is too old to support PyTorch Dynamo in C++ pytree. "
175-
"C++ pytree support is disabled. "
176-
"Please consider upgrading optree using `python3 -m pip install --upgrade 'optree>=0.13.0'`.",
177-
FutureWarning,
178-
)
179-
180-
del Version
171+
_cxx_pytree_dynamo_traceable = True
181172

182173
_cxx_pytree_imported = False
183174
_cxx_pytree_pending_imports: list[Any] = []

0 commit comments

Comments
 (0)