Skip to content

Commit 5677ab9

Browse files
malfetpytorchmergebot
authored andcommitted
[BE] Correctly pass exceptions raised from rpc_init to CPython (pytorch#154325)
By decorating function body with `HANDLE_TH_ERRORS` Partially addresses pytorch#154300 I.e. after that change, importing torch no longer crashes but returns a readable (and actionable exception) ``` >>> import torch Traceback (most recent call last): File "<python-input-0>", line 1, in <module> import torch File "/Users/malfet/git/pytorch/pytorch/torch/__init__.py", line 2134, in <module> from torch import _VF as _VF, functional as functional # usort: skip ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/malfet/git/pytorch/pytorch/torch/functional.py", line 8, in <module> import torch.nn.functional as F File "/Users/malfet/git/pytorch/pytorch/torch/nn/__init__.py", line 8, in <module> from torch.nn.modules import * # usort: skip # noqa: F403 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/malfet/git/pytorch/pytorch/torch/nn/modules/__init__.py", line 2, in <module> from .linear import Bilinear, Identity, LazyLinear, Linear # usort: skip ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/malfet/git/pytorch/pytorch/torch/nn/modules/linear.py", line 7, in <module> from torch.nn import functional as F, init File "/Users/malfet/git/pytorch/pytorch/torch/nn/functional.py", line 11, in <module> from torch._jit_internal import ( ...<5 lines>... ) File "/Users/malfet/git/pytorch/pytorch/torch/_jit_internal.py", line 42, in <module> import torch.distributed.rpc File "/Users/malfet/git/pytorch/pytorch/torch/distributed/rpc/__init__.py", line 37, in <module> from torch._C._distributed_rpc import ( # noqa: F401 ...<33 lines>... ) ImportError: cannot import name '_DEFAULT_NUM_WORKER_THREADS' from 'torch._C._distributed_rpc' (unknown location) ``` Pull Request resolved: pytorch#154325 Approved by: https://github.com/Skylion007
1 parent 31ae07b commit 5677ab9

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

torch/csrc/distributed/rpc/init.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ template <typename T>
3030
using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>;
3131

3232
PyObject* rpc_init(PyObject* _unused, PyObject* noargs) {
33+
HANDLE_TH_ERRORS
3334
auto rpc_module =
3435
THPObjectPtr(PyImport_ImportModule("torch.distributed.rpc"));
3536
if (!rpc_module) {
@@ -845,6 +846,7 @@ PyObject* rpc_init(PyObject* _unused, PyObject* noargs) {
845846
module.def("_disable_jit_rref_pickle", &disableJitRRefPickle);
846847

847848
Py_RETURN_TRUE;
849+
END_HANDLE_TH_ERRORS
848850
}
849851

850852
} // namespace

0 commit comments

Comments
 (0)