File tree Expand file tree Collapse file tree 3 files changed +10
-1
lines changed Expand file tree Collapse file tree 3 files changed +10
-1
lines changed Original file line number Diff line number Diff line change 11# Owner(s): ["module: dynamo"]
2+ import weakref
23
34import torch
45import torch ._dynamo .test_case
@@ -15,6 +16,10 @@ def setUpClass(cls):
1516 def tearDownClass (cls ):
1617 super ().tearDownClass ()
1718
19+ def test_stream_weakref (self ):
20+ s = torch .Stream ()
21+ weakref .ref (s )
22+
1823 @requires_cuda
1924 def test_run_opcheck (self ):
2025 from torch ._dynamo .variables .streams import fork_stream , join_stream
Original file line number Diff line number Diff line change @@ -95,6 +95,7 @@ static PyObject* THPStream_pynew(
9595 self->device_index = static_cast <int64_t >(stream_opt->device_index ());
9696 self->device_type = static_cast <int64_t >(stream_opt->device_type ());
9797 self->context = nullptr ;
98+ self->weakreflist = nullptr ;
9899
99100 return static_cast <PyObject*>(ptr.release ());
100101 END_HANDLE_TH_ERRORS
@@ -114,11 +115,13 @@ PyObject* THPStream_Wrap(const c10::Stream& stream) {
114115 self->device_index = static_cast <int64_t >(stream.device_index ());
115116 self->device_type = static_cast <int64_t >(stream.device_type ());
116117 self->context = nullptr ;
118+ self->weakreflist = nullptr ;
117119 return ptr.release ();
118120 END_HANDLE_TH_ERRORS
119121}
120122
121123static void THPStream_dealloc (THPStream* self) {
124+ PyObject_ClearWeakRefs ((PyObject*)self);
122125 Py_TYPE (self)->tp_free (reinterpret_cast <PyObject*>(self));
123126}
124127
@@ -444,7 +447,7 @@ static PyTypeObject THPStreamType = {
444447 nullptr , /* tp_traverse */
445448 nullptr , /* tp_clear */
446449 THPStream_richcompare, /* tp_richcompare */
447- 0 , /* tp_weaklistoffset */
450+ offsetof (THPStream, weakreflist) , /* tp_weaklistoffset */
448451 nullptr , /* tp_iter */
449452 nullptr , /* tp_iternext */
450453 // NOLINTNEXTLINE(*const-cast)
Original file line number Diff line number Diff line change @@ -13,6 +13,7 @@ struct THPStream {
1313 int64_t device_index ;
1414 // Used to switch stream context management, initialized lazily.
1515 PyObject * context ;
16+ PyObject * weakreflist ;
1617};
1718extern TORCH_API PyTypeObject * THPStreamClass ;
1819
You can’t perform that action at this time.
0 commit comments