Skip to content

Commit 1399061

Browse files
hsharma35facebook-github-bot
authored andcommitted
Enable both aten and exir for ops in program builder. (#13075)
Summary: Allows both aten and exir ops to be built using program builder. Reviewed By: zonglinpeng Differential Revision: D79477846
1 parent 047587e commit 1399061

File tree

2 files changed

+123
-10
lines changed

2 files changed

+123
-10
lines changed

backends/cadence/aot/program_builder.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22

33
# pyre-strict
44

5+
from enum import auto, Enum
56
from typing import Optional
67

78
from executorch.backends.cadence.aot.graph_builder import GraphBuilder
89
from executorch.exir import EdgeCompileConfig, EdgeProgramManager
910
from executorch.exir.pass_base import ProxyValue
1011
from executorch.exir.verification.verifier import EXIREdgeDialectVerifier
11-
1212
from torch import Tensor
13+
from torch._export.verifier import Verifier
1314
from torch.export import ExportedProgram
1415
from torch.export.graph_signature import (
1516
ExportGraphSignature,
@@ -21,14 +22,20 @@
2122
)
2223

2324

25+
class IrMode(Enum):
26+
EXIR = auto()
27+
ATEN = auto()
28+
29+
2430
class ProgramBuilder(GraphBuilder):
2531
"""Utility class to build a program from a graph module."""
2632

27-
def __init__(self) -> None:
33+
def __init__(self, mode: Optional[IrMode] = None) -> None:
2834
self.input_specs: list[InputSpec] = []
2935
self.output_specs: list[OutputSpec] = []
3036
self.constants: dict[str, Tensor] = {}
3137
self.state_dict: dict[str, Tensor] = {}
38+
self.mode: IrMode = mode or IrMode.EXIR
3239
super().__init__()
3340

3441
def insert_input_spec(
@@ -68,6 +75,16 @@ def output(
6875
)
6976
return super().output(results)
7077

78+
def get_verifiers(self) -> Optional[list[Verifier]]:
79+
if self.mode == IrMode.ATEN:
80+
return None
81+
return [
82+
EXIREdgeDialectVerifier(
83+
edge_compile_config=EdgeCompileConfig(_check_ir_validity=False),
84+
class_only=True,
85+
)
86+
]
87+
7188
def get_program(self) -> ExportedProgram:
7289
gm = self.get_graph_module()
7390
return ExportedProgram(
@@ -81,12 +98,8 @@ def get_program(self) -> ExportedProgram:
8198
state_dict=self.state_dict,
8299
range_constraints={},
83100
module_call_graph=[],
84-
verifiers=[
85-
EXIREdgeDialectVerifier(
86-
edge_compile_config=EdgeCompileConfig(_check_ir_validity=False),
87-
class_only=True,
88-
)
89-
],
101+
# pyre-ignore[6]: Incompatible parameter type.
102+
verifiers=self.get_verifiers(),
90103
)
91104

92105
def get_edge_program(self) -> EdgeProgramManager:

backends/cadence/aot/tests/test_program_builder.py

Lines changed: 102 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22

33
# pyre-strict
4-
54
import torch
6-
from executorch.backends.cadence.aot.program_builder import ProgramBuilder
5+
from executorch.backends.cadence.aot.program_builder import IrMode, ProgramBuilder
6+
from executorch.exir.dialects._ops import ops as exir_ops
77
from later.unittest import TestCase
8+
from torch._export.verifier import SpecViolationError
89
from torch.export.graph_signature import InputKind, OutputKind
910

1011

@@ -120,3 +121,102 @@ def test_user_input_mutation(self) -> None:
120121
self.assertEqual(
121122
program.graph_signature.output_specs[0].kind, OutputKind.USER_INPUT_MUTATION
122123
)
124+
125+
def test_get_verifier_exir_mode(self) -> None:
126+
"""Test that get_verifier returns EXIREdgeDialectVerifier for EXIR mode."""
127+
builder = ProgramBuilder(mode=IrMode.EXIR)
128+
verifiers = builder.get_verifiers()
129+
self.assertIsNotNone(verifiers)
130+
self.assertEqual(len(verifiers), 1)
131+
132+
def test_get_verifier_aten_mode(self) -> None:
133+
"""Test that get_verifier returns None for ATEN mode."""
134+
builder = ProgramBuilder(mode=IrMode.ATEN)
135+
verifiers = builder.get_verifiers()
136+
self.assertIsNone(verifiers)
137+
138+
def test_get_verifier_default_mode(self) -> None:
139+
"""Test that get_verifier returns EXIREdgeDialectVerifier for default mode."""
140+
builder = ProgramBuilder() # Should default to EXIR
141+
self.assertEqual(builder.mode, IrMode.EXIR)
142+
verifiers = builder.get_verifiers()
143+
self.assertIsNotNone(verifiers)
144+
self.assertEqual(len(verifiers), 1)
145+
146+
def test_aten_add_tensor_exir_mode(self) -> None:
147+
"""Test using torch.ops.aten.add.Tensor with EXIR mode."""
148+
inp = torch.randn([3, 5])
149+
buffer = torch.randn([5])
150+
151+
builder = ProgramBuilder(mode=IrMode.EXIR)
152+
inp_proxy = builder.placeholder("inp", inp)
153+
buffer_proxy = builder.placeholder(
154+
"buffer", buffer, input_kind=InputKind.BUFFER
155+
)
156+
add = builder.call_operator(
157+
torch.ops.aten.add.Tensor, (inp_proxy, buffer_proxy)
158+
)
159+
builder.output([add])
160+
builder.get_program()
161+
162+
def test_aten_add_tensor_aten_mode(self) -> None:
163+
"""Test using torch.ops.aten.add.Tensor with ATEN mode."""
164+
inp = torch.randn([3, 5])
165+
buffer = torch.randn([5])
166+
167+
builder = ProgramBuilder(mode=IrMode.ATEN)
168+
inp_proxy = builder.placeholder("inp", inp)
169+
buffer_proxy = builder.placeholder(
170+
"buffer", buffer, input_kind=InputKind.BUFFER
171+
)
172+
add = builder.call_operator(
173+
torch.ops.aten.add.Tensor, (inp_proxy, buffer_proxy)
174+
)
175+
builder.output([add])
176+
program = builder.get_program()
177+
178+
# Verify the program was created successfully
179+
self.assertEqual(len(program.graph_signature.input_specs), 2)
180+
self.assertEqual(len(program.graph_signature.output_specs), 1)
181+
self.assertEqual(builder.mode, IrMode.ATEN)
182+
183+
def test_exir_edge_aten_add_tensor_exir_mode(self) -> None:
184+
"""Test using exir_ops.edge.aten.add.Tensor with EXIR mode."""
185+
inp = torch.randn([3, 5])
186+
buffer = torch.randn([5])
187+
188+
builder_exir = ProgramBuilder(mode=IrMode.EXIR)
189+
inp_proxy_exir = builder_exir.placeholder("inp", inp)
190+
buffer_proxy_exir = builder_exir.placeholder(
191+
"buffer", buffer, input_kind=InputKind.BUFFER
192+
)
193+
add_exir = builder_exir.call_operator(
194+
exir_ops.edge.aten.add.Tensor, (inp_proxy_exir, buffer_proxy_exir)
195+
)
196+
builder_exir.output([add_exir])
197+
program_exir = builder_exir.get_program()
198+
199+
# Verify the program was created successfully
200+
self.assertEqual(len(program_exir.graph_signature.input_specs), 2)
201+
self.assertEqual(len(program_exir.graph_signature.output_specs), 1)
202+
self.assertEqual(builder_exir.mode, IrMode.EXIR)
203+
204+
def test_exir_edge_aten_add_tensor_aten_mode(self) -> None:
205+
"""Test using exir_ops.edge.aten.add.Tensor with ATEN mode."""
206+
inp = torch.randn([3, 5])
207+
buffer = torch.randn([5])
208+
209+
builder_aten = ProgramBuilder(mode=IrMode.ATEN)
210+
inp_proxy_aten = builder_aten.placeholder("inp", inp)
211+
buffer_proxy_aten = builder_aten.placeholder(
212+
"buffer", buffer, input_kind=InputKind.BUFFER
213+
)
214+
add_aten = builder_aten.call_operator(
215+
exir_ops.edge.aten.add.Tensor, (inp_proxy_aten, buffer_proxy_aten)
216+
)
217+
builder_aten.output([add_aten])
218+
219+
with self.assertRaises(
220+
SpecViolationError, msg="Operator '<EdgeOpOverload: aten.add.Tensor>"
221+
):
222+
builder_aten.get_program()

0 commit comments

Comments
 (0)