Skip to content

Commit 8e076d8

Browse files
swolchokpytorchmergebot
authored andcommitted
Don't call check_has_torch_dispatch in THPVariable_NewWithVar if we already know (pytorch#161591)
We already know when we're called from make_wrapper_subclass or make_dtensor. The check isn't particularly cheap. Differential Revision: [D81530099](https://our.internmc.facebook.com/intern/diff/D81530099) Pull Request resolved: pytorch#161591 Approved by: https://github.com/ezyang ghstack dependencies: pytorch#161466, pytorch#161586, pytorch#161590
1 parent f044fa2 commit 8e076d8

File tree

1 file changed

+22
-5
lines changed

1 file changed

+22
-5
lines changed

torch/csrc/autograd/python_variable.cpp

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,8 @@ PyObject* ParameterClass = nullptr;
209209
static PyObject* THPVariable_NewWithVar(
210210
PyTypeObject* type,
211211
const at::TensorBase& _var,
212-
bool allow_preexisting_pyobj = false);
212+
bool allow_preexisting_pyobj = false,
213+
std::optional<bool> has_torch_dispatch_if_known = std::nullopt);
213214

214215
// clang-tidy gets confused by static const
215216
static const char* VOLATILE_WARNING =
@@ -777,7 +778,13 @@ static PyObject* THPVariable_make_wrapper_subclass(
777778
tensor.unsafeGetTensorImpl()->set_python_custom_layout(true);
778779
}
779780

780-
return THPVariable_NewWithVar((PyTypeObject*)cls, tensor);
781+
return THPVariable_NewWithVar(
782+
(PyTypeObject*)cls,
783+
tensor,
784+
// false is the default
785+
/*allow_preexisting_pyobj=*/false,
786+
// we checked __torch_dispatch__ above; avoid checking again.
787+
/*has_torch_dispatch_if_known=*/true);
781788
END_HANDLE_TH_ERRORS
782789
}
783790

@@ -833,7 +840,14 @@ static PyObject* THPVariable_make_dtensor(
833840
/*storage_size=*/std::nullopt,
834841
extra_dispatch_keys);
835842
tensor.set_requires_grad(r.toBool(4));
836-
return THPVariable_NewWithVar((PyTypeObject*)cls, tensor);
843+
return THPVariable_NewWithVar(
844+
(PyTypeObject*)cls,
845+
tensor,
846+
// false is the default
847+
/*allow_preexisting_pyobj=*/false,
848+
// we know DTensor has __torch_dispatch__ and we double-checked
849+
// above; avoid checking again.
850+
/*has_torch_dispatch_if_known=*/true);
837851
END_HANDLE_TH_ERRORS
838852
}
839853

@@ -2093,7 +2107,8 @@ static void THPVariable_subclass_dealloc(PyObject* self) {
20932107
static PyObject* THPVariable_NewWithVar(
20942108
PyTypeObject* type,
20952109
const at::TensorBase& _var,
2096-
bool allow_preexisting_pyobj) {
2110+
bool allow_preexisting_pyobj,
2111+
std::optional<bool> has_torch_dispatch_if_known) {
20972112
// Make sure that the reinterpret into a THPVariable* will be valid
20982113
TORCH_CHECK(
20992114
PyType_IsSubtype(type, &THPVariableType),
@@ -2186,7 +2201,9 @@ static PyObject* THPVariable_NewWithVar(
21862201
v->cdata = MaybeOwned<Variable>::owned(Variable(_var));
21872202
const auto& var = THPVariable_Unpack(v);
21882203
var.unsafeGetTensorImpl()->pyobj_slot()->init_pyobj(obj);
2189-
if (check_has_torch_dispatch(obj)) {
2204+
if (has_torch_dispatch_if_known.has_value()
2205+
? *has_torch_dispatch_if_known
2206+
: check_has_torch_dispatch(obj)) {
21902207
var.unsafeGetTensorImpl()->set_python_dispatch(true);
21912208
}
21922209
}

0 commit comments

Comments
 (0)