@@ -955,6 +955,133 @@ static PyObject* is_fwd_grad_enabled(PyObject* _unused, PyObject* arg) {
955955 END_HANDLE_TH_ERRORS
956956}
957957
958+ template <bool skip_tensors_in_non_tensorlist>
959+ static bool visit (
960+ PyObject* o,
961+ const std::function<bool (at::Tensor&)>& visit_tensor) {
962+ if (THPVariable_Check (o)) {
963+ auto t = THPVariable_Unpack (o);
964+ if (visit_tensor (t)) {
965+ return true ;
966+ }
967+ } else if (PyList_Check (o)) {
968+ // Check that this List is TensorList
969+ if constexpr (skip_tensors_in_non_tensorlist) {
970+ for (const auto i : c10::irange (PyList_GET_SIZE (o))) {
971+ if (!THPVariable_Check (PyList_GET_ITEM (o, i))) {
972+ return false ;
973+ }
974+ }
975+ }
976+ for (const auto i : c10::irange (PyList_GET_SIZE (o))) {
977+ if (visit<skip_tensors_in_non_tensorlist>(
978+ PyList_GET_ITEM (o, i), visit_tensor)) {
979+ return true ;
980+ };
981+ }
982+ }
983+ return false ;
984+ }
985+
986+ // Visiting of tensors in args and kwargs,
987+ // only List container is visited.
988+ // skip_tensors_in_non_tensorlist will skip any List with non-Tensor.
989+ // Lambda returning true means short circuit, traversal stops after that.
990+ template <bool skip_tensors_in_non_tensorlist>
991+ static void visit_tensors (
992+ PyObject* args,
993+ PyObject* kwargs,
994+ const std::function<bool (at::Tensor&)>& visit_tensor) {
995+ if (args && PyTuple_Check (args)) {
996+ for (const auto i : c10::irange (PyTuple_GET_SIZE (args))) {
997+ if (visit<skip_tensors_in_non_tensorlist>(
998+ PyTuple_GET_ITEM (args, i), visit_tensor)) {
999+ return ;
1000+ }
1001+ }
1002+ }
1003+ if (kwargs && PyDict_Check (kwargs)) {
1004+ auto vals = PyDict_Values (kwargs);
1005+ for (const auto i : c10::irange (PyList_GET_SIZE (vals))) {
1006+ if (visit<skip_tensors_in_non_tensorlist>(
1007+ PyList_GET_ITEM (vals, i), visit_tensor)) {
1008+ return ;
1009+ }
1010+ }
1011+ }
1012+ }
1013+
1014+ // Returns true if any of the args, kwargs tensor leaves have requires_grad.
1015+ // Only List[Tensor] container in args is supported.
1016+ static PyObject* any_requires_grad (
1017+ PyObject* _unused,
1018+ PyObject* args,
1019+ PyObject* kwargs) {
1020+ HANDLE_TH_ERRORS
1021+ bool has_requires_grad = false ;
1022+ visit_tensors<true >(args, kwargs, [&has_requires_grad](at::Tensor& t) {
1023+ if (t.requires_grad ()) {
1024+ has_requires_grad = true ;
1025+ return true ;
1026+ }
1027+ return false ;
1028+ });
1029+ if (has_requires_grad) {
1030+ Py_RETURN_TRUE;
1031+ }
1032+ Py_RETURN_FALSE;
1033+ END_HANDLE_TH_ERRORS
1034+ }
1035+
1036+ // Checks aliasing constraint for custom ops:
1037+ // Returns true if any of outputs is alias to any of inputs or another output
1038+ // Args:
1039+ // args[0] - inputs args
1040+ // args[1] - inputs kwargs
1041+ // args[2] - outputs
1042+ // Only List container is supported.
1043+ // Tensors in Lists that has not only Tensor are checked.
1044+ static PyObject* any_output_is_alias_to_input_or_output (
1045+ PyObject* _unused,
1046+ PyObject* args) {
1047+ HANDLE_TH_ERRORS
1048+ PyObject* inps = PyTuple_GET_ITEM (args, 0 );
1049+ PyObject* inps_kwargs = PyTuple_GET_ITEM (args, 1 );
1050+ PyObject* outs = PyTuple_GET_ITEM (args, 2 );
1051+ std::unordered_set<void *> s;
1052+ visit_tensors<false >(inps, inps_kwargs, [&s](at::Tensor& t) {
1053+ if (!t.storage ()) {
1054+ return false ;
1055+ }
1056+ auto * cp = t.storage ().data_ptr ().get_context ();
1057+ if (cp) {
1058+ s.insert (cp);
1059+ }
1060+ return false ;
1061+ });
1062+ bool ret = false ;
1063+ visit_tensors<false >(outs, nullptr , [&s, &ret](at::Tensor& t) {
1064+ if (!t.storage ()) {
1065+ return false ;
1066+ }
1067+ auto * cp = t.storage ().data_ptr ().get_context ();
1068+ if (!cp) {
1069+ return false ;
1070+ }
1071+ if (s.find (cp) != s.end ()) {
1072+ ret = true ;
1073+ return true ;
1074+ }
1075+ s.insert (cp);
1076+ return false ;
1077+ });
1078+ if (ret) {
1079+ Py_RETURN_TRUE;
1080+ }
1081+ Py_RETURN_FALSE;
1082+ END_HANDLE_TH_ERRORS
1083+ }
1084+
9581085static PyObject* set_multithreading_enabled (
9591086 PyObject* self,
9601087 PyObject* args,
@@ -1326,6 +1453,14 @@ static PyMethodDef methods[] = {
13261453 nullptr },
13271454 {" is_grad_enabled" , is_grad_enabled, METH_NOARGS, nullptr },
13281455 {" _set_fwd_grad_enabled" , set_fwd_grad_enabled, METH_O, nullptr },
1456+ {" _any_requires_grad" ,
1457+ castPyCFunctionWithKeywords (any_requires_grad),
1458+ METH_VARARGS | METH_KEYWORDS,
1459+ nullptr },
1460+ {" _any_output_is_alias_to_input_or_output" ,
1461+ any_output_is_alias_to_input_or_output,
1462+ METH_VARARGS,
1463+ nullptr },
13291464 {" _is_fwd_grad_enabled" , is_fwd_grad_enabled, METH_NOARGS, nullptr },
13301465 {" is_inference_mode_enabled" ,
13311466 is_inference_mode_enabled,
0 commit comments