Skip to content

Commit cde81e9

Browse files
mlazospytorchmergebot
authored andcommitted
[User-streams] Make torch.Event weakref compatible (pytorch#164522)
Pull Request resolved: pytorch#164522 Approved by: https://github.com/williamwen42 ghstack dependencies: pytorch#162903, pytorch#164343, pytorch#164344, pytorch#164507, pytorch#162901, pytorch#164304
1 parent bfc2050 commit cde81e9

File tree

4 files changed

+12
-23
lines changed

4 files changed

+12
-23
lines changed

test/dynamo/test_compile.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -234,27 +234,6 @@ def fn(x, y):
234234
with self.assertRaises(IndexError):
235235
fn(torch.randn(10), 99)
236236

237-
def test_list_bad_weakref(self):
238-
import weakref
239-
240-
a = torch.Event()
241-
with self.assertRaises(TypeError):
242-
weakref.ref(a)
243-
244-
@torch.compile(backend="eager")
245-
class Mod(torch.nn.Module):
246-
def __init__(self, event):
247-
super().__init__()
248-
self.event = event
249-
250-
def forward(self, x):
251-
return x * int(self.event.query())
252-
253-
e = torch.Event()
254-
m = Mod(e)
255-
a = torch.randn(10)
256-
self.assertEqual(m(a), a)
257-
258237

259238
# The private variants of the below functions are extensively tested
260239
# So as long as the signatures match we're good

test/dynamo/test_streams.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ def test_stream_weakref(self):
2020
s = torch.Stream()
2121
weakref.ref(s)
2222

23+
def test_event_weakref(self):
24+
e = torch.Event()
25+
weakref.ref(e)
26+
2327
@requires_cuda
2428
def test_run_opcheck(self):
2529
from torch._dynamo.variables.streams import fork_stream, join_stream

torch/csrc/Event.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ static PyObject* THPEvent_pynew(
4949
}
5050

5151
THPEvent* self = reinterpret_cast<THPEvent*>(ptr.get());
52+
self->weakreflist = nullptr;
5253

5354
// TODO: blocking and interprocess are not supported yet. To support them, the
5455
// flag system of c10::Event needs to be refactored. C10::Event should also
@@ -73,6 +74,7 @@ PyObject* THPEvent_new(c10::DeviceType device_type, c10::EventFlag flag) {
7374
auto self = THPObjectPtr{type->tp_alloc(type, 0)};
7475
TORCH_CHECK(self, "Failed to allocate memory for Event");
7576
auto self_ = reinterpret_cast<THPEvent*>(self.get());
77+
self_->weakreflist = nullptr;
7678
new (&self_->event) c10::Event(device_type, flag);
7779
return self.release();
7880
}
@@ -82,6 +84,7 @@ static void THPEvent_dealloc(THPEvent* self) {
8284
pybind11::gil_scoped_release no_gil{};
8385
self->event.~Event();
8486
}
87+
PyObject_ClearWeakRefs((PyObject*)self);
8588
Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
8689
}
8790

@@ -282,7 +285,8 @@ static PyMethodDef THPEvent_methods[] = {
282285
{"synchronize", THPEvent_synchronize, METH_NOARGS, nullptr},
283286
{"ipc_handle", THPEvent_ipc_handle, METH_NOARGS, nullptr},
284287
{nullptr}};
285-
288+
#pragma GCC diagnostic push
289+
#pragma GCC diagnostic ignored "-Winvalid-offsetof"
286290
PyTypeObject THPEventType = {
287291
PyVarObject_HEAD_INIT(nullptr, 0)
288292
"torch.Event", /* tp_name */
@@ -308,7 +312,7 @@ PyTypeObject THPEventType = {
308312
nullptr, /* tp_traverse */
309313
nullptr, /* tp_clear */
310314
nullptr, /* tp_richcompare */
311-
0, /* tp_weaklistoffset */
315+
offsetof(THPEvent, weakreflist), /* tp_weaklistoffset */
312316
nullptr, /* tp_iter */
313317
nullptr, /* tp_iternext */
314318
THPEvent_methods, /* tp_methods */
@@ -323,6 +327,7 @@ PyTypeObject THPEventType = {
323327
nullptr, /* tp_alloc */
324328
THPEvent_pynew, /* tp_new */
325329
};
330+
#pragma GCC diagnostic pop
326331

327332
void THPEvent_init(PyObject* module) {
328333
THPEventClass = &THPEventType;

torch/csrc/Event.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
struct TORCH_API THPEvent {
88
PyObject_HEAD
99
c10::Event event;
10+
PyObject* weakreflist;
1011
};
1112
TORCH_API extern PyTypeObject* THPEventClass;
1213
TORCH_API extern PyTypeObject THPEventType;

0 commit comments

Comments
 (0)