- 
                Notifications
    You must be signed in to change notification settings 
- Fork 706
cuda partioner supported #14477
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
          
     Merged
      
      
    
  
     Merged
                    cuda partioner supported #14477
Changes from all commits
      Commits
    
    
  File filter
Filter by extension
Conversations
          Failed to load comments.   
        
        
          
      Loading
        
  Jump to
        
          Jump to file
        
      
      
          Failed to load files.   
        
        
          
      Loading
        
  Diff view
Diff view
There are no files selected for viewing
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,18 @@ | ||
| load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") | ||
|  | ||
| oncall("executorch") | ||
|  | ||
| runtime.python_library( | ||
| name = "cuda_partitioner", | ||
| srcs = [ | ||
| "cuda_partitioner.py", | ||
| ], | ||
| visibility = [ | ||
| "//executorch/...", | ||
| ], | ||
| deps = [ | ||
| "//caffe2:torch", | ||
| "//executorch/exir/backend:partitioner", | ||
| "//executorch/exir/backend:utils", | ||
| ], | ||
| ) | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,69 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|  | ||
| from typing import Callable, Dict, final, List, Optional, Tuple | ||
|  | ||
| import torch | ||
| from executorch.exir.backend.compile_spec_schema import CompileSpec | ||
| from executorch.exir.backend.partitioner import ( | ||
| DelegationSpec, | ||
| Partitioner, | ||
| PartitionResult, | ||
| ) | ||
| from executorch.exir.backend.utils import tag_constant_data | ||
| from torch.export.exported_program import ExportedProgram | ||
|  | ||
|  | ||
| @final | ||
| class CudaPartitioner(Partitioner): | ||
| """ | ||
| CUDA partitioner for AOTInductor backend integration. | ||
|  | ||
| This partitioner creates a single partition containing all operators from the input graph. | ||
| It skips core ATen decomposition, allowing the CUDA backend to handle decomposition using | ||
| AOTInductor's CUDA-specific decomposition table. | ||
|  | ||
| Only operators that cannot be handled by the aoti-cuda library will be excluded from | ||
| the partition and fall back to ExecuTorch's default or custom handling. | ||
| """ | ||
|  | ||
| def __init__(self, compile_spec: List[CompileSpec]) -> None: | ||
| self.delegation_spec = DelegationSpec("CudaBackend", compile_spec) | ||
|  | ||
| def partition(self, exported_program: ExportedProgram) -> PartitionResult: | ||
| """ | ||
| Fully delegate the graph to AOTInductor by tagging all nodes as a single partition. | ||
| """ | ||
|  | ||
| partition_tags: Dict[str, DelegationSpec] = {} | ||
| for node in exported_program.graph.nodes: | ||
| if node.op != "call_function": | ||
| continue | ||
| tag = "tag0" | ||
| node.meta["delegation_tag"] = tag | ||
| partition_tags[tag] = self.delegation_spec | ||
|  | ||
| tag_constant_data(exported_program) | ||
|  | ||
| return PartitionResult( | ||
| tagged_exported_program=exported_program, partition_tags=partition_tags | ||
| ) | ||
|  | ||
| def ops_to_not_decompose( | ||
| self, ep: ExportedProgram | ||
| ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]: | ||
| """ | ||
| Return a list of operations that should not be decomposed and let the AOT compiler handle them. | ||
| Currently we skip ATen decompositon for all ops, and let the cuda backend handle them. | ||
| """ | ||
| do_not_decompose = set() | ||
|  | ||
| for node in ep.graph.nodes: | ||
| if node.op == "call_function" and isinstance( | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can aoti eat control flow hops? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. didn't test that. Let me add a test later. | ||
| node.target, torch._ops.OpOverload | ||
| ): | ||
| do_not_decompose.add(node.target) | ||
| return list(do_not_decompose), None | ||
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,20 @@ | ||
| load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") | ||
| load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest") | ||
|  | ||
| oncall("executorch") | ||
|  | ||
| python_unittest( | ||
| name = "test_cuda_partitioner", | ||
| srcs = [ | ||
| "test_cuda_partitioner.py", | ||
| ], | ||
| visibility = [ | ||
| "//executorch/...", | ||
| ], | ||
| deps = [ | ||
| "//caffe2:torch", | ||
| "//executorch/backends/cuda:cuda_partitioner", | ||
| "//executorch/exir:lib", | ||
| "//executorch/exir/backend:compile_spec_schema", | ||
| ], | ||
| ) | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,143 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|  | ||
| import unittest | ||
| from typing import Tuple | ||
|  | ||
| import torch | ||
| from executorch.backends.cuda.cuda_partitioner import CudaPartitioner | ||
| from executorch.exir.backend.compile_spec_schema import CompileSpec | ||
| from executorch.exir.backend.partitioner import PartitionResult | ||
| from torch.export import export | ||
|  | ||
|  | ||
| class TestCudaPartitioner(unittest.TestCase): | ||
| """ | ||
| Test CUDA partitioner functionality. | ||
|  | ||
| After CUDA partitioning, there should be exactly one partitioned graph that contains | ||
| all operators from the input graph. This means all operators should be tagged with | ||
| the same delegation tag, indicating they will all be executed by the CUDA backend. | ||
| """ | ||
|  | ||
| def _get_partition_result( | ||
| self, module: torch.nn.Module, inputs: Tuple[torch.Tensor, ...] | ||
| ) -> PartitionResult: | ||
| """Helper method to get partition result for a given module.""" | ||
| # Export the model | ||
| exported_program = export(module, inputs, strict=True) | ||
|  | ||
| # Create partitioner and compile specs | ||
| compile_specs = [CompileSpec("cuda_compile_options", b"")] | ||
| partitioner = CudaPartitioner(compile_specs) | ||
|  | ||
| # Get partition result | ||
| partition_result = partitioner.partition(exported_program) | ||
|  | ||
| # Verify partition result structure | ||
| self.assertIsNotNone(partition_result) | ||
| self.assertTrue(hasattr(partition_result, "tagged_exported_program")) | ||
| self.assertTrue(hasattr(partition_result, "partition_tags")) | ||
|  | ||
| return partition_result | ||
|  | ||
| def _check_fully_partitioned(self, partition_result: PartitionResult) -> bool: | ||
| """Check if the graph is fully partitioned (all operators have the same tag).""" | ||
| tagged_nodes = [] | ||
| untagged_ops = [] | ||
|  | ||
| for node in partition_result.tagged_exported_program.graph.nodes: | ||
| if node.op == "call_function": | ||
| if hasattr(node, "meta") and "delegation_tag" in node.meta: | ||
| tagged_nodes.append(node) | ||
| else: | ||
| untagged_ops.append(node) | ||
|  | ||
| # Check if we have any tagged nodes | ||
| if not tagged_nodes: | ||
| return False | ||
|  | ||
| # Check if all tagged nodes have the same tag | ||
| first_tag = tagged_nodes[0].meta["delegation_tag"] | ||
| all_same_tag = all( | ||
| node.meta.get("delegation_tag") == first_tag for node in tagged_nodes | ||
| ) | ||
|  | ||
| # Should have no untagged operations for full partitioning | ||
| fully_partitioned = len(untagged_ops) == 0 and all_same_tag | ||
|  | ||
| return fully_partitioned | ||
|  | ||
| def test_simple_add_partition(self): | ||
| """ | ||
| Test that CUDA partitioner creates exactly one partition containing all operators. | ||
| Simple element-wise addition should result in a single graph with all ops tagged identically. | ||
| """ | ||
|  | ||
| class AddModule(torch.nn.Module): | ||
| def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | ||
| return x + y | ||
|  | ||
| module = AddModule() | ||
| inputs = (torch.randn(3, 4), torch.randn(3, 4)) | ||
|  | ||
| partition_result = self._get_partition_result(module, inputs) | ||
| fully_partitioned = self._check_fully_partitioned(partition_result) | ||
|  | ||
| self.assertTrue( | ||
| fully_partitioned, | ||
| "Graph should be fully partitioned with all operators having the same tag", | ||
| ) | ||
|  | ||
| def test_conv2d_partition(self): | ||
| """ | ||
| Test that CUDA partitioner creates exactly one partition containing all operators. | ||
| Conv2D operation should result in a single graph with all ops tagged identically. | ||
| """ | ||
|  | ||
| class Conv2dModule(torch.nn.Module): | ||
| def __init__(self): | ||
| super().__init__() | ||
| self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, padding=1) | ||
|  | ||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| return self.conv(x) | ||
|  | ||
| module = Conv2dModule() | ||
| inputs = (torch.randn(1, 3, 32, 32),) | ||
|  | ||
| partition_result = self._get_partition_result(module, inputs) | ||
| fully_partitioned = self._check_fully_partitioned(partition_result) | ||
|  | ||
| self.assertTrue( | ||
| fully_partitioned, | ||
| "Graph should be fully partitioned with all operators having the same tag", | ||
| ) | ||
|  | ||
| def test_linear_partition(self): | ||
| """ | ||
| Test that CUDA partitioner creates exactly one partition containing all operators. | ||
| Linear layer operation should result in a single graph with all ops tagged identically. | ||
| """ | ||
|  | ||
| class LinearModule(torch.nn.Module): | ||
| def __init__(self): | ||
| super().__init__() | ||
| self.linear = torch.nn.Linear(128, 64) | ||
|  | ||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| return self.linear(x) | ||
|  | ||
| module = LinearModule() | ||
| inputs = (torch.randn(8, 128),) | ||
|  | ||
| partition_result = self._get_partition_result(module, inputs) | ||
| fully_partitioned = self._check_fully_partitioned(partition_result) | ||
|  | ||
| self.assertTrue( | ||
| fully_partitioned, | ||
| "Graph should be fully partitioned with all operators having the same tag", | ||
| ) | 
      
      Oops, something went wrong.
        
    
  
  Add this suggestion to a batch that can be applied as a single commit.
  This suggestion is invalid because no changes were made to the code.
  Suggestions cannot be applied while the pull request is closed.
  Suggestions cannot be applied while viewing a subset of changes.
  Only one suggestion per line can be applied in a batch.
  Add this suggestion to a batch that can be applied as a single commit.
  Applying suggestions on deleted lines is not supported.
  You must change the existing code in this line in order to create a valid suggestion.
  Outdated suggestions cannot be applied.
  This suggestion has been applied or marked resolved.
  Suggestions cannot be applied from pending reviews.
  Suggestions cannot be applied on multi-line comments.
  Suggestions cannot be applied while the pull request is queued to merge.
  Suggestion cannot be applied right now. Please check back later.
  
    
  
    
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mark as experimental