Skip to content

Commit 1341794

Browse files
pytorchbotzou3519
andauthored
Gracefully handle optree less than minimum version, part 2 (pytorch#151323)
Gracefully handle optree less than minimum version, part 2 (pytorch#151257) If optree is less than the minimum version, we should pretend it doesn't exist. The problem right now is: - Install optree==0.12.1 - `import torch._dynamo` - This raise an error "min optree version is 0.13.0" The fix is to pretend optree doesn't exist if it is less than the min version. There are ways to clean up this PR more (e.g. have a single source of truth for the version, some of the variables are redundant), but I am trying to reduce the risk as much as possible for this to go into 2.7. Test Plan: I verified the above problem was fixed. Also tried some other things, like the following, which now gives the expected behavior. ```py >>> import torch >>> import optree >>> optree.__version__ '0.12.1' >>> import torch._dynamo >>> import torch._dynamo.polyfills.pytree >>> import torch.utils._pytree >>> import torch.utils._cxx_pytree ImportError: torch.utils._cxx_pytree depends on optree, which is an optional dependency of PyTorch. To u se it, please upgrade your optree package to >= 0.13.0 ``` I also audited all non-test callsites of optree and torch.utils._cxx_pytree. Follow along with me: optree imports - torch.utils._cxx_pytree. This is fine. - [guarded by check] https://github.com/pytorch/pytorch/blob/f76b7ef33cc30f7378ef71a201f68a2bef18dba0/torch/_dynamo/polyfills/pytree.py#L29-L31 _cxx_pytree imports - [guarded by check] torch.utils._pytree (changed in this PR) - [guarded by check] torch/_dynamo/polyfills/pytree.py (changed in this PR) - [guarded by try-catch] https://github.com/pytorch/pytorch/blob/f76b7ef33cc30f7378ef71a201f68a2bef18dba0/torch/distributed/_functional_collectives.py#L17 - [guarded by try-catch] https://github.com/pytorch/pytorch/blob/f76b7ef33cc30f7378ef71a201f68a2bef18dba0/torch/distributed/tensor/_op_schema.py#L15 - [guarded by try-catch] https://github.com/pytorch/pytorch/blob/f76b7ef33cc30f7378ef71a201f68a2bef18dba0/torch/distributed/tensor/_dispatch.py#L35 - [guarded by try-catch] https://github.com/pytorch/pytorch/blob/f76b7ef33cc30f7378ef71a201f68a2bef18dba0/torch/_dynamo/variables/user_defined.py#L94 - [guarded by try-catch] https://github.com/pytorch/pytorch/blob/f76b7ef33cc30f7378ef71a201f68a2bef18dba0/torch/distributed/tensor/experimental/_func_map.py#L14 Pull Request resolved: pytorch#151257 Approved by: https://github.com/malfet, https://github.com/XuehaiPan (cherry picked from commit f1f18c7) Co-authored-by: rzou <[email protected]>
1 parent 0739127 commit 1341794

File tree

3 files changed

+16
-6
lines changed

3 files changed

+16
-6
lines changed

torch/_dynamo/polyfills/pytree.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@
2020
from collections.abc import Iterable
2121
from typing_extensions import Self
2222

23-
from torch.utils._cxx_pytree import PyTree
24-
2523

2624
__all__: list[str] = []
2725

@@ -32,6 +30,9 @@
3230

3331
import torch.utils._cxx_pytree as cxx_pytree
3432

33+
if TYPE_CHECKING:
34+
from torch.utils._cxx_pytree import PyTree
35+
3536
@substitute_in_graph(
3637
optree._C.is_dict_insertion_ordered,
3738
can_constant_fold_through=True,

torch/utils/_cxx_pytree.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from torch._vendor.packaging.version import Version
2525

2626

27+
# Keep the version in sync with torch.utils._cxx_pytree!
2728
if Version(optree.__version__) < Version("0.13.0"): # type: ignore[attr-defined]
2829
raise ImportError(
2930
"torch.utils._cxx_pytree depends on optree, which is an optional dependency "

torch/utils/_pytree.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -163,12 +163,20 @@ class _SerializeNodeDef(NamedTuple):
163163
try:
164164
_optree_version = importlib.metadata.version("optree")
165165
except importlib.metadata.PackageNotFoundError:
166-
# optree was not imported
166+
# No optree package found
167167
_cxx_pytree_dynamo_traceable = _cxx_pytree_exists = False
168168
else:
169-
# optree was imported
170-
_cxx_pytree_exists = True
171-
_cxx_pytree_dynamo_traceable = True
169+
from torch._vendor.packaging.version import Version
170+
171+
# Keep this in sync with torch.utils._cxx_pytree!
172+
if Version(_optree_version) < Version("0.13.0"):
173+
# optree package less than our required minimum version.
174+
# Pretend the optree package doesn't exist.
175+
# NB: We will raise ImportError if the user directly tries to
176+
# `import torch.utils._cxx_pytree` (look in that file for the check).
177+
_cxx_pytree_dynamo_traceable = _cxx_pytree_exists = False
178+
else:
179+
_cxx_pytree_dynamo_traceable = _cxx_pytree_exists = True
172180

173181
_cxx_pytree_imported = False
174182
_cxx_pytree_pending_imports: list[Any] = []

0 commit comments

Comments
 (0)