Skip to content

Commit 49c446c

Browse files
swolchokpytorchmergebot
authored andcommitted
Add C++ function for torch.distributed.tensor._op_schema.is_view_op (pytorch#161595)
This seems to have been an especially slow one because of the repeated pybind access (schema is a pybind, as is arguments, and then we hit each argument). It's still ~~1% of total benchmark runtime because of the repeated single pybind function call, but that's a lot better. Differential Revision: [D81530095](https://our.internmc.facebook.com/intern/diff/D81530095) Pull Request resolved: pytorch#161595 Approved by: https://github.com/ezyang, https://github.com/bdhirsh ghstack dependencies: pytorch#161466, pytorch#161586, pytorch#161590, pytorch#161591
1 parent 8e076d8 commit 49c446c

File tree

3 files changed

+12
-4
lines changed

3 files changed

+12
-4
lines changed

torch/_C/__init__.pyi.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -951,6 +951,7 @@ class FunctionSchema:
951951
is_vararg: _bool,
952952
is_varret: _bool,
953953
) -> None: ...
954+
def _is_view_op(self) -> _bool: ...
954955

955956
class _UpgraderEntry:
956957
bumped_at_version: _int

torch/csrc/jit/python/init.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1962,6 +1962,16 @@ void initJITBindings(PyObject* module) {
19621962
.def_property_readonly("overload_name", &FunctionSchema::overload_name)
19631963
.def_property_readonly("arguments", &FunctionSchema::arguments)
19641964
.def_property_readonly("returns", &FunctionSchema::returns)
1965+
.def(
1966+
"_is_view_op",
1967+
[](const FunctionSchema& self) -> bool {
1968+
for (const auto& arg : self.arguments()) {
1969+
if (arg.alias_info() && !arg.alias_info()->isWrite()) {
1970+
return true;
1971+
}
1972+
}
1973+
return false;
1974+
})
19651975
.def(
19661976
"is_backward_compatible_with",
19671977
// FunctionSchema::isBackwardCompatibleWith has an extra

torch/distributed/tensor/_op_schema.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -477,10 +477,7 @@ def is_out_variant_op(self) -> bool:
477477
return "out" in self.op._schema.overload_name
478478

479479
def is_view_op(self) -> bool:
480-
return any(
481-
a.alias_info is not None and not a.alias_info.is_write
482-
for a in self.op._schema.arguments
483-
)
480+
return self.op._schema._is_view_op()
484481

485482
def _recompute_comparison_key(self):
486483
if not self.schema_info:

0 commit comments

Comments
 (0)