Skip to content

Commit b8fa964

Browse files
committed
Patch out refcounted RAII patterns from torch
1 parent d22906f commit b8fa964

File tree

1 file changed

+332
-0
lines changed

1 file changed

+332
-0
lines changed

graalpython/lib-graalpython/patches/torch/torch-1.13.1.patch

Lines changed: 332 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,51 @@ index 2ef8b7f2..1f4efd70 100644
121121
report("Building wheel {}-{}".format(package_name, version))
122122

123123
cmake = CMake()
124+
diff --git a/test/test_overrides.py b/test/test_overrides.py
125+
index e9e01684..f4069bc9 100644
126+
--- a/test/test_overrides.py
127+
+++ b/test/test_overrides.py
128+
@@ -1456,12 +1456,9 @@ class TestTorchFunctionMode(TestCase):
129+
pass
130+
131+
x = A(torch.randn(5))
132+
- with torch._C.DisableTorchFunction():
133+
- g = torch._C._EnableTorchFunction()
134+
- try:
135+
+ with torch._C.DisableTorchFunction(), \
136+
+ torch._C._EnableTorchFunction():
137+
self.assertIsInstance(torch.sum(x), A)
138+
- finally:
139+
- del g
140+
141+
def test_subclass_hash(self):
142+
class DiagTensor(torch.Tensor):
143+
diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py
144+
index dea96d19..4fb18b9e 100644
145+
--- a/test/test_python_dispatch.py
146+
+++ b/test/test_python_dispatch.py
147+
@@ -1671,16 +1671,16 @@ $0 = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memo
148+
class TestPythonDispatcher(TestCase):
149+
def test_basic(self):
150+
x = torch.randn(2, requires_grad=True)
151+
- r = torch._C._EnablePythonDispatcher()
152+
- torch.add(x, x)
153+
+ with torch._C._EnablePythonDispatcher():
154+
+ torch.add(x, x)
155+
156+
def test_lstsq(self):
157+
a = torch.randn(4, 3)
158+
b = torch.rand(4, 3)
159+
expected_shape = torch.linalg.lstsq(a, b).solution.shape
160+
- r = torch._C._EnablePythonDispatcher()
161+
- python_disp_shape = torch.linalg.lstsq(a, b).solution.shape
162+
- self.assertEqual(expected_shape, python_disp_shape)
163+
+ with torch._C._EnablePythonDispatcher():
164+
+ python_disp_shape = torch.linalg.lstsq(a, b).solution.shape
165+
+ self.assertEqual(expected_shape, python_disp_shape)
166+
167+
if __name__ == '__main__':
168+
run_tests()
124169
diff --git a/third_party/fbgemm/CMakeLists.txt b/third_party/fbgemm/CMakeLists.txt
125170
index 58dcb9ae..b0ad68aa 100644
126171
--- a/third_party/fbgemm/CMakeLists.txt
@@ -207,6 +252,60 @@ index db8f2401..bac3c0d9 100644
207252
auto *tb = reinterpret_cast<PyTracebackObject *>(m_trace.ptr());
208253

