Skip to content

Commit bfc2050

Browse files
mlazospytorchmergebot
authored andcommitted
[user-streams] Make device-agnostic streams weakref compatible (pytorch#164304)
Pull Request resolved: pytorch#164304 Approved by: https://github.com/williamwen42, https://github.com/colesbury ghstack dependencies: pytorch#162903, pytorch#164343, pytorch#164344, pytorch#164507, pytorch#162901
1 parent c5701d0 commit bfc2050

File tree

3 files changed

+10
-1
lines changed

3 files changed

+10
-1
lines changed

test/dynamo/test_streams.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Owner(s): ["module: dynamo"]
2+
import weakref
23

34
import torch
45
import torch._dynamo.test_case
@@ -15,6 +16,10 @@ def setUpClass(cls):
1516
def tearDownClass(cls):
1617
super().tearDownClass()
1718

19+
def test_stream_weakref(self):
20+
s = torch.Stream()
21+
weakref.ref(s)
22+
1823
@requires_cuda
1924
def test_run_opcheck(self):
2025
from torch._dynamo.variables.streams import fork_stream, join_stream

torch/csrc/Stream.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ static PyObject* THPStream_pynew(
9595
self->device_index = static_cast<int64_t>(stream_opt->device_index());
9696
self->device_type = static_cast<int64_t>(stream_opt->device_type());
9797
self->context = nullptr;
98+
self->weakreflist = nullptr;
9899

99100
return static_cast<PyObject*>(ptr.release());
100101
END_HANDLE_TH_ERRORS
@@ -114,11 +115,13 @@ PyObject* THPStream_Wrap(const c10::Stream& stream) {
114115
self->device_index = static_cast<int64_t>(stream.device_index());
115116
self->device_type = static_cast<int64_t>(stream.device_type());
116117
self->context = nullptr;
118+
self->weakreflist = nullptr;
117119
return ptr.release();
118120
END_HANDLE_TH_ERRORS
119121
}
120122

121123
static void THPStream_dealloc(THPStream* self) {
124+
PyObject_ClearWeakRefs((PyObject*)self);
122125
Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
123126
}
124127

@@ -444,7 +447,7 @@ static PyTypeObject THPStreamType = {
444447
nullptr, /* tp_traverse */
445448
nullptr, /* tp_clear */
446449
THPStream_richcompare, /* tp_richcompare */
447-
0, /* tp_weaklistoffset */
450+
offsetof(THPStream, weakreflist), /* tp_weaklistoffset */
448451
nullptr, /* tp_iter */
449452
nullptr, /* tp_iternext */
450453
// NOLINTNEXTLINE(*const-cast)

torch/csrc/Stream.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ struct THPStream {
1313
int64_t device_index;
1414
// Used to switch stream context management, initialized lazily.
1515
PyObject* context;
16+
PyObject* weakreflist;
1617
};
1718
extern TORCH_API PyTypeObject* THPStreamClass;
1819

0 commit comments

Comments
 (0)