Skip to content

Commit d2c7c55

Browse files
committed
Make it work
1 parent 5ff0919 commit d2c7c55

File tree

6 files changed

+27
-125
lines changed

6 files changed

+27
-125
lines changed

backends/aoti/passes/TARGETS

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
oncall("executorch")
4+
5+
runtime.python_library(
6+
name = "passes",
7+
srcs = [
8+
"replace_view_copy_with_view.py",
9+
],
10+
visibility = [
11+
"//executorch/...",
12+
],
13+
deps = [
14+
"//caffe2:torch",
15+
"//executorch/exir:pass_base",
16+
],
17+
)

backends/cuda/replace_view_copy_with_view.py renamed to backends/aoti/passes/replace_view_copy_with_view.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
# pyre-strict
7+
# This pass replaces view_copy ops with view ops. This is different than
8+
# exir/passes/replace_view_copy_with_view.py and exir/passes/reinplace.py
9+
# because this should only be used in the AOTInductor backend, as it
10+
# has less restrictions on whether the tensor memory is densely packed,
811

912
from typing import Dict, Iterable, Tuple
1013

backends/apple/metal/metal_backend.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
from typing import Any, Dict, final, List, Optional, Set
1313

1414
import torch
15-
from executorch.backends.apple.metal.replace_slice_copy_with_slice import (
16-
ReplaceSliceCopyWithSlicePass,
15+
from executorch.backends.aoti.passes.replace_view_copy_with_view import (
16+
ReplaceViewCopyWithViewPass,
1717
)
1818
from executorch.exir._serialize._named_data_store import NamedDataStore
1919
from executorch.exir._warnings import experimental
@@ -93,7 +93,7 @@ def preprocess(
9393
mps_edge_program = move_to_device_pass(edge_program, "mps")
9494

9595
# replace slice_copy with slice
96-
ReplaceSliceCopyWithSlicePass()(mps_edge_program.graph_module)
96+
ReplaceViewCopyWithViewPass()(mps_edge_program.graph_module)
9797

9898
edge_program_module = mps_edge_program.module()
9999

backends/apple/metal/replace_slice_copy_with_slice.py

Lines changed: 0 additions & 118 deletions
This file was deleted.

backends/cuda/TARGETS

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@ runtime.python_library(
66
name = "cuda_backend",
77
srcs = [
88
"cuda_backend.py",
9-
"replace_view_copy_with_view.py",
109
],
1110
visibility = [
1211
"//executorch/...",
1312
],
1413
deps = [
1514
"//caffe2:torch",
15+
"//executorch/backends/aoti/passes:passes",
1616
"//executorch/exir/_serialize:lib",
1717
"//executorch/exir/backend:backend_details",
1818
"//executorch/exir/backend:compile_spec_schema",

backends/cuda/cuda_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from typing import Any, Dict, final, List, Optional, Set
1313

1414
import torch
15-
from executorch.backends.cuda.replace_view_copy_with_view import (
15+
from executorch.backends.aoti.passes.replace_view_copy_with_view import (
1616
ReplaceViewCopyWithViewPass,
1717
)
1818
from executorch.exir._serialize._named_data_store import NamedDataStore
@@ -123,7 +123,7 @@ def preprocess(
123123
# Move the edge_program from CPU to CUDA for aoti compile
124124
cuda_edge_program = move_to_device_pass(edge_program, "cuda")
125125

126-
# replace slice_copy with slice
126+
# replace slice_copy.Tensor with slice.Tensor, select_copy.int with select.int
127127
ReplaceViewCopyWithViewPass()(cuda_edge_program.graph_module)
128128

129129
cuda_edge_program = cuda_edge_program.run_decompositions(

0 commit comments

Comments
 (0)