209254
// Get the deepest trace possible.
255+
diff --git a/torch/_dispatch/python.py b/torch/_dispatch/python.py
256+
index 95b7fa05..8d6039c6 100644
257+
--- a/torch/_dispatch/python.py
258+
+++ b/torch/_dispatch/python.py
259+
@@ -3,18 +3,8 @@ from contextlib import contextmanager
260+
261+
__all__ = ['enable_python_dispatcher', 'no_python_dispatcher']
262+
263+
-@contextmanager
264+
def no_python_dispatcher():
265+
- g = torch._C._DisablePythonDispatcher()
266+
- try:
267+
- yield
268+
- finally:
269+
- del g
270+
+ return torch._C._DisablePythonDispatcher()
271+
272+
-@contextmanager
273+
def enable_python_dispatcher():
274+
- g = torch._C._EnablePythonDispatcher()
275+
- try:
276+
- yield
277+
- finally:
278+
- del g
279+
+ return torch._C._EnablePythonDispatcher()
280+
diff --git a/torch/_tensor_str.py b/torch/_tensor_str.py
281+
index 986be67a..53b5126a 100644
282+
--- a/torch/_tensor_str.py
283+
+++ b/torch/_tensor_str.py
284+
@@ -632,6 +632,6 @@ def _functorch_wrapper_str_intern(tensor, *, tensor_contents=None):
285+
286+
287+
def _str(self, *, tensor_contents=None):
288+
- with torch.no_grad():
289+
- guard = torch._C._DisableFuncTorch()
290+
+ with torch.no_grad(), \
291+
+ torch._C._DisableFuncTorch():
292+
return _str_intern(self, tensor_contents=tensor_contents)
293+
diff --git a/torch/autograd/grad_mode.py b/torch/autograd/grad_mode.py
294+
index b847129d..3219d1e1 100644
295+
--- a/torch/autograd/grad_mode.py
296+
+++ b/torch/autograd/grad_mode.py
297+
@@ -292,9 +292,10 @@ class inference_mode(_DecoratorContextManager):
298+
299+
def __enter__(self):
300+
self._inference_mode_raii_guard = torch._C._InferenceMode(self.mode)
301+
+ self._inference_mode_raii_guard.__enter__()
302+
303+
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
304+
- del self._inference_mode_raii_guard
305+
+ self._inference_mode_raii_guard.__exit__(exc_type, exc_value, traceback)
306+
307+
def clone(self):
308+
return self.__class__(self.mode)
210309
diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp
211310
index 8e1ca3b1..b150ac3f 100644
212311
--- a/torch/csrc/Module.cpp
@@ -269,6 +368,164 @@ index 8e1ca3b1..b150ac3f 100644
269368
}
270369

