From df40d627364b815e1e9cfbfd398cf331e9b2740f Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Fri, 10 Oct 2025 13:15:15 -0700 Subject: [PATCH] decompose con1d to conv2d for cuda backend this diff decomposed conv1d into conv2d for aoti-cuda backend support Differential Revision: [D84296877](https://our.internmc.facebook.com/intern/diff/D84296877/) [ghstack-poisoned] --- .github/workflows/cuda.yml | 2 +- backends/cuda/cuda_backend.py | 9 +++++++++ backends/cuda/tests/test_cuda_export.py | 19 +++++++++++++++++++ examples/models/__init__.py | 2 ++ examples/models/toy_model/model.py | 17 +++++++++++++++++ 5 files changed, 48 insertions(+), 1 deletion(-) diff --git a/.github/workflows/cuda.yml b/.github/workflows/cuda.yml index 8dbbb254ac3..cc5222436d6 100644 --- a/.github/workflows/cuda.yml +++ b/.github/workflows/cuda.yml @@ -71,7 +71,7 @@ jobs: strategy: fail-fast: false matrix: - model: [linear, add, add_mul, resnet18] + model: [linear, add, add_mul, resnet18, conv1d] with: timeout: 90 runner: linux.g5.4xlarge.nvidia.gpu diff --git a/backends/cuda/cuda_backend.py b/backends/cuda/cuda_backend.py index ef98de29f23..0f68e084bda 100644 --- a/backends/cuda/cuda_backend.py +++ b/backends/cuda/cuda_backend.py @@ -24,9 +24,14 @@ ) from executorch.exir.backend.compile_spec_schema import CompileSpec from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu +from torch._inductor.decomposition import conv1d_to_conv2d from torch.export.passes import move_to_device_pass from torch.nn.attention import SDPBackend +cuda_decomposition_table = { + torch.ops.aten.conv1d.default: conv1d_to_conv2d, +} + # exist fallback operators in et namespace; supported_fallback_kernels: Dict[str, Any] = {} @@ -119,6 +124,10 @@ def preprocess( # replace slice_copy with slice ReplaceSliceCopyWithSlicePass()(cuda_edge_program.graph_module) + cuda_edge_program = cuda_edge_program.run_decompositions( + cuda_decomposition_table + ) + edge_program_module = cuda_edge_program.module() # Grab all input placeholders from the graph diff --git a/backends/cuda/tests/test_cuda_export.py b/backends/cuda/tests/test_cuda_export.py index d794a4f042c..ef43a3ab3cb 100644 --- a/backends/cuda/tests/test_cuda_export.py +++ b/backends/cuda/tests/test_cuda_export.py @@ -251,3 +251,22 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: self.assertIsNotNone( edge_program_manager, "Mathematical operations export failed" ) + + def test_conv1d(self): + """Test CUDA export for 1D convolution.""" + + class Conv1dModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv1d(3, 16, kernel_size=3, padding=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.conv(x) + + module = Conv1dModule() + module.eval() + inputs = (torch.randn(1, 3, 10),) + + # Test export + edge_program_manager = self._export_to_cuda_with_lower(module, inputs) + self.assertIsNotNone(edge_program_manager, "Conv1d operation export failed") diff --git a/examples/models/__init__.py b/examples/models/__init__.py index b997ba1bf6e..5f9791843aa 100644 --- a/examples/models/__init__.py +++ b/examples/models/__init__.py @@ -14,6 +14,7 @@ class Model(str, Enum): Add = "add" AddMul = "add_mul" Softmax = "softmax" + Conv1d = "conv1d" Dl3 = "dl3" Edsr = "edsr" EmformerTranscribe = "emformer_transcribe" @@ -59,6 +60,7 @@ def __str__(self) -> str: str(Model.Add): ("toy_model", "AddModule"), str(Model.AddMul): ("toy_model", "AddMulModule"), str(Model.Softmax): ("toy_model", "SoftmaxModule"), + str(Model.Conv1d): ("toy_model", "Conv1dModule"), str(Model.Dl3): ("deeplab_v3", "DeepLabV3ResNet50Model"), str(Model.Edsr): ("edsr", "EdsrModel"), str(Model.EmformerTranscribe): ("emformer_rnnt", "EmformerRnntTranscriberModel"), diff --git a/examples/models/toy_model/model.py b/examples/models/toy_model/model.py index 9ebe42e6621..e1dd290b829 100644 --- a/examples/models/toy_model/model.py +++ b/examples/models/toy_model/model.py @@ -88,3 +88,20 @@ def get_eager_model(self) -> torch.nn.Module: def get_example_inputs(self): return (torch.ones(2, 2),) + + +class Conv1dModule(torch.nn.Module, EagerModelBase): + def __init__(self): + super().__init__() + self.conv1d = torch.nn.Conv1d( + in_channels=3, out_channels=16, kernel_size=3, padding=1 + ) + + def forward(self, x): + return self.conv1d(x) + + def get_eager_model(self) -> torch.nn.Module: + return self + + def get_example_inputs(self): + return (torch.randn(1, 3, 10),)