@@ -121,6 +121,51 @@ index 2ef8b7f2..1f4efd70 100644
121
121
report("Building wheel {}-{}".format(package_name, version))
122
122
123
123
cmake = CMake()
124
+ diff --git a/test/test_overrides.py b/test/test_overrides.py
125
+ index e9e01684..f4069bc9 100644
126
+ --- a/test/test_overrides.py
127
+ +++ b/test/test_overrides.py
128
+ @@ -1456,12 +1456,9 @@ class TestTorchFunctionMode(TestCase):
129
+ pass
130
+
131
+ x = A(torch.randn(5))
132
+ - with torch._C.DisableTorchFunction():
133
+ - g = torch._C._EnableTorchFunction()
134
+ - try:
135
+ + with torch._C.DisableTorchFunction(), \
136
+ + torch._C._EnableTorchFunction():
137
+ self.assertIsInstance(torch.sum(x), A)
138
+ - finally:
139
+ - del g
140
+
141
+ def test_subclass_hash(self):
142
+ class DiagTensor(torch.Tensor):
143
+ diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py
144
+ index dea96d19..4fb18b9e 100644
145
+ --- a/test/test_python_dispatch.py
146
+ +++ b/test/test_python_dispatch.py
147
+ @@ -1671,16 +1671,16 @@ $0 = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memo
148
+ class TestPythonDispatcher(TestCase):
149
+ def test_basic(self):
150
+ x = torch.randn(2, requires_grad=True)
151
+ - r = torch._C._EnablePythonDispatcher()
152
+ - torch.add(x, x)
153
+ + with torch._C._EnablePythonDispatcher():
154
+ + torch.add(x, x)
155
+
156
+ def test_lstsq(self):
157
+ a = torch.randn(4, 3)
158
+ b = torch.rand(4, 3)
159
+ expected_shape = torch.linalg.lstsq(a, b).solution.shape
160
+ - r = torch._C._EnablePythonDispatcher()
161
+ - python_disp_shape = torch.linalg.lstsq(a, b).solution.shape
162
+ - self.assertEqual(expected_shape, python_disp_shape)
163
+ + with torch._C._EnablePythonDispatcher():
164
+ + python_disp_shape = torch.linalg.lstsq(a, b).solution.shape
165
+ + self.assertEqual(expected_shape, python_disp_shape)
166
+
167
+ if __name__ == '__main__':
168
+ run_tests()
124
169
diff --git a/third_party/fbgemm/CMakeLists.txt b/third_party/fbgemm/CMakeLists.txt
125
170
index 58dcb9ae..b0ad68aa 100644
126
171
--- a/third_party/fbgemm/CMakeLists.txt
@@ -207,6 +252,60 @@ index db8f2401..bac3c0d9 100644
207
252
auto *tb = reinterpret_cast<PyTracebackObject *>(m_trace.ptr());
208
253
209
254
// Get the deepest trace possible.
255
+ diff --git a/torch/_dispatch/python.py b/torch/_dispatch/python.py
256
+ index 95b7fa05..8d6039c6 100644
257
+ --- a/torch/_dispatch/python.py
258
+ +++ b/torch/_dispatch/python.py
259
+ @@ -3,18 +3,8 @@ from contextlib import contextmanager
260
+
261
+ __all__ = ['enable_python_dispatcher', 'no_python_dispatcher']
262
+
263
+ - @contextmanager
264
+ def no_python_dispatcher():
265
+ - g = torch._C._DisablePythonDispatcher()
266
+ - try:
267
+ - yield
268
+ - finally:
269
+ - del g
270
+ + return torch._C._DisablePythonDispatcher()
271
+
272
+ - @contextmanager
273
+ def enable_python_dispatcher():
274
+ - g = torch._C._EnablePythonDispatcher()
275
+ - try:
276
+ - yield
277
+ - finally:
278
+ - del g
279
+ + return torch._C._EnablePythonDispatcher()
280
+ diff --git a/torch/_tensor_str.py b/torch/_tensor_str.py
281
+ index 986be67a..53b5126a 100644
282
+ --- a/torch/_tensor_str.py
283
+ +++ b/torch/_tensor_str.py
284
+ @@ -632,6 +632,6 @@ def _functorch_wrapper_str_intern(tensor, *, tensor_contents=None):
285
+
286
+
287
+ def _str(self, *, tensor_contents=None):
288
+ - with torch.no_grad():
289
+ - guard = torch._C._DisableFuncTorch()
290
+ + with torch.no_grad(), \
291
+ + torch._C._DisableFuncTorch():
292
+ return _str_intern(self, tensor_contents=tensor_contents)
293
+ diff --git a/torch/autograd/grad_mode.py b/torch/autograd/grad_mode.py
294
+ index b847129d..3219d1e1 100644
295
+ --- a/torch/autograd/grad_mode.py
296
+ +++ b/torch/autograd/grad_mode.py
297
+ @@ -292,9 +292,10 @@ class inference_mode(_DecoratorContextManager):
298
+
299
+ def __enter__(self):
300
+ self._inference_mode_raii_guard = torch._C._InferenceMode(self.mode)
301
+ + self._inference_mode_raii_guard.__enter__()
302
+
303
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
304
+ - del self._inference_mode_raii_guard
305
+ + self._inference_mode_raii_guard.__exit__(exc_type, exc_value, traceback)
306
+
307
+ def clone(self):
308
+ return self.__class__(self.mode)
210
309
diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp
211
310
index 8e1ca3b1..b150ac3f 100644
212
311
--- a/torch/csrc/Module.cpp
@@ -269,6 +368,164 @@ index 8e1ca3b1..b150ac3f 100644
269
368
}
270
369
271
370
Py_INCREF(obj);
371
+ diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp
372
+ index a1d6de21..4a8e6487 100644
373
+ --- a/torch/csrc/autograd/init.cpp
374
+ +++ b/torch/csrc/autograd/init.cpp
375
+ @@ -43,27 +43,42 @@ struct DisableFuncTorch {
376
+ c10::impl::ExcludeDispatchKeyGuard back_guard_;
377
+ };
378
+
379
+ + struct DisableFuncTorchWrapper {
380
+ + DisableFuncTorch* delegate = nullptr;
381
+ + };
382
+ +
383
+ struct EnableTorchFunction {
384
+ EnableTorchFunction()
385
+ : old_(at::impl::PythonTorchFunctionTLS::is_disabled()) {
386
+ - at::impl::PythonTorchFunctionTLS::set_disabled(false);
387
+ - }
388
+ - ~EnableTorchFunction() {
389
+ - at::impl::PythonTorchFunctionTLS::set_disabled(old_);
390
+ }
391
+ bool old_;
392
+ };
393
+
394
+ struct EnablePythonDispatcher {
395
+ EnablePythonDispatcher() : old_(c10::impl::PythonDispatcherTLS::get_state()) {
396
+ - c10::impl::PythonDispatcherTLS::set_state(getPyInterpreter());
397
+ - }
398
+ - ~EnablePythonDispatcher() {
399
+ - c10::impl::PythonDispatcherTLS::set_state(old_);
400
+ }
401
+ c10::impl::PyInterpreter* old_;
402
+ };
403
+
404
+ + struct DisableTorchDispatchWrapper {
405
+ + torch::DisableTorchDispatch* delegate = nullptr;
406
+ + };
407
+ +
408
+ + struct InferenceModeWrapper {
409
+ + bool enabled;
410
+ + torch::InferenceMode* delegate = nullptr;
411
+ +
412
+ + InferenceModeWrapper(bool enabled) : enabled(enabled) {}
413
+ + };
414
+ +
415
+ + struct RestorePythonTLSSnapshotWrapper {
416
+ + at::impl::RestorePythonTLSSnapshot* delegate = nullptr;
417
+ + };
418
+ +
419
+ + struct DisablePythonDispatcherWrapper {
420
+ + c10::impl::DisablePythonDispatcher* delegate = nullptr;
421
+ + };
422
+ +
423
+ } // namespace
424
+
425
+ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
426
+ @@ -337,23 +352,92 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
427
+
428
+ _C_m.def("_activate_cuda_trace", []() { activateCUDATrace(); });
429
+
430
+ - py::class_<c10::InferenceMode>(_C_m, "_InferenceMode").def(py::init<bool>());
431
+ + py::class_<InferenceModeWrapper>(_C_m, "_InferenceMode")
432
+ + .def(py::init<bool>())
433
+ + .def("__enter__", [&] (InferenceModeWrapper& w) {
434
+ + if (w.delegate) {
435
+ + delete w.delegate;
436
+ + }
437
+ + w.delegate = new c10::InferenceMode(w.enabled);
438
+ + })
439
+ + .def("__exit__", [&] (InferenceModeWrapper& w, py::object& excType, py::object& excValue, py::object& excTb) {
440
+ + if (w.delegate) {
441
+ + delete w.delegate;
442
+ + }
443
+ + });
444
+
445
+ - py::class_<at::impl::RestorePythonTLSSnapshot>(
446
+ + py::class_<RestorePythonTLSSnapshotWrapper>(
447
+ _C_m, "_RestorePythonTLSSnapshot")
448
+ - .def(py::init<>());
449
+ + .def(py::init<>())
450
+ + .def("__enter__", [&] (RestorePythonTLSSnapshotWrapper& w) {
451
+ + if (w.delegate) {
452
+ + delete w.delegate;
453
+ + }
454
+ + w.delegate = new at::impl::RestorePythonTLSSnapshot;
455
+ + })
456
+ + .def("__exit__", [&] (RestorePythonTLSSnapshotWrapper& w, py::object& excType, py::object& excValue, py::object& excTb) {
457
+ + if (w.delegate) {
458
+ + delete w.delegate;
459
+ + }
460
+ + });
461
+
462
+ // TODO: line up this binding with DisableTorchFunction
463
+ - py::class_<torch::DisableTorchDispatch>(_C_m, "_DisableTorchDispatch")
464
+ - .def(py::init<>());
465
+ + py::class_<DisableTorchDispatchWrapper>(_C_m, "_DisableTorchDispatch")
466
+ + .def(py::init<>())
467
+ + .def("__enter__", [&] (DisableTorchDispatchWrapper& w) {
468
+ + if (w.delegate) {
469
+ + delete w.delegate;
470
+ + }
471
+ + w.delegate = new torch::DisableTorchDispatch;
472
+ + })
473
+ + .def("__exit__", [&] (DisableTorchDispatchWrapper& w, py::object& excType, py::object& excValue, py::object& excTb) {
474
+ + if (w.delegate) {
475
+ + delete w.delegate;
476
+ + }
477
+ + });
478
+ py::class_<EnableTorchFunction>(_C_m, "_EnableTorchFunction")
479
+ - .def(py::init<>());
480
+ + .def(py::init<>())
481
+ + .def("__enter__", [&] (EnableTorchFunction& w) {
482
+ + at::impl::PythonTorchFunctionTLS::set_disabled(false);
483
+ + })
484
+ + .def("__exit__", [&] (EnableTorchFunction& w, py::object& excType, py::object& excValue, py::object& excTb) {
485
+ + at::impl::PythonTorchFunctionTLS::set_disabled(w.old_);
486
+ + });
487
+ py::class_<EnablePythonDispatcher>(_C_m, "_EnablePythonDispatcher")
488
+ - .def(py::init<>());
489
+ - py::class_<c10::impl::DisablePythonDispatcher>(
490
+ + .def(py::init<>())
491
+ + .def("__enter__", [&] (EnablePythonDispatcher& w) {
492
+ + c10::impl::PythonDispatcherTLS::set_state(getPyInterpreter());
493
+ + })
494
+ + .def("__exit__", [&] (EnablePythonDispatcher& w, py::object& excType, py::object& excValue, py::object& excTb) {
495
+ + c10::impl::PythonDispatcherTLS::set_state(w.old_);
496
+ + });
497
+ + py::class_<DisablePythonDispatcherWrapper>(
498
+ _C_m, "_DisablePythonDispatcher")
499
+ - .def(py::init<>());
500
+ - py::class_<DisableFuncTorch>(_C_m, "_DisableFuncTorch").def(py::init<>());
501
+ + .def(py::init<>())
502
+ + .def("__enter__", [&] (DisablePythonDispatcherWrapper& w) {
503
+ + if (w.delegate) {
504
+ + delete w.delegate;
505
+ + }
506
+ + w.delegate = new c10::impl::DisablePythonDispatcher;
507
+ + })
508
+ + .def("__exit__", [&] (DisablePythonDispatcherWrapper& w, py::object& excType, py::object& excValue, py::object& excTb) {
509
+ + if (w.delegate) {
510
+ + delete w.delegate;
511
+ + }
512
+ + });
513
+ + py::class_<DisableFuncTorchWrapper>(_C_m, "_DisableFuncTorch")
514
+ + .def(py::init<>())
515
+ + .def("__enter__", [&] (DisableFuncTorchWrapper& w) {
516
+ + if (w.delegate) {
517
+ + delete w.delegate;
518
+ + }
519
+ + w.delegate = new DisableFuncTorch;
520
+ + })
521
+ + .def("__exit__", [&] (DisableFuncTorchWrapper& w, py::object& excType, py::object& excValue, py::object& excTb) {
522
+ + if (w.delegate) {
523
+ + delete w.delegate;
524
+ + }
525
+ + });
526
+
527
+ py::class_<torch::autograd::SavedVariable>(m, "SavedTensor")
528
+ .def(py::init([]() -> torch::autograd::SavedVariable {
272
529
diff --git a/torch/csrc/autograd/python_variable_indexing.cpp b/torch/csrc/autograd/python_variable_indexing.cpp
273
530
index 8c9ed1d7..5183a325 100644
274
531
--- a/torch/csrc/autograd/python_variable_indexing.cpp
@@ -344,3 +601,78 @@ index 44911527..d3ad9c54 100644
344
601
static inline PyObject* PyFrame_GetLocals(PyFrameObject* frame) {
345
602
PyFrame_FastToLocals(frame);
346
603
auto res = frame->f_locals;
604
+ diff --git a/torch/distributed/_shard/partial_tensor.py b/torch/distributed/_shard/partial_tensor.py
605
+ index dc8d09bd..d7e4ec6c 100644
606
+ --- a/torch/distributed/_shard/partial_tensor.py
607
+ +++ b/torch/distributed/_shard/partial_tensor.py
608
+ @@ -234,14 +234,11 @@ class _PartialTensor(torch.Tensor):
609
+ return _PARTIAL_TENSOR_OPS[func](types, args, kwargs, process_group)
610
+
611
+ # Need to disable all dispatch to print args and kwargs appropriately.
612
+ - guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined]
613
+ - try:
614
+ - with torch._C.DisableTorchFunction():
615
+ - raise RuntimeError(
616
+ - f"torch function '{func.__name__}', with args: {args} and "
617
+ - f"kwargs: {kwargs} not supported for PartialTensor!")
618
+ - finally:
619
+ - del guard
620
+ + with torch._C._DisableTorchDispatch(), \
621
+ + torch._C.DisableTorchFunction():
622
+ + raise RuntimeError(
623
+ + f"torch function '{func.__name__}', with args: {args} and "
624
+ + f"kwargs: {kwargs} not supported for PartialTensor!")
625
+
626
+ @classmethod
627
+ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
628
+ diff --git a/torch/overrides.py b/torch/overrides.py
629
+ index dbee241b..dbc601e4 100644
630
+ --- a/torch/overrides.py
631
+ +++ b/torch/overrides.py
632
+ @@ -1885,9 +1885,10 @@ def _no_torch_function_mode() -> Iterator[None]:
633
+ class enable_reentrant_dispatch():
634
+ def __enter__(self):
635
+ self._raii_guard = torch._C._RestorePythonTLSSnapshot()
636
+ + self._raii_guard.__enter__()
637
+
638
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
639
+ - del self._raii_guard
640
+ + self._raii_guard.__exit__(exc_type, exc_value, traceback)
641
+
642
+ def get_buffer(tensor_subclass, data, prefix):
643
+ import ctypes
644
+ diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py
645
+ index e3285090..7f9059aa 100644
646
+ --- a/torch/testing/_internal/common_utils.py
647
+ +++ b/torch/testing/_internal/common_utils.py
648
+ @@ -1347,13 +1347,8 @@ def set_rng_seed(seed):
649
+ np.random.seed(seed)
650
+
651
+
652
+ - @contextmanager
653
+ def disable_functorch():
654
+ - guard = torch._C._DisableFuncTorch() # type: ignore[attr-defined]
655
+ - try:
656
+ - yield
657
+ - finally:
658
+ - del guard
659
+ + return torch._C._DisableFuncTorch()
660
+
661
+
662
+ @contextlib.contextmanager
663
+ diff --git a/torch/utils/_mode_utils.py b/torch/utils/_mode_utils.py
664
+ index f9098c6d..adcd4920 100644
665
+ --- a/torch/utils/_mode_utils.py
666
+ +++ b/torch/utils/_mode_utils.py
667
+ @@ -8,10 +8,5 @@ T = TypeVar('T')
668
+ def all_same_mode(modes):
669
+ return all(tuple(mode == modes[0] for mode in modes))
670
+
671
+ - @contextmanager
672
+ def no_dispatch():
673
+ - guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined]
674
+ - try:
675
+ - yield
676
+ - finally:
677
+ - del guard
678
+ + return torch._C._DisableTorchDispatch()
0 commit comments