Skip to content

Commit e159e65

Browse files
authored
Extend reinplace pass to select_copy.int (pytorch#15136)
This pull request refactors and centralizes the logic for replacing "view_copy" operations with "view" operations in the graph transformation passes for both the AOTInductor and CUDA backends. The main change is the creation of a unified pass in `backends/aoti/passes/replace_view_copy_with_view.py`, which replaces the previous backend-specific implementations and expands support to additional ops. The backend code is updated to use this new shared pass, and redundant files are removed. **Pass refactoring and centralization:** * Created a new unified pass `replace_view_copy_with_view.py` in `backends/aoti/passes` that replaces "view_copy" type ops (including `slice_copy` and `select_copy`) with their corresponding "view" ops for use in AOTInductor and CUDA backends. [[1]](diffhunk://#diff-725a4a1f4634a11f716ae6f649894f6eea64edb21f56ad56cde92f18fdd2f713L7-R12) [[2]](diffhunk://#diff-725a4a1f4634a11f716ae6f649894f6eea64edb21f56ad56cde92f18fdd2f713L18-R44) [[3]](diffhunk://#diff-374a8b362bdad92dce92e7c3bb474dd6106fc80d7253e6b5d5a1c9fb971dc76eR1-R17) * Removed the old backend-specific pass files (`replace_slice_copy_with_slice.py`) from both `backends/apple/metal` and `backends/cuda`. [[1]](diffhunk://#diff-c4a228b182f50f778545991d472609ad705d2325994342174093ff374738851dL1-L118) [[2]](diffhunk://#diff-f0e6cbb7940752204a85a43708b5424de89eb4556698043d6cc652c07eabd624L9-R15) **Backend integration and API updates:** * Updated both `metal_backend.py` and `cuda_backend.py` to import and use the new `ReplaceViewCopyWithViewPass` instead of the previous backend-specific implementations. [[1]](diffhunk://#diff-20452c18c868bce8db75555905fdbc3a6347536697bdfea9b7187bd6c765a24eL15-R16) [[2]](diffhunk://#diff-5b5ea2257772b3aba04b2534f5ea1429a0c631bfd25a7ef531f526e76c471d7aL15-R16) * Modified the preprocessing step in both backends to apply the new pass, which now handles both `slice_copy` and `select_copy` ops. [[1]](diffhunk://#diff-20452c18c868bce8db75555905fdbc3a6347536697bdfea9b7187bd6c765a24eL96-R96) [[2]](diffhunk://#diff-5b5ea2257772b3aba04b2534f5ea1429a0c631bfd25a7ef531f526e76c471d7aL126-R127)
1 parent 312267e commit e159e65

File tree

6 files changed

+37
-138
lines changed

6 files changed

+37
-138
lines changed

backends/aoti/passes/TARGETS

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
oncall("executorch")
4+
5+
runtime.python_library(
6+
name = "passes",
7+
srcs = [
8+
"replace_view_copy_with_view.py",
9+
],
10+
visibility = [
11+
"//executorch/...",
12+
],
13+
deps = [
14+
"//caffe2:torch",
15+
"//executorch/exir:pass_base",
16+
],
17+
)

backends/apple/metal/replace_slice_copy_with_slice.py renamed to backends/aoti/passes/replace_view_copy_with_view.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,12 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
# pyre-strict
7+
# This pass replaces view_copy ops with view ops. This is different than
8+
# exir/passes/replace_view_copy_with_view.py and exir/passes/reinplace.py
9+
# because this should only be used in the AOTInductor backend, as it
10+
# has less restrictions on whether the tensor memory is densely packed,
811

9-
from typing import Dict, Iterable, Tuple
12+
from typing import Dict, Iterable
1013

1114
import torch
1215
from executorch.exir.dialects._ops import ops
@@ -15,33 +18,30 @@
1518
from torch import fx
1619

1720

18-
_SLICE_COPY_TARGETS: Tuple[torch._ops.OpOverload | EdgeOpOverload] = (
19-
torch.ops.aten.slice_copy.Tensor,
20-
ops.edge.aten.slice_copy.Tensor,
21-
)
22-
23-
_SLICE_TARGETS: Dict[
21+
_VIEW_TARGETS: Dict[
2422
torch._ops.OpOverload | EdgeOpOverload, torch._ops.OpOverload | EdgeOpOverload
2523
] = {
2624
torch.ops.aten.slice_copy.Tensor: torch.ops.aten.slice.Tensor,
2725
ops.edge.aten.slice_copy.Tensor: ops.edge.aten.slice.Tensor,
26+
torch.ops.aten.select_copy.int: torch.ops.aten.select.int,
27+
ops.edge.aten.select_copy.int: ops.edge.aten.select.int,
2828
}
2929

3030

31-
class ReplaceSliceCopyWithSlicePass(ExportPass):
32-
"""Replace non-mutated ``slice_copy`` results with ``slice`` views."""
31+
class ReplaceViewCopyWithViewPass(ExportPass):
32+
"""Replace non-mutated ``view_copy`` type of ops with ``view`` ops."""
3333

3434
def call(self, graph_module: fx.GraphModule) -> PassResult:
3535
graph_changed = False
3636

3737
for node in graph_module.graph.nodes:
38-
if node.op != "call_function" or node.target not in _SLICE_COPY_TARGETS:
38+
if node.op != "call_function" or node.target not in _VIEW_TARGETS:
3939
continue
4040

4141
if self._has_blocking_user(node, node.users.keys()):
4242
continue
4343

44-
node.target = _SLICE_TARGETS[node.target]
44+
node.target = _VIEW_TARGETS[node.target]
4545
graph_changed = True
4646

4747
if graph_changed:

backends/apple/metal/metal_backend.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
from typing import Any, Dict, final, List, Optional, Set
1313

1414
import torch
15-
from executorch.backends.apple.metal.replace_slice_copy_with_slice import (
16-
ReplaceSliceCopyWithSlicePass,
15+
from executorch.backends.aoti.passes.replace_view_copy_with_view import (
16+
ReplaceViewCopyWithViewPass,
1717
)
1818
from executorch.exir._serialize._named_data_store import NamedDataStore
1919
from executorch.exir._warnings import experimental
@@ -93,7 +93,7 @@ def preprocess(
9393
mps_edge_program = move_to_device_pass(edge_program, "mps")
9494

9595
# replace slice_copy with slice
96-
ReplaceSliceCopyWithSlicePass()(mps_edge_program.graph_module)
96+
ReplaceViewCopyWithViewPass()(mps_edge_program.graph_module)
9797

9898
edge_program_module = mps_edge_program.module()
9999

backends/cuda/TARGETS

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@ runtime.python_library(
66
name = "cuda_backend",
77
srcs = [
88
"cuda_backend.py",
9-
"replace_slice_copy_with_slice.py",
109
],
1110
visibility = [
1211
"//executorch/...",
1312
],
1413
deps = [
1514
"//caffe2:torch",
15+
"//executorch/backends/aoti/passes:passes",
1616
"//executorch/exir/_serialize:lib",
1717
"//executorch/exir/backend:backend_details",
1818
"//executorch/exir/backend:compile_spec_schema",

backends/cuda/cuda_backend.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
from typing import Any, Dict, final, List, Optional, Set
1313

1414
import torch
15-
from executorch.backends.cuda.replace_slice_copy_with_slice import (
16-
ReplaceSliceCopyWithSlicePass,
15+
from executorch.backends.aoti.passes.replace_view_copy_with_view import (
16+
ReplaceViewCopyWithViewPass,
1717
)
1818
from executorch.exir._serialize._named_data_store import NamedDataStore
1919
from executorch.exir._warnings import experimental
@@ -123,8 +123,8 @@ def preprocess(
123123
# Move the edge_program from CPU to CUDA for aoti compile
124124
cuda_edge_program = move_to_device_pass(edge_program, "cuda")
125125

126-
# replace slice_copy with slice
127-
ReplaceSliceCopyWithSlicePass()(cuda_edge_program.graph_module)
126+
# replace slice_copy.Tensor with slice.Tensor, select_copy.int with select.int
127+
ReplaceViewCopyWithViewPass()(cuda_edge_program.graph_module)
128128

129129
cuda_edge_program = cuda_edge_program.run_decompositions(
130130
cuda_decomposition_table

backends/cuda/replace_slice_copy_with_slice.py

Lines changed: 0 additions & 118 deletions
This file was deleted.

0 commit comments

Comments
 (0)