Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ jobs:
strategy:
fail-fast: false
matrix:
model: [linear, add, add_mul, resnet18]
model: [linear, add, add_mul, resnet18, conv1d]
with:
timeout: 90
runner: linux.g5.4xlarge.nvidia.gpu
Expand Down
9 changes: 9 additions & 0 deletions backends/cuda/cuda_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,14 @@
)
from executorch.exir.backend.compile_spec_schema import CompileSpec
from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu
from torch._inductor.decomposition import conv1d_to_conv2d
from torch.export.passes import move_to_device_pass
from torch.nn.attention import SDPBackend

cuda_decomposition_table = {
torch.ops.aten.conv1d.default: conv1d_to_conv2d,
}

# exist fallback operators in et namespace;
supported_fallback_kernels: Dict[str, Any] = {}

Expand Down Expand Up @@ -119,6 +124,10 @@ def preprocess(
# replace slice_copy with slice
ReplaceSliceCopyWithSlicePass()(cuda_edge_program.graph_module)

cuda_edge_program = cuda_edge_program.run_decompositions(
cuda_decomposition_table
)

edge_program_module = cuda_edge_program.module()

# Grab all input placeholders from the graph
Expand Down
19 changes: 19 additions & 0 deletions backends/cuda/tests/test_cuda_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,3 +251,22 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
self.assertIsNotNone(
edge_program_manager, "Mathematical operations export failed"
)

def test_conv1d(self):
"""Test CUDA export for 1D convolution."""

class Conv1dModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv1d(3, 16, kernel_size=3, padding=1)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.conv(x)

module = Conv1dModule()
module.eval()
inputs = (torch.randn(1, 3, 10),)

# Test export
edge_program_manager = self._export_to_cuda_with_lower(module, inputs)
self.assertIsNotNone(edge_program_manager, "Conv1d operation export failed")
2 changes: 2 additions & 0 deletions examples/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class Model(str, Enum):
Add = "add"
AddMul = "add_mul"
Softmax = "softmax"
Conv1d = "conv1d"
Dl3 = "dl3"
Edsr = "edsr"
EmformerTranscribe = "emformer_transcribe"
Expand Down Expand Up @@ -59,6 +60,7 @@ def __str__(self) -> str:
str(Model.Add): ("toy_model", "AddModule"),
str(Model.AddMul): ("toy_model", "AddMulModule"),
str(Model.Softmax): ("toy_model", "SoftmaxModule"),
str(Model.Conv1d): ("toy_model", "Conv1dModule"),
str(Model.Dl3): ("deeplab_v3", "DeepLabV3ResNet50Model"),
str(Model.Edsr): ("edsr", "EdsrModel"),
str(Model.EmformerTranscribe): ("emformer_rnnt", "EmformerRnntTranscriberModel"),
Expand Down
17 changes: 17 additions & 0 deletions examples/models/toy_model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,20 @@ def get_eager_model(self) -> torch.nn.Module:

def get_example_inputs(self):
return (torch.ones(2, 2),)


class Conv1dModule(torch.nn.Module, EagerModelBase):
def __init__(self):
super().__init__()
self.conv1d = torch.nn.Conv1d(
in_channels=3, out_channels=16, kernel_size=3, padding=1
)

def forward(self, x):
return self.conv1d(x)

def get_eager_model(self) -> torch.nn.Module:
return self

def get_example_inputs(self):
return (torch.randn(1, 3, 10),)
Loading