Skip to content

Commit 5d0f639

Browse files
eee4017pytorchmergebot
authored andcommitted
Make Tensor.__dlpack__(stream=None) capture-safe during CUDA Graph capture (pytorch#163242)
Many extensions (including pybind helpers) call `Tensor.__dlpack__()` without a stream argument. Before pytorch#150217, `stream=None` behaved like “no cross-stream sync” and was safe inside CUDA Graph capture. After pytorch#150217, `stream=None` maps to the legacy default stream, adding a cross-stream wait that invalidates capture when running on a non-default stream. See this example ``` import torch s = torch.cuda.Stream() x = torch.randn(8, device="cuda") g = torch.cuda.CUDAGraph() with torch.cuda.stream(s): with torch.cuda.graph(g): _ = x + 1 cap = x.__dlpack__() _ = torch.utils.dlpack.from_dlpack(cap) ``` This PR partially reverts pytorch#150217 that stream=None defaults to no sync. Pull Request resolved: pytorch#163242 Approved by: https://github.com/ngimel
1 parent 9d0d98a commit 5d0f639

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

torch/_tensor.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1699,7 +1699,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):
16991699
def __dlpack__(
17001700
self,
17011701
*,
1702-
stream: Optional[Any] = None,
1702+
stream: Optional[Any] = -1,
17031703
max_version: Optional[tuple[int, int]] = None,
17041704
dl_device: Optional[tuple[enum.IntEnum, int]] = None,
17051705
copy: Optional[bool] = None,
@@ -1717,9 +1717,12 @@ def __dlpack__(
17171717
pointer to a CUDA stream. The current stream is synchronized with
17181718
this stream before the capsule is created, and since the capsule
17191719
shares its storage with the tensor this make it safe to access from
1720-
both streams. If None or -1 is passed then no synchronization is performed.
1720+
both streams. If -1 is passed then no synchronization is performed.
17211721
If 1 (on CUDA) or 0 (on ROCM) then the default stream is used for
1722-
synchronization.
1722+
synchronization. This API intentionally slightly deviates from the DLPack
1723+
guidance: the default stream is -1 (stream-preserving; no cross-stream sync),
1724+
because many from_dlpack implementations intend stream preservation.
1725+
For non-CUDA devices, -1 is treated the same as None.
17231726
17241727
max_version (tuple[int, int] or None): An optional Python tuple with
17251728
2 integers, representing the maximum version the caller supports. If
@@ -1797,7 +1800,7 @@ def __dlpack__(
17971800
event.record(current_stream)
17981801
stream.wait_event(event)
17991802
elif self.device.type == "cpu":
1800-
assert stream is None, "stream should be None on cpu."
1803+
assert stream is None or stream == -1, "stream should be None on cpu."
18011804

18021805
if self.device.type == "xla":
18031806
import torch_xla

0 commit comments

Comments
 (0)