271370
Py_INCREF(obj);
371+
diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp
372+
index a1d6de21..4a8e6487 100644
373+
--- a/torch/csrc/autograd/init.cpp
374+
+++ b/torch/csrc/autograd/init.cpp
375+
@@ -43,27 +43,42 @@ struct DisableFuncTorch {
376+
c10::impl::ExcludeDispatchKeyGuard back_guard_;
377+
};
378+
379+
+struct DisableFuncTorchWrapper {
380+
+ DisableFuncTorch* delegate = nullptr;
381+
+};
382+
+
383+
struct EnableTorchFunction {
384+
EnableTorchFunction()
385+
: old_(at::impl::PythonTorchFunctionTLS::is_disabled()) {
386+
- at::impl::PythonTorchFunctionTLS::set_disabled(false);
387+
- }
388+
- ~EnableTorchFunction() {
389+
- at::impl::PythonTorchFunctionTLS::set_disabled(old_);
390+
}
391+
bool old_;
392+
};
393+
394+
struct EnablePythonDispatcher {
395+
EnablePythonDispatcher() : old_(c10::impl::PythonDispatcherTLS::get_state()) {
396+
- c10::impl::PythonDispatcherTLS::set_state(getPyInterpreter());
397+
- }
398+
- ~EnablePythonDispatcher() {
399+
- c10::impl::PythonDispatcherTLS::set_state(old_);
400+
}
401+
c10::impl::PyInterpreter* old_;
402+
};
403+
404+
+struct DisableTorchDispatchWrapper {
405+
+ torch::DisableTorchDispatch* delegate = nullptr;
406+
+};
407+
+
408+
+struct InferenceModeWrapper {
409+
+ bool enabled;
410+
+ torch::InferenceMode* delegate = nullptr;
411+
+
412+
+ InferenceModeWrapper(bool enabled) : enabled(enabled) {}
413+
+};
414+
+
415+
+struct RestorePythonTLSSnapshotWrapper {
416+
+ at::impl::RestorePythonTLSSnapshot* delegate = nullptr;
417+
+};
418+
+
419+
+struct DisablePythonDispatcherWrapper {
420+
+ c10::impl::DisablePythonDispatcher* delegate = nullptr;
421+
+};
422+
+
423+
} // namespace
424+
425+
PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
426+
@@ -337,23 +352,92 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
427+
428+
_C_m.def("_activate_cuda_trace", []() { activateCUDATrace(); });
429+
430+
- py::class_<c10::InferenceMode>(_C_m, "_InferenceMode").def(py::init<bool>());
431+
+ py::class_<InferenceModeWrapper>(_C_m, "_InferenceMode")
432+
+ .def(py::init<bool>())
433+
+ .def("__enter__", [&] (InferenceModeWrapper& w) {
434+
+ if (w.delegate) {
435+
+ delete w.delegate;
436+
+ }
437+
+ w.delegate = new c10::InferenceMode(w.enabled);
438+
+ })
439+
+ .def("__exit__", [&] (InferenceModeWrapper& w, py::object& excType, py::object& excValue, py::object& excTb) {
440+
+ if (w.delegate) {
441+
+ delete w.delegate;
442+
+ }
443+
+ });
444+
445+
- py::class_<at::impl::RestorePythonTLSSnapshot>(
446+
+ py::class_<RestorePythonTLSSnapshotWrapper>(
447+
_C_m, "_RestorePythonTLSSnapshot")
448+
- .def(py::init<>());
449+
+ .def(py::init<>())
450+
+ .def("__enter__", [&] (RestorePythonTLSSnapshotWrapper& w) {
451+
+ if (w.delegate) {
452+
+ delete w.delegate;
453+
+ }
454+
+ w.delegate = new at::impl::RestorePythonTLSSnapshot;
455+
+ })
456+
+ .def("__exit__", [&] (RestorePythonTLSSnapshotWrapper& w, py::object& excType, py::object& excValue, py::object& excTb) {
457+
+ if (w.delegate) {
458+
+ delete w.delegate;
459+
+ }
460+
+ });
461+
462+
// TODO: line up this binding with DisableTorchFunction
463+
- py::class_<torch::DisableTorchDispatch>(_C_m, "_DisableTorchDispatch")
464+
- .def(py::init<>());
465+
+ py::class_<DisableTorchDispatchWrapper>(_C_m, "_DisableTorchDispatch")
466+
+ .def(py::init<>())
467+
+ .def("__enter__", [&] (DisableTorchDispatchWrapper& w) {
468+
+ if (w.delegate) {
469+
+ delete w.delegate;
470+
+ }
471+
+ w.delegate = new torch::DisableTorchDispatch;
472+
+ })
473+
+ .def("__exit__", [&] (DisableTorchDispatchWrapper& w, py::object& excType, py::object& excValue, py::object& excTb) {
474+
+ if (w.delegate) {
475+
+ delete w.delegate;
476+
+ }
477+
+ });
478+
py::class_<EnableTorchFunction>(_C_m, "_EnableTorchFunction")
479+
- .def(py::init<>());
480+
+ .def(py::init<>())
481+
+ .def("__enter__", [&] (EnableTorchFunction& w) {
482+
+ at::impl::PythonTorchFunctionTLS::set_disabled(false);
483+
+ })
484+
+ .def("__exit__", [&] (EnableTorchFunction& w, py::object& excType, py::object& excValue, py::object& excTb) {
485+
+ at::impl::PythonTorchFunctionTLS::set_disabled(w.old_);
486+
+ });
487+
py::class_<EnablePythonDispatcher>(_C_m, "_EnablePythonDispatcher")
488+
- .def(py::init<>());
489+
- py::class_<c10::impl::DisablePythonDispatcher>(
490+
+ .def(py::init<>())
491+
+ .def("__enter__", [&] (EnablePythonDispatcher& w) {
492+
+ c10::impl::PythonDispatcherTLS::set_state(getPyInterpreter());
493+
+ })
494+
+ .def("__exit__", [&] (EnablePythonDispatcher& w, py::object& excType, py::object& excValue, py::object& excTb) {
495+
+ c10::impl::PythonDispatcherTLS::set_state(w.old_);
496+
+ });
497+
+ py::class_<DisablePythonDispatcherWrapper>(
498+
_C_m, "_DisablePythonDispatcher")
499+
- .def(py::init<>());
500+
- py::class_<DisableFuncTorch>(_C_m, "_DisableFuncTorch").def(py::init<>());
501+
+ .def(py::init<>())
502+
+ .def("__enter__", [&] (DisablePythonDispatcherWrapper& w) {
503+
+ if (w.delegate) {
504+
+ delete w.delegate;
505+
+ }
506+
+ w.delegate = new c10::impl::DisablePythonDispatcher;
507+
+ })
508+
+ .def("__exit__", [&] (DisablePythonDispatcherWrapper& w, py::object& excType, py::object& excValue, py::object& excTb) {
509+
+ if (w.delegate) {
510+
+ delete w.delegate;
511+
+ }
512+
+ });
513+
+ py::class_<DisableFuncTorchWrapper>(_C_m, "_DisableFuncTorch")
514+
+ .def(py::init<>())
515+
+ .def("__enter__", [&] (DisableFuncTorchWrapper& w) {
516+
+ if (w.delegate) {
517+
+ delete w.delegate;
518+
+ }
519+
+ w.delegate = new DisableFuncTorch;
520+
+ })
521+
+ .def("__exit__", [&] (DisableFuncTorchWrapper& w, py::object& excType, py::object& excValue, py::object& excTb) {
522+
+ if (w.delegate) {
523+
+ delete w.delegate;
524+
+ }
525+
+ });
526+
527+
py::class_<torch::autograd::SavedVariable>(m, "SavedTensor")
528+
.def(py::init([]() -> torch::autograd::SavedVariable {
272529
diff --git a/torch/csrc/autograd/python_variable_indexing.cpp b/torch/csrc/autograd/python_variable_indexing.cpp
273530
index 8c9ed1d7..5183a325 100644
274531
--- a/torch/csrc/autograd/python_variable_indexing.cpp
@@ -344,3 +601,78 @@ index 44911527..d3ad9c54 100644
344601
static inline PyObject* PyFrame_GetLocals(PyFrameObject* frame) {
345602
PyFrame_FastToLocals(frame);
346603
auto res = frame->f_locals;
604+
diff --git a/torch/distributed/_shard/partial_tensor.py b/torch/distributed/_shard/partial_tensor.py
605+
index dc8d09bd..d7e4ec6c 100644
606+
--- a/torch/distributed/_shard/partial_tensor.py
607+
+++ b/torch/distributed/_shard/partial_tensor.py
608+
@@ -234,14 +234,11 @@ class _PartialTensor(torch.Tensor):
609+
return _PARTIAL_TENSOR_OPS[func](types, args, kwargs, process_group)
610+
611+
# Need to disable all dispatch to print args and kwargs appropriately.
612+
- guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined]
613+
- try:
614+
- with torch._C.DisableTorchFunction():
615+
- raise RuntimeError(
616+
- f"torch function '{func.__name__}', with args: {args} and "
617+
- f"kwargs: {kwargs} not supported for PartialTensor!")
618+
- finally:
619+
- del guard
620+
+ with torch._C._DisableTorchDispatch(), \
621+
+ torch._C.DisableTorchFunction():
622+
+ raise RuntimeError(
623+
+ f"torch function '{func.__name__}', with args: {args} and "
624+
+ f"kwargs: {kwargs} not supported for PartialTensor!")
625+
626+
@classmethod
627+
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
628+
diff --git a/torch/overrides.py b/torch/overrides.py
629+
index dbee241b..dbc601e4 100644
630+
--- a/torch/overrides.py
631+
+++ b/torch/overrides.py
632+
@@ -1885,9 +1885,10 @@ def _no_torch_function_mode() -> Iterator[None]:
633+
class enable_reentrant_dispatch():
634+
def __enter__(self):
635+
self._raii_guard = torch._C._RestorePythonTLSSnapshot()
636+
+ self._raii_guard.__enter__()
637+
638+
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
639+
- del self._raii_guard
640+
+ self._raii_guard.__exit__(exc_type, exc_value, traceback)
641+
642+
def get_buffer(tensor_subclass, data, prefix):
643+
import ctypes
644+
diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py
645+
index e3285090..7f9059aa 100644
646+
--- a/torch/testing/_internal/common_utils.py
647+
+++ b/torch/testing/_internal/common_utils.py
648+
@@ -1347,13 +1347,8 @@ def set_rng_seed(seed):
649+
np.random.seed(seed)
650+
651+
652+
-@contextmanager
653+
def disable_functorch():
654+
- guard = torch._C._DisableFuncTorch() # type: ignore[attr-defined]
655+
- try:
656+
- yield
657+
- finally:
658+
- del guard
659+
+ return torch._C._DisableFuncTorch()
660+
661+
662+
@contextlib.contextmanager
663+
diff --git a/torch/utils/_mode_utils.py b/torch/utils/_mode_utils.py
664+
index f9098c6d..adcd4920 100644
665+
--- a/torch/utils/_mode_utils.py
666+
+++ b/torch/utils/_mode_utils.py
667+
@@ -8,10 +8,5 @@ T = TypeVar('T')
668+
def all_same_mode(modes):
669+
return all(tuple(mode == modes[0] for mode in modes))
670+
671+
-@contextmanager
672+
def no_dispatch():
673+
- guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined]
674+
- try:
675+
- yield
676+
- finally:
677+
- del guard
678+
+ return torch._C._DisableTorchDispatch()

0 commit comments

Comments
 (0)