Skip to content

Commit 787452b

Browse files
Gasoonjiafacebook-github-bot
authored andcommitted
equip the first partition with same example input as original graph
Summary: This diff equips the partitioned submodules example inputs. More specific, if the submodule starts from and covers all original export program's input, its example inputs should be the same as original export program's. Otherwise, the example inputs should be None. The example inputs is essential for aoti-driven backends for further compile. Differential Revision: D82865677
1 parent ce8916f commit 787452b

File tree

5 files changed

+371
-0
lines changed

5 files changed

+371
-0
lines changed

backends/arm/tosa/quant_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from executorch.backends.arm.tosa.mapping import TosaArg
2222
from torch.fx import Node
23+
2324
from tosa.RoundingMode import RoundingMode # type: ignore
2425

2526

@@ -318,6 +319,7 @@ def build_rescale(
318319
per_channel=False,
319320
):
320321
import serializer.tosa_serializer as ts # type: ignore
322+
321323
import tosa.Op as TosaOp # type: ignore
322324

323325
scaleWidth = 32

exir/backend/backend_api.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -720,6 +720,8 @@ def to_backend(
720720
fake_edge_program = copy.deepcopy(edge_program)
721721
partitioner_result = partitioner_instance(fake_edge_program)
722722
tagged_exported_program = partitioner_result.tagged_exported_program
723+
tagged_exported_program.example_inputs = edge_program.example_inputs
724+
723725
method_to_tagged_exported_program[method_name] = tagged_exported_program
724726

725727
# Check that the partitioner did not modify the original graph

exir/backend/test/TARGETS

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -458,3 +458,21 @@ python_unittest(
458458
"//executorch/exir/backend/canonical_partitioners:group_partitioner_lib",
459459
],
460460
)
461+
462+
python_unittest(
463+
name = "test_example_input_of_submodule",
464+
srcs = [
465+
"test_submodule_example_inputs.py",
466+
],
467+
deps = [
468+
"//caffe2:torch",
469+
"//executorch/exir:lib",
470+
"//executorch/exir:lowered_backend_module",
471+
"//executorch/exir/backend:backend_details",
472+
"//executorch/exir/backend:compile_spec_schema",
473+
"//executorch/exir/backend:partitioner",
474+
"//executorch/exir/backend/test/demos/rpc:executor_backend_preprocess",
475+
"//executorch/exir/backend:utils",
476+
"//executorch/exir/dialects:lib",
477+
],
478+
)
Lines changed: 322 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,322 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from executorch import exir
11+
from executorch.exir.backend.backend_details import CompileSpec, ExportedProgram
12+
from executorch.exir.backend.partitioner import (
13+
DelegationSpec,
14+
Partitioner,
15+
PartitionResult,
16+
)
17+
from executorch.exir.backend.test.demos.rpc.executor_backend_preprocess import (
18+
ExecutorBackend,
19+
)
20+
from executorch.exir.backend.utils import get_delegates
21+
from executorch.exir.dialects._ops import ops as exir_ops
22+
from torch.export import export
23+
24+
25+
class TestExampleInputOfSubmodule(unittest.TestCase):
26+
"""
27+
Tests for verifying that create_exported_program_from_submodule correctly
28+
handles example inputs of subgraphs based on input signature matching.
29+
30+
More specifically, if the partitioner delegates a subgraph that doesn't
31+
start from the original inputs or not cover all or them, the example inputs
32+
of the delegate should be None. Otherwise, the example inputs should match
33+
the original inputs.
34+
"""
35+
36+
def test_multiple_subgraphs_first_matches_original_others_none(self):
37+
"""
38+
Test partitioning a model into several submodules where:
39+
- First submodule starts from the very beginning (same inputs as original)
40+
- Verify first submodule has original example inputs
41+
- Verify rest of submodules have None example inputs
42+
"""
43+
44+
class ThreeStageModel(torch.nn.Module):
45+
def __init__(self):
46+
super().__init__()
47+
self.weight1 = torch.nn.Parameter(torch.tensor([2.0]))
48+
self.weight2 = torch.nn.Parameter(torch.tensor([3.0]))
49+
50+
def forward(self, x, y):
51+
# Stage 1: Direct operation on original inputs (will be first partition)
52+
stage1 = x + y # This should match original signature
53+
54+
# Stage 2: Uses stage1 result (different signature)
55+
stage2 = stage1 * self.weight1
56+
57+
# Stage 3: Uses stage2 result (different signature)
58+
stage3 = stage2 + self.weight2
59+
60+
return stage3
61+
62+
model = ThreeStageModel()
63+
example_inputs = (torch.tensor([1.0]), torch.tensor([2.0]))
64+
65+
# Create partitioner that delegates each stage separately
66+
class ThreeStagePartitioner(Partitioner):
67+
def __init__(self):
68+
super().__init__()
69+
self.spec1 = DelegationSpec(
70+
ExecutorBackend.__name__, [CompileSpec("stage1", bytes([1]))]
71+
)
72+
self.spec2 = DelegationSpec(
73+
ExecutorBackend.__name__, [CompileSpec("stage2", bytes([2]))]
74+
)
75+
self.spec3 = DelegationSpec(
76+
ExecutorBackend.__name__, [CompileSpec("stage3", bytes([3]))]
77+
)
78+
79+
def partition(
80+
self, edge_exported_program: ExportedProgram
81+
) -> PartitionResult:
82+
partition_tags = {}
83+
stage_counter = 1
84+
85+
for node in edge_exported_program.graph.nodes:
86+
if node.op == "call_function":
87+
if node.target == exir_ops.edge.aten.add.Tensor:
88+
# First add operation (x + y) - uses original inputs
89+
node.meta["delegation_tag"] = "stage1"
90+
partition_tags["stage1"] = self.spec1
91+
elif node.target == exir_ops.edge.aten.mul.Tensor:
92+
# Second operation (stage1 * weight1) - uses intermediate result
93+
node.meta["delegation_tag"] = "stage2"
94+
partition_tags["stage2"] = self.spec2
95+
elif (
96+
node.target == exir_ops.edge.aten.add.Tensor
97+
and stage_counter > 1
98+
):
99+
# Third operation (stage2 + weight2) - uses intermediate result
100+
node.meta["delegation_tag"] = "stage3"
101+
partition_tags["stage3"] = self.spec3
102+
103+
stage_counter += 1
104+
105+
return PartitionResult(
106+
tagged_exported_program=edge_exported_program,
107+
partition_tags=partition_tags,
108+
)
109+
110+
exported_program = export(model, example_inputs, strict=True)
111+
edge_program = exir.to_edge(exported_program)
112+
113+
partitioned_program = edge_program.to_backend(ThreeStagePartitioner())
114+
115+
# Get all delegate modules
116+
delegates = get_delegates(partitioned_program.exported_program().graph)
117+
self.assertGreater(
118+
len(delegates), 1, "Should have multiple delegate submodules"
119+
)
120+
121+
# Check each delegate's example inputs
122+
delegate_modules = []
123+
for delegate_node in delegates:
124+
delegate_module = getattr(
125+
partitioned_program.exported_program().graph_module, delegate_node.name
126+
)
127+
delegate_modules.append(delegate_module)
128+
129+
# Sort delegates by creation order (first should be the one with original inputs)
130+
delegate_modules.sort(
131+
key=lambda x: (
132+
x.backend_id if hasattr(x, "backend_id") else x.processed_bytes
133+
)
134+
)
135+
136+
# Verify first delegate has example inputs (should match original)
137+
first_delegate = delegate_modules[0]
138+
self.assertIsNotNone(
139+
first_delegate.original_module, "First delegate should have original_module"
140+
)
141+
142+
# The key test: first submodule should have example inputs
143+
if hasattr(first_delegate.original_module, "example_inputs"):
144+
first_example_inputs = first_delegate.original_module.example_inputs
145+
if first_example_inputs is not None:
146+
# Verify they match original inputs structure
147+
self.assertEqual(
148+
len(first_example_inputs),
149+
len(example_inputs),
150+
"First submodule example inputs should match original count",
151+
)
152+
153+
# Verify remaining delegates have None example inputs
154+
for i, delegate in enumerate(delegate_modules[1:], 1):
155+
if hasattr(delegate.original_module, "example_inputs"):
156+
subsequent_example_inputs = delegate.original_module.example_inputs
157+
self.assertIsNone(
158+
subsequent_example_inputs,
159+
f"Delegate {i+1} should have None example inputs",
160+
)
161+
162+
def test_single_subgraph_not_starting_from_original_input(self):
163+
"""
164+
Test partitioning into only one subgraph that doesn't start from the original
165+
inputs, and verify that this subgraph has None example inputs.
166+
"""
167+
168+
class IntermediateOnlyModel(torch.nn.Module):
169+
def __init__(self):
170+
super().__init__()
171+
self.multiplier = torch.nn.Parameter(torch.tensor([2.0]))
172+
173+
def forward(self, x, y):
174+
# Step 1: Create intermediate (not delegated)
175+
intermediate = x + y
176+
177+
# Step 2: Process intermediate (this will be delegated)
178+
# This doesn't use original x, y directly - uses intermediate result
179+
result = intermediate * self.multiplier
180+
return result
181+
182+
model = IntermediateOnlyModel()
183+
example_inputs = (torch.tensor([1.0]), torch.tensor([2.0]))
184+
185+
# Partitioner that only delegates the multiplication step
186+
class IntermediateOnlyPartitioner(Partitioner):
187+
def __init__(self):
188+
super().__init__()
189+
self.delegation_spec = DelegationSpec(
190+
ExecutorBackend.__name__,
191+
[CompileSpec("intermediate_only", bytes([1]))],
192+
)
193+
194+
def partition(
195+
self, edge_exported_program: ExportedProgram
196+
) -> PartitionResult:
197+
partition_tags = {}
198+
199+
for node in edge_exported_program.graph.nodes:
200+
if node.op == "call_function":
201+
# Only delegate the multiplication (intermediate * multiplier)
202+
# NOT the addition (x + y) which uses original inputs
203+
if node.target == exir_ops.edge.aten.mul.Tensor:
204+
node.meta["delegation_tag"] = "intermediate_only"
205+
partition_tags["intermediate_only"] = self.delegation_spec
206+
207+
return PartitionResult(
208+
tagged_exported_program=edge_exported_program,
209+
partition_tags=partition_tags,
210+
)
211+
212+
exported_program = export(model, example_inputs, strict=True)
213+
edge_program = exir.to_edge(exported_program)
214+
215+
partitioned_program = edge_program.to_backend(IntermediateOnlyPartitioner())
216+
217+
# Verify functionality
218+
output = partitioned_program.exported_program().module()(*example_inputs)
219+
expected_output = model(*example_inputs)
220+
self.assertTrue(
221+
torch.allclose(output, expected_output),
222+
"Partitioned program should produce same results as original",
223+
)
224+
225+
# Get the single delegate
226+
delegates = get_delegates(partitioned_program.exported_program().graph)
227+
self.assertEqual(len(delegates), 1, "Should have exactly one delegate")
228+
229+
# Get the delegate module
230+
delegate_node = delegates[0]
231+
delegate_module = getattr(
232+
partitioned_program.exported_program().graph_module, delegate_node.name
233+
)
234+
235+
# Key verification: This delegate doesn't start from original inputs,
236+
# so it should have None example inputs
237+
self.assertIsNotNone(
238+
delegate_module.original_module, "Delegate should have original_module"
239+
)
240+
241+
if hasattr(delegate_module.original_module, "example_inputs"):
242+
delegate_example_inputs = delegate_module.original_module.example_inputs
243+
self.assertIsNone(
244+
delegate_example_inputs,
245+
"Delegate not starting from original inputs should have None example inputs",
246+
)
247+
248+
def test_inputs_match_original_unit_logic(self):
249+
"""
250+
Unit test for the core logic that determines if subgraph inputs match original inputs.
251+
This directly tests the _inputs_match_original function behavior.
252+
"""
253+
254+
# Create a test model with multiple inputs
255+
class MultiInputModel(torch.nn.Module):
256+
def __init__(self):
257+
super().__init__()
258+
self.param = torch.nn.Parameter(torch.tensor([1.0]))
259+
self.register_buffer("buffer", torch.tensor([2.0]))
260+
261+
def forward(self, x, y):
262+
return x + y + self.param + self.buffer
263+
264+
model = MultiInputModel()
265+
example_inputs = (torch.tensor([1.0]), torch.tensor([2.0]))
266+
original_program = export(model, example_inputs, strict=True)
267+
268+
# Helper function that replicates the logic from create_exported_program_from_submodule
269+
def _inputs_match_original(subgraph_user_inputs, original_user_inputs):
270+
"""
271+
Core matching logic: check if user inputs match exactly
272+
"""
273+
if len(subgraph_user_inputs) != len(original_user_inputs):
274+
return False
275+
276+
return subgraph_user_inputs == original_user_inputs
277+
278+
# Get original user inputs for reference
279+
original_user_inputs = original_program.graph_signature.user_inputs
280+
self.assertEqual(
281+
len(original_user_inputs), 2, "Original should have 2 user inputs"
282+
)
283+
284+
# Test Case 1: Matching user inputs (same as original)
285+
matching_user_inputs = original_user_inputs # Exact same structure
286+
self.assertTrue(
287+
_inputs_match_original(matching_user_inputs, original_user_inputs),
288+
"Should return True when user inputs match exactly",
289+
)
290+
291+
# Test Case 2: Different count of user inputs (subset)
292+
different_count_inputs = original_user_inputs[:1] # Only first input
293+
self.assertFalse(
294+
_inputs_match_original(different_count_inputs, original_user_inputs),
295+
"Should return False when number of user inputs differs",
296+
)
297+
298+
# Test Case 3: Empty inputs
299+
empty_inputs = []
300+
self.assertFalse(
301+
_inputs_match_original(empty_inputs, original_user_inputs),
302+
"Should return False when subgraph has no user inputs",
303+
)
304+
305+
# Test Case 4: Test with a completely different signature
306+
# Create a different model to get genuinely different user inputs
307+
class SingleInputModel(torch.nn.Module):
308+
def forward(self, x):
309+
return x * 2
310+
311+
single_input_model = SingleInputModel()
312+
single_input_example = (torch.tensor([5.0]),)
313+
314+
single_input_program = export(
315+
single_input_model, single_input_example, strict=True
316+
)
317+
different_user_inputs = single_input_program.graph_signature.user_inputs
318+
319+
self.assertFalse(
320+
_inputs_match_original(different_user_inputs, original_user_inputs),
321+
"Should return False when user inputs from different model signature",
322+
)

0 commit comments

Comments
 (0)