|  | 
|  | 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. | 
|  | 2 | +# All rights reserved. | 
|  | 3 | +# | 
|  | 4 | +# This source code is licensed under the BSD-style license found in the | 
|  | 5 | +# LICENSE file in the root directory of this source tree. | 
|  | 6 | + | 
|  | 7 | +import unittest | 
|  | 8 | + | 
|  | 9 | +import torch | 
|  | 10 | +from executorch import exir | 
|  | 11 | +from executorch.exir.backend.backend_details import CompileSpec, ExportedProgram | 
|  | 12 | +from executorch.exir.backend.partitioner import ( | 
|  | 13 | +    DelegationSpec, | 
|  | 14 | +    Partitioner, | 
|  | 15 | +    PartitionResult, | 
|  | 16 | +) | 
|  | 17 | +from executorch.exir.backend.test.demos.rpc.executor_backend_preprocess import ( | 
|  | 18 | +    ExecutorBackend, | 
|  | 19 | +) | 
|  | 20 | +from executorch.exir.backend.utils import get_delegates | 
|  | 21 | +from executorch.exir.dialects._ops import ops as exir_ops | 
|  | 22 | +from torch.export import export | 
|  | 23 | + | 
|  | 24 | + | 
|  | 25 | +class TestExampleInputOfSubmodule(unittest.TestCase): | 
|  | 26 | +    """ | 
|  | 27 | +    Tests for verifying that create_exported_program_from_submodule correctly | 
|  | 28 | +    handles example inputs of subgraphs based on input signature matching. | 
|  | 29 | +
 | 
|  | 30 | +    More specifically, if the partitioner delegates a subgraph that doesn't | 
|  | 31 | +    start from the original inputs or not cover all or them, the example inputs | 
|  | 32 | +    of the delegate should  be None. Otherwise, the example inputs should match | 
|  | 33 | +    the original inputs. | 
|  | 34 | +    """ | 
|  | 35 | + | 
|  | 36 | +    def test_multiple_subgraphs_first_matches_original_others_none(self): | 
|  | 37 | +        """ | 
|  | 38 | +        Test partitioning a model into several submodules where: | 
|  | 39 | +        - First submodule starts from the very beginning (same inputs as original) | 
|  | 40 | +        - Verify first submodule has original example inputs | 
|  | 41 | +        - Verify rest of submodules have None example inputs | 
|  | 42 | +        """ | 
|  | 43 | + | 
|  | 44 | +        class ThreeStageModel(torch.nn.Module): | 
|  | 45 | +            def __init__(self): | 
|  | 46 | +                super().__init__() | 
|  | 47 | +                self.weight1 = torch.nn.Parameter(torch.tensor([2.0])) | 
|  | 48 | +                self.weight2 = torch.nn.Parameter(torch.tensor([3.0])) | 
|  | 49 | + | 
|  | 50 | +            def forward(self, x, y): | 
|  | 51 | +                # Stage 1: Direct operation on original inputs (will be first partition) | 
|  | 52 | +                stage1 = x + y  # This should match original signature | 
|  | 53 | + | 
|  | 54 | +                # Stage 2: Uses stage1 result (different signature) | 
|  | 55 | +                stage2 = stage1 * self.weight1 | 
|  | 56 | + | 
|  | 57 | +                # Stage 3: Uses stage2 result (different signature) | 
|  | 58 | +                stage3 = stage2 + self.weight2 | 
|  | 59 | + | 
|  | 60 | +                return stage3 | 
|  | 61 | + | 
|  | 62 | +        model = ThreeStageModel() | 
|  | 63 | +        example_inputs = (torch.tensor([1.0]), torch.tensor([2.0])) | 
|  | 64 | + | 
|  | 65 | +        # Create partitioner that delegates each stage separately | 
|  | 66 | +        class ThreeStagePartitioner(Partitioner): | 
|  | 67 | +            def __init__(self): | 
|  | 68 | +                super().__init__() | 
|  | 69 | +                self.spec1 = DelegationSpec( | 
|  | 70 | +                    ExecutorBackend.__name__, [CompileSpec("stage1", bytes([1]))] | 
|  | 71 | +                ) | 
|  | 72 | +                self.spec2 = DelegationSpec( | 
|  | 73 | +                    ExecutorBackend.__name__, [CompileSpec("stage2", bytes([2]))] | 
|  | 74 | +                ) | 
|  | 75 | +                self.spec3 = DelegationSpec( | 
|  | 76 | +                    ExecutorBackend.__name__, [CompileSpec("stage3", bytes([3]))] | 
|  | 77 | +                ) | 
|  | 78 | + | 
|  | 79 | +            def partition( | 
|  | 80 | +                self, edge_exported_program: ExportedProgram | 
|  | 81 | +            ) -> PartitionResult: | 
|  | 82 | +                partition_tags = {} | 
|  | 83 | +                stage_counter = 1 | 
|  | 84 | + | 
|  | 85 | +                for node in edge_exported_program.graph.nodes: | 
|  | 86 | +                    if node.op == "call_function": | 
|  | 87 | +                        if node.target == exir_ops.edge.aten.add.Tensor: | 
|  | 88 | +                            # First add operation (x + y) - uses original inputs | 
|  | 89 | +                            node.meta["delegation_tag"] = "stage1" | 
|  | 90 | +                            partition_tags["stage1"] = self.spec1 | 
|  | 91 | +                        elif node.target == exir_ops.edge.aten.mul.Tensor: | 
|  | 92 | +                            # Second operation (stage1 * weight1) - uses intermediate result | 
|  | 93 | +                            node.meta["delegation_tag"] = "stage2" | 
|  | 94 | +                            partition_tags["stage2"] = self.spec2 | 
|  | 95 | +                        elif ( | 
|  | 96 | +                            node.target == exir_ops.edge.aten.add.Tensor | 
|  | 97 | +                            and stage_counter > 1 | 
|  | 98 | +                        ): | 
|  | 99 | +                            # Third operation (stage2 + weight2) - uses intermediate result | 
|  | 100 | +                            node.meta["delegation_tag"] = "stage3" | 
|  | 101 | +                            partition_tags["stage3"] = self.spec3 | 
|  | 102 | + | 
|  | 103 | +                        stage_counter += 1 | 
|  | 104 | + | 
|  | 105 | +                return PartitionResult( | 
|  | 106 | +                    tagged_exported_program=edge_exported_program, | 
|  | 107 | +                    partition_tags=partition_tags, | 
|  | 108 | +                ) | 
|  | 109 | + | 
|  | 110 | +        exported_program = export(model, example_inputs, strict=True) | 
|  | 111 | +        edge_program = exir.to_edge(exported_program) | 
|  | 112 | + | 
|  | 113 | +        partitioned_program = edge_program.to_backend(ThreeStagePartitioner()) | 
|  | 114 | + | 
|  | 115 | +        # Get all delegate modules | 
|  | 116 | +        delegates = get_delegates(partitioned_program.exported_program().graph) | 
|  | 117 | +        self.assertGreater( | 
|  | 118 | +            len(delegates), 1, "Should have multiple delegate submodules" | 
|  | 119 | +        ) | 
|  | 120 | + | 
|  | 121 | +        # Check each delegate's example inputs | 
|  | 122 | +        delegate_modules = [] | 
|  | 123 | +        for delegate_node in delegates: | 
|  | 124 | +            delegate_module = getattr( | 
|  | 125 | +                partitioned_program.exported_program().graph_module, delegate_node.name | 
|  | 126 | +            ) | 
|  | 127 | +            delegate_modules.append(delegate_module) | 
|  | 128 | + | 
|  | 129 | +        # Sort delegates by creation order (first should be the one with original inputs) | 
|  | 130 | +        delegate_modules.sort( | 
|  | 131 | +            key=lambda x: ( | 
|  | 132 | +                x.backend_id if hasattr(x, "backend_id") else x.processed_bytes | 
|  | 133 | +            ) | 
|  | 134 | +        ) | 
|  | 135 | + | 
|  | 136 | +        # Verify first delegate has example inputs (should match original) | 
|  | 137 | +        first_delegate = delegate_modules[0] | 
|  | 138 | +        self.assertIsNotNone( | 
|  | 139 | +            first_delegate.original_module, "First delegate should have original_module" | 
|  | 140 | +        ) | 
|  | 141 | + | 
|  | 142 | +        # The key test: first submodule should have example inputs | 
|  | 143 | +        if hasattr(first_delegate.original_module, "example_inputs"): | 
|  | 144 | +            first_example_inputs = first_delegate.original_module.example_inputs | 
|  | 145 | +            if first_example_inputs is not None: | 
|  | 146 | +                # Verify they match original inputs structure | 
|  | 147 | +                self.assertEqual( | 
|  | 148 | +                    len(first_example_inputs), | 
|  | 149 | +                    len(example_inputs), | 
|  | 150 | +                    "First submodule example inputs should match original count", | 
|  | 151 | +                ) | 
|  | 152 | + | 
|  | 153 | +        # Verify remaining delegates have None example inputs | 
|  | 154 | +        for i, delegate in enumerate(delegate_modules[1:], 1): | 
|  | 155 | +            if hasattr(delegate.original_module, "example_inputs"): | 
|  | 156 | +                subsequent_example_inputs = delegate.original_module.example_inputs | 
|  | 157 | +                self.assertIsNone( | 
|  | 158 | +                    subsequent_example_inputs, | 
|  | 159 | +                    f"Delegate {i+1} should have None example inputs", | 
|  | 160 | +                ) | 
|  | 161 | + | 
|  | 162 | +    def test_single_subgraph_not_starting_from_original_input(self): | 
|  | 163 | +        """ | 
|  | 164 | +        Test partitioning into only one subgraph that doesn't start from the original | 
|  | 165 | +        inputs, and verify that this subgraph has None example inputs. | 
|  | 166 | +        """ | 
|  | 167 | + | 
|  | 168 | +        class IntermediateOnlyModel(torch.nn.Module): | 
|  | 169 | +            def __init__(self): | 
|  | 170 | +                super().__init__() | 
|  | 171 | +                self.multiplier = torch.nn.Parameter(torch.tensor([2.0])) | 
|  | 172 | + | 
|  | 173 | +            def forward(self, x, y): | 
|  | 174 | +                # Step 1: Create intermediate (not delegated) | 
|  | 175 | +                intermediate = x + y | 
|  | 176 | + | 
|  | 177 | +                # Step 2: Process intermediate (this will be delegated) | 
|  | 178 | +                # This doesn't use original x, y directly - uses intermediate result | 
|  | 179 | +                result = intermediate * self.multiplier | 
|  | 180 | +                return result | 
|  | 181 | + | 
|  | 182 | +        model = IntermediateOnlyModel() | 
|  | 183 | +        example_inputs = (torch.tensor([1.0]), torch.tensor([2.0])) | 
|  | 184 | + | 
|  | 185 | +        # Partitioner that only delegates the multiplication step | 
|  | 186 | +        class IntermediateOnlyPartitioner(Partitioner): | 
|  | 187 | +            def __init__(self): | 
|  | 188 | +                super().__init__() | 
|  | 189 | +                self.delegation_spec = DelegationSpec( | 
|  | 190 | +                    ExecutorBackend.__name__, | 
|  | 191 | +                    [CompileSpec("intermediate_only", bytes([1]))], | 
|  | 192 | +                ) | 
|  | 193 | + | 
|  | 194 | +            def partition( | 
|  | 195 | +                self, edge_exported_program: ExportedProgram | 
|  | 196 | +            ) -> PartitionResult: | 
|  | 197 | +                partition_tags = {} | 
|  | 198 | + | 
|  | 199 | +                for node in edge_exported_program.graph.nodes: | 
|  | 200 | +                    if node.op == "call_function": | 
|  | 201 | +                        # Only delegate the multiplication (intermediate * multiplier) | 
|  | 202 | +                        # NOT the addition (x + y) which uses original inputs | 
|  | 203 | +                        if node.target == exir_ops.edge.aten.mul.Tensor: | 
|  | 204 | +                            node.meta["delegation_tag"] = "intermediate_only" | 
|  | 205 | +                            partition_tags["intermediate_only"] = self.delegation_spec | 
|  | 206 | + | 
|  | 207 | +                return PartitionResult( | 
|  | 208 | +                    tagged_exported_program=edge_exported_program, | 
|  | 209 | +                    partition_tags=partition_tags, | 
|  | 210 | +                ) | 
|  | 211 | + | 
|  | 212 | +        exported_program = export(model, example_inputs, strict=True) | 
|  | 213 | +        edge_program = exir.to_edge(exported_program) | 
|  | 214 | + | 
|  | 215 | +        partitioned_program = edge_program.to_backend(IntermediateOnlyPartitioner()) | 
|  | 216 | + | 
|  | 217 | +        # Verify functionality | 
|  | 218 | +        output = partitioned_program.exported_program().module()(*example_inputs) | 
|  | 219 | +        expected_output = model(*example_inputs) | 
|  | 220 | +        self.assertTrue( | 
|  | 221 | +            torch.allclose(output, expected_output), | 
|  | 222 | +            "Partitioned program should produce same results as original", | 
|  | 223 | +        ) | 
|  | 224 | + | 
|  | 225 | +        # Get the single delegate | 
|  | 226 | +        delegates = get_delegates(partitioned_program.exported_program().graph) | 
|  | 227 | +        self.assertEqual(len(delegates), 1, "Should have exactly one delegate") | 
|  | 228 | + | 
|  | 229 | +        # Get the delegate module | 
|  | 230 | +        delegate_node = delegates[0] | 
|  | 231 | +        delegate_module = getattr( | 
|  | 232 | +            partitioned_program.exported_program().graph_module, delegate_node.name | 
|  | 233 | +        ) | 
|  | 234 | + | 
|  | 235 | +        # Key verification: This delegate doesn't start from original inputs, | 
|  | 236 | +        # so it should have None example inputs | 
|  | 237 | +        self.assertIsNotNone( | 
|  | 238 | +            delegate_module.original_module, "Delegate should have original_module" | 
|  | 239 | +        ) | 
|  | 240 | + | 
|  | 241 | +        if hasattr(delegate_module.original_module, "example_inputs"): | 
|  | 242 | +            delegate_example_inputs = delegate_module.original_module.example_inputs | 
|  | 243 | +            self.assertIsNone( | 
|  | 244 | +                delegate_example_inputs, | 
|  | 245 | +                "Delegate not starting from original inputs should have None example inputs", | 
|  | 246 | +            ) | 
|  | 247 | + | 
|  | 248 | +    def test_inputs_match_original_unit_logic(self): | 
|  | 249 | +        """ | 
|  | 250 | +        Unit test for the core logic that determines if subgraph inputs match original inputs. | 
|  | 251 | +        This directly tests the _inputs_match_original function behavior. | 
|  | 252 | +        """ | 
|  | 253 | + | 
|  | 254 | +        # Create a test model with multiple inputs | 
|  | 255 | +        class MultiInputModel(torch.nn.Module): | 
|  | 256 | +            def __init__(self): | 
|  | 257 | +                super().__init__() | 
|  | 258 | +                self.param = torch.nn.Parameter(torch.tensor([1.0])) | 
|  | 259 | +                self.register_buffer("buffer", torch.tensor([2.0])) | 
|  | 260 | + | 
|  | 261 | +            def forward(self, x, y): | 
|  | 262 | +                return x + y + self.param + self.buffer | 
|  | 263 | + | 
|  | 264 | +        model = MultiInputModel() | 
|  | 265 | +        example_inputs = (torch.tensor([1.0]), torch.tensor([2.0])) | 
|  | 266 | +        original_program = export(model, example_inputs, strict=True) | 
|  | 267 | + | 
|  | 268 | +        # Helper function that replicates the logic from create_exported_program_from_submodule | 
|  | 269 | +        def _inputs_match_original(subgraph_user_inputs, original_user_inputs): | 
|  | 270 | +            """ | 
|  | 271 | +            Core matching logic: check if user inputs match exactly | 
|  | 272 | +            """ | 
|  | 273 | +            if len(subgraph_user_inputs) != len(original_user_inputs): | 
|  | 274 | +                return False | 
|  | 275 | + | 
|  | 276 | +            return subgraph_user_inputs == original_user_inputs | 
|  | 277 | + | 
|  | 278 | +        # Get original user inputs for reference | 
|  | 279 | +        original_user_inputs = original_program.graph_signature.user_inputs | 
|  | 280 | +        self.assertEqual( | 
|  | 281 | +            len(original_user_inputs), 2, "Original should have 2 user inputs" | 
|  | 282 | +        ) | 
|  | 283 | + | 
|  | 284 | +        # Test Case 1: Matching user inputs (same as original) | 
|  | 285 | +        matching_user_inputs = original_user_inputs  # Exact same structure | 
|  | 286 | +        self.assertTrue( | 
|  | 287 | +            _inputs_match_original(matching_user_inputs, original_user_inputs), | 
|  | 288 | +            "Should return True when user inputs match exactly", | 
|  | 289 | +        ) | 
|  | 290 | + | 
|  | 291 | +        # Test Case 2: Different count of user inputs (subset) | 
|  | 292 | +        different_count_inputs = original_user_inputs[:1]  # Only first input | 
|  | 293 | +        self.assertFalse( | 
|  | 294 | +            _inputs_match_original(different_count_inputs, original_user_inputs), | 
|  | 295 | +            "Should return False when number of user inputs differs", | 
|  | 296 | +        ) | 
|  | 297 | + | 
|  | 298 | +        # Test Case 3: Empty inputs | 
|  | 299 | +        empty_inputs = [] | 
|  | 300 | +        self.assertFalse( | 
|  | 301 | +            _inputs_match_original(empty_inputs, original_user_inputs), | 
|  | 302 | +            "Should return False when subgraph has no user inputs", | 
|  | 303 | +        ) | 
|  | 304 | + | 
|  | 305 | +        # Test Case 4: Test with a completely different signature | 
|  | 306 | +        # Create a different model to get genuinely different user inputs | 
|  | 307 | +        class SingleInputModel(torch.nn.Module): | 
|  | 308 | +            def forward(self, x): | 
|  | 309 | +                return x * 2 | 
|  | 310 | + | 
|  | 311 | +        single_input_model = SingleInputModel() | 
|  | 312 | +        single_input_example = (torch.tensor([5.0]),) | 
|  | 313 | + | 
|  | 314 | +        single_input_program = export( | 
|  | 315 | +            single_input_model, single_input_example, strict=True | 
|  | 316 | +        ) | 
|  | 317 | +        different_user_inputs = single_input_program.graph_signature.user_inputs | 
|  | 318 | + | 
|  | 319 | +        self.assertFalse( | 
|  | 320 | +            _inputs_match_original(different_user_inputs, original_user_inputs), | 
|  | 321 | +            "Should return False when user inputs from different model signature", | 
|  | 322 | +        ) | 
0 commit comments