Skip to content

Commit 53ccfd0

Browse files
authored
Fix cuda export test failures from #14715 (#14753)
1 parent a1652f9 commit 53ccfd0

File tree

4 files changed

+16
-7
lines changed

4 files changed

+16
-7
lines changed

backends/cuda/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ runtime.python_library(
66
name = "cuda_backend",
77
srcs = [
88
"cuda_backend.py",
9+
"replace_slice_copy_with_slice.py",
910
],
1011
visibility = [
1112
"//executorch/...",

backends/cuda/cuda_backend.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,9 @@ def preprocess(
144144
}
145145

146146
with collect_unsupported_fallback_kernels(), torch.nn.attention.sdpa_kernel(
147-
[SDPBackend.MATH]
147+
[
148+
SDPBackend.MATH # pyre-ignore[16]: Module `torch.nn.attention` has no attribute `SDPBackend`.
149+
]
148150
), torch.no_grad():
149151
# torch._logging.set_logs(post_grad_graphs=True)
150152
so_path = torch._inductor.aot_compile(edge_program_module, tuple(user_input_placeholders), options=options) # type: ignore[arg-type]

backends/cuda/replace_slice_copy_with_slice.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,23 @@
66

77
# pyre-strict
88

9-
from typing import Iterable
9+
from typing import Dict, Iterable, Tuple
1010

1111
import torch
1212
from executorch.exir.dialects._ops import ops
13+
from executorch.exir.dialects.edge._ops import EdgeOpOverload
1314
from executorch.exir.pass_base import ExportPass, PassResult
1415
from torch import fx
1516

1617

17-
_SLICE_COPY_TARGETS = (
18+
_SLICE_COPY_TARGETS: Tuple[torch._ops.OpOverload | EdgeOpOverload] = (
1819
torch.ops.aten.slice_copy.Tensor,
1920
ops.edge.aten.slice_copy.Tensor,
2021
)
2122

22-
_SLICE_TARGETS = {
23+
_SLICE_TARGETS: Dict[
24+
torch._ops.OpOverload | EdgeOpOverload, torch._ops.OpOverload | EdgeOpOverload
25+
] = {
2326
torch.ops.aten.slice_copy.Tensor: torch.ops.aten.slice.Tensor,
2427
ops.edge.aten.slice_copy.Tensor: ops.edge.aten.slice.Tensor,
2528
}
@@ -99,8 +102,8 @@ def _is_view_user(self, node: fx.Node, user: fx.Node) -> bool:
99102
return False
100103

101104
def _argument_mutates(
102-
self, schema: torch._C.FunctionSchema, key
103-
) -> bool: # pyre-ignore[11]
105+
self, schema: torch._C.FunctionSchema, key: int | str
106+
) -> bool:
104107
arguments = schema.arguments
105108
if isinstance(key, int):
106109
if key >= len(arguments):

backends/cuda/tests/test_cuda_export.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Tuple
99

1010
import torch
11+
from executorch.backends.cuda.cuda_backend import CudaBackend
1112
from executorch.backends.cuda.cuda_partitioner import CudaPartitioner
1213
from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower
1314
from torch.export import export
@@ -30,7 +31,9 @@ def _export_to_cuda_with_lower(
3031
exported_program = export(module, inputs, strict=True)
3132

3233
# Create partitioner and compile specs
33-
partitioner = CudaPartitioner([])
34+
partitioner = CudaPartitioner(
35+
[CudaBackend.generate_method_name_compile_spec("forward")]
36+
)
3437

3538
# Use to_edge_transform_and_lower for complete pipeline
3639
edge_program_manager = to_edge_transform_and_lower(

0 commit comments

Comments
 (0)