Skip to content

Commit 3c46dd2

Browse files
hsharma35facebook-github-bot
authored andcommitted
Replace aten linalg svd with cadence version.
Summary: Adds a replacement pass for `aten::_linalg_svd.default`. This op produces non-contiguous outputs so we replace it with cadence version that produces contiguous outputs. Differential Revision: D81199281
1 parent 5abad6c commit 3c46dd2

File tree

2 files changed

+52
-0
lines changed

2 files changed

+52
-0
lines changed

backends/cadence/aot/replace_ops.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2242,10 +2242,26 @@ def call_operator(self, op, args, kwargs, meta):
22422242
)
22432243

22442244

2245+
@register_cadence_pass(CadencePassAttribute(opt_level=0))
2246+
class ReplaceAtenLinalgSvdWithCadenceLinalgSvdPass(ExportPass):
2247+
"""
2248+
Replace aten linalg svd op with cadence custom op.
2249+
"""
2250+
2251+
def call_operator(self, op, args, kwargs, meta):
2252+
if op != exir_ops.edge.aten._linalg_svd.default:
2253+
return super().call_operator(op, args, kwargs, meta)
2254+
2255+
return super().call_operator(
2256+
exir_ops.edge.cadence.linalg_svd.default, args, kwargs, meta
2257+
)
2258+
2259+
22452260
# This class encapsulates all the functions that replace/switch one op in the
22462261
# graph with another.
22472262
class CadenceReplaceOpsInGraph:
22482263
passes = [
2264+
ReplaceAtenLinalgSvdWithCadenceLinalgSvdPass,
22492265
ReplaceEmptyTensorsWithFullPass,
22502266
ReplaceFunctionallyEquivalentOpTargets,
22512267
ReplacePermuteWithTransposePass,

backends/cadence/aot/tests/test_replace_ops_passes.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
ReplaceAddMMWithLinearPass,
2323
ReplaceAtenApproxGeluWithApproxGeluPass,
2424
ReplaceAtenConvolutionWithCadenceConvolutionPass,
25+
ReplaceAtenLinalgSvdWithCadenceLinalgSvdPass,
2526
ReplaceConstantPadNdWithSlicePass,
2627
ReplaceConvolutionOptionalArgsWithConcreteArgsPass,
2728
ReplaceConvWithChannelLastConvPass,
@@ -2045,3 +2046,38 @@ def test_replace_adaptive_avg_pool_with_aten_avg_pool_irregular(self) -> None:
20452046
len(avg_pool2d_nodes),
20462047
0,
20472048
)
2049+
2050+
2051+
class TestReplaceLinalgSvdPass(unittest.TestCase):
2052+
@expand(
2053+
[
2054+
("2x2", (2, 2)),
2055+
("3x3", (3, 3)),
2056+
("4x5", (4, 5)),
2057+
("10x10", (10, 10)),
2058+
]
2059+
)
2060+
@torch.no_grad()
2061+
def test_replace_aten_linalg_svd_with_cadence_linalg_svd(
2062+
self, _: str, shape: Tuple[int, int]
2063+
) -> None:
2064+
x = torch.randn(shape, dtype=torch.float32)
2065+
original_gm = single_op_builder(
2066+
placeholders=(x,),
2067+
op=exir_ops.edge.aten._linalg_svd.default,
2068+
args=(x, False, True),
2069+
kwargs={"driver": None},
2070+
)
2071+
2072+
p = ReplaceAtenLinalgSvdWithCadenceLinalgSvdPass()
2073+
graph_after_passes = cast(PassResult, p(original_gm)).graph_module
2074+
2075+
# Assert that the aten linalg_svd op was replaced with cadence linalg_svd op
2076+
self.assertEqual(
2077+
count_node(graph_after_passes, exir_ops.edge.aten._linalg_svd.default),
2078+
0,
2079+
)
2080+
self.assertEqual(
2081+
count_node(graph_after_passes, exir_ops.edge.cadence.linalg_svd.default),
2082+
1,
2083+
)

0 commit comments

Comments
 (0)