@@ -209,7 +209,8 @@ PyObject* ParameterClass = nullptr;
209209static PyObject* THPVariable_NewWithVar (
210210 PyTypeObject* type,
211211 const at::TensorBase& _var,
212- bool allow_preexisting_pyobj = false );
212+ bool allow_preexisting_pyobj = false ,
213+ std::optional<bool > has_torch_dispatch_if_known = std::nullopt );
213214
214215// clang-tidy gets confused by static const
215216static const char * VOLATILE_WARNING =
@@ -777,7 +778,13 @@ static PyObject* THPVariable_make_wrapper_subclass(
777778 tensor.unsafeGetTensorImpl ()->set_python_custom_layout (true );
778779 }
779780
780- return THPVariable_NewWithVar ((PyTypeObject*)cls, tensor);
781+ return THPVariable_NewWithVar (
782+ (PyTypeObject*)cls,
783+ tensor,
784+ // false is the default
785+ /* allow_preexisting_pyobj=*/ false ,
786+ // we checked __torch_dispatch__ above; avoid checking again.
787+ /* has_torch_dispatch_if_known=*/ true );
781788 END_HANDLE_TH_ERRORS
782789}
783790
@@ -833,7 +840,14 @@ static PyObject* THPVariable_make_dtensor(
833840 /* storage_size=*/ std::nullopt ,
834841 extra_dispatch_keys);
835842 tensor.set_requires_grad (r.toBool (4 ));
836- return THPVariable_NewWithVar ((PyTypeObject*)cls, tensor);
843+ return THPVariable_NewWithVar (
844+ (PyTypeObject*)cls,
845+ tensor,
846+ // false is the default
847+ /* allow_preexisting_pyobj=*/ false ,
848+ // we know DTensor has __torch_dispatch__ and we double-checked
849+ // above; avoid checking again.
850+ /* has_torch_dispatch_if_known=*/ true );
837851 END_HANDLE_TH_ERRORS
838852}
839853
@@ -2093,7 +2107,8 @@ static void THPVariable_subclass_dealloc(PyObject* self) {
20932107static PyObject* THPVariable_NewWithVar (
20942108 PyTypeObject* type,
20952109 const at::TensorBase& _var,
2096- bool allow_preexisting_pyobj) {
2110+ bool allow_preexisting_pyobj,
2111+ std::optional<bool > has_torch_dispatch_if_known) {
20972112 // Make sure that the reinterpret into a THPVariable* will be valid
20982113 TORCH_CHECK (
20992114 PyType_IsSubtype (type, &THPVariableType),
@@ -2186,7 +2201,9 @@ static PyObject* THPVariable_NewWithVar(
21862201 v->cdata = MaybeOwned<Variable>::owned (Variable (_var));
21872202 const auto & var = THPVariable_Unpack (v);
21882203 var.unsafeGetTensorImpl ()->pyobj_slot ()->init_pyobj (obj);
2189- if (check_has_torch_dispatch (obj)) {
2204+ if (has_torch_dispatch_if_known.has_value ()
2205+ ? *has_torch_dispatch_if_known
2206+ : check_has_torch_dispatch (obj)) {
21902207 var.unsafeGetTensorImpl ()->set_python_dispatch (true );
21912208 }
21922209 }
0 commit comments