Skip to content

Commit 53b4cb9

Browse files
authored
Replace aten linalg svd with cadence version.
Differential Revision: D81199281 Pull Request resolved: #13907
1 parent bbe8943 commit 53b4cb9

File tree

3 files changed

+70
-3
lines changed

3 files changed

+70
-3
lines changed

backends/cadence/aot/program_builder.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from executorch.exir.verification.verifier import EXIREdgeDialectVerifier
1212
from torch import Tensor
1313
from torch._export.verifier import Verifier
14+
from torch._ops import OpOverload
1415
from torch.export import ExportedProgram
1516
from torch.export.exported_program import ModuleCallEntry, ModuleCallSignature
1617
from torch.export.graph_signature import (
@@ -32,12 +33,19 @@ class IrMode(Enum):
3233
class ProgramBuilder(GraphBuilder):
3334
"""Utility class to build a program from a graph module."""
3435

35-
def __init__(self, mode: Optional[IrMode] = None) -> None:
36+
def __init__(
37+
self,
38+
mode: Optional[IrMode] = None,
39+
_core_aten_ops_exception_list: Optional[list[OpOverload]] = None,
40+
) -> None:
3641
self.input_specs: list[InputSpec] = []
3742
self.output_specs: list[OutputSpec] = []
3843
self.constants: dict[str, Tensor] = {}
3944
self.state_dict: dict[str, Tensor] = {}
4045
self.mode: IrMode = mode or IrMode.EXIR
46+
self._core_aten_ops_exception_list: list[OpOverload] = (
47+
_core_aten_ops_exception_list or []
48+
)
4149
super().__init__()
4250

4351
def insert_input_spec(
@@ -82,7 +90,11 @@ def get_verifiers(self) -> Optional[list[Verifier]]:
8290
return None
8391
return [
8492
EXIREdgeDialectVerifier(
85-
edge_compile_config=EdgeCompileConfig(_check_ir_validity=False),
93+
edge_compile_config=EdgeCompileConfig(
94+
_check_ir_validity=False,
95+
_core_aten_ops_exception_list=self._core_aten_ops_exception_list,
96+
),
97+
core_aten_ops_exception_list=self._core_aten_ops_exception_list,
8698
class_only=True,
8799
)
88100
]
@@ -113,4 +125,7 @@ def get_program(self) -> ExportedProgram:
113125
)
114126

115127
def get_edge_program(self) -> EdgeProgramManager:
116-
return EdgeProgramManager(self.get_program())
128+
return EdgeProgramManager(
129+
self.get_program(),
130+
core_aten_ops_exception_list=self._core_aten_ops_exception_list,
131+
)

backends/cadence/aot/replace_ops.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2242,10 +2242,26 @@ def call_operator(self, op, args, kwargs, meta):
22422242
)
22432243

22442244

2245+
@register_cadence_pass(CadencePassAttribute(opt_level=0))
2246+
class ReplaceAtenLinalgSvdWithCadenceLinalgSvdPass(ExportPass):
2247+
"""
2248+
Replace aten linalg svd op with cadence custom op.
2249+
"""
2250+
2251+
def call_operator(self, op, args, kwargs, meta):
2252+
if op != exir_ops.edge.aten._linalg_svd.default:
2253+
return super().call_operator(op, args, kwargs, meta)
2254+
2255+
return super().call_operator(
2256+
exir_ops.edge.cadence.linalg_svd.default, args, kwargs, meta
2257+
)
2258+
2259+
22452260
# This class encapsulates all the functions that replace/switch one op in the
22462261
# graph with another.
22472262
class CadenceReplaceOpsInGraph:
22482263
passes = [
2264+
ReplaceAtenLinalgSvdWithCadenceLinalgSvdPass,
22492265
ReplaceEmptyTensorsWithFullPass,
22502266
ReplaceFunctionallyEquivalentOpTargets,
22512267
ReplacePermuteWithTransposePass,

backends/cadence/aot/tests/test_replace_ops_passes.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
ReplaceAddMMWithLinearPass,
2323
ReplaceAtenApproxGeluWithApproxGeluPass,
2424
ReplaceAtenConvolutionWithCadenceConvolutionPass,
25+
ReplaceAtenLinalgSvdWithCadenceLinalgSvdPass,
2526
ReplaceConstantPadNdWithSlicePass,
2627
ReplaceConvolutionOptionalArgsWithConcreteArgsPass,
2728
ReplaceConvWithChannelLastConvPass,
@@ -2045,3 +2046,38 @@ def test_replace_adaptive_avg_pool_with_aten_avg_pool_irregular(self) -> None:
20452046
len(avg_pool2d_nodes),
20462047
0,
20472048
)
2049+
2050+
2051+
class TestReplaceLinalgSvdPass(unittest.TestCase):
2052+
@expand(
2053+
[
2054+
("2x2", (2, 2)),
2055+
("3x3", (3, 3)),
2056+
("4x5", (4, 5)),
2057+
("10x10", (10, 10)),
2058+
]
2059+
)
2060+
@torch.no_grad()
2061+
def test_replace_aten_linalg_svd_with_cadence_linalg_svd(
2062+
self, _: str, shape: Tuple[int, int]
2063+
) -> None:
2064+
x = torch.randn(shape, dtype=torch.float32)
2065+
original_gm = single_op_builder(
2066+
placeholders=(x,),
2067+
op=exir_ops.edge.aten._linalg_svd.default,
2068+
args=(x, False, True),
2069+
kwargs={"driver": None},
2070+
)
2071+
2072+
p = ReplaceAtenLinalgSvdWithCadenceLinalgSvdPass()
2073+
graph_after_passes = cast(PassResult, p(original_gm)).graph_module
2074+
2075+
# Assert that the aten linalg_svd op was replaced with cadence linalg_svd op
2076+
self.assertEqual(
2077+
count_node(graph_after_passes, exir_ops.edge.aten._linalg_svd.default),
2078+
0,
2079+
)
2080+
self.assertEqual(
2081+
count_node(graph_after_passes, exir_ops.edge.cadence.linalg_svd.default),
2082+
1,
2083+
)

0 commit comments

Comments
 (0)