Skip to content

Commit 0468e28

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Add cadence.where_Scalar op (#9764)
Summary: Pull Request resolved: #9764 The op replaces the regular where op when it uses two `aten.full` ops for the tensors. Those cases do no need broadcast (but would call it in the current state) and can sometimes be constant folded, if the `condition` is a constant tensor. Since `aten.full` is _not_ currently constant folded, it would stay in the graph. Reviewed By: skrtskrtfb Differential Revision: D70539497
1 parent 7fd589d commit 0468e28

File tree

3 files changed

+134
-0
lines changed

3 files changed

+134
-0
lines changed

backends/cadence/aot/ops_registrations.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,8 @@
151151
"quantized_fully_connected.per_tensor(Tensor src, Tensor weight, Tensor bias, int src_zero_point, "
152152
"int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset) -> (Tensor Z)"
153153
)
154+
lib.define("where_Scalar(Tensor condition, float self, float other) -> (Tensor Z)")
155+
lib.define("where_Scalar.out(Tensor condition, float self, float other, *, Tensor(a!) out) -> Tensor(a!)")
154156

155157
# ------------------------------------ #
156158
# Migrated from custom_ops.yaml #
@@ -898,3 +900,11 @@ def transposed_im2row_meta(
898900
output_size = torch.Size((batch_size, output_length, n_output_plane))
899901

900902
return input.new_empty(output_size, dtype=input.dtype)
903+
904+
@register_fake("cadence::where_Scalar")
905+
def where_Scalar_meta(
906+
condition: torch.Tensor,
907+
self: float,
908+
other: float,
909+
) -> torch.Tensor:
910+
return condition.new_empty(condition.size(), dtype=torch.float32)

backends/cadence/aot/replace_ops.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2051,6 +2051,49 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
20512051
return PassResult(ret.graph_module, modified)
20522052

20532053

2054+
@register_cadence_pass(CadencePassAttribute(opt_level=1))
2055+
class ReplaceWhereWithFullArgsWithWhereScalar(ExportPass):
2056+
"""Replaces where ops using two full ops as tensors with a scalar
2057+
version.
2058+
"""
2059+
2060+
def call_operator(
2061+
self,
2062+
op,
2063+
args: Tuple[Argument, ...],
2064+
kwargs: Dict[str, Argument],
2065+
meta: NodeMetadata,
2066+
) -> ProxyValue:
2067+
if op not in {
2068+
exir_ops.edge.aten.where.self,
2069+
}:
2070+
return super().call_operator(op, args, kwargs, meta)
2071+
2072+
# If the args are not full ops, bail
2073+
# pyre-ignore[16]: `ProxyValue` has no attribute `node`.
2074+
if (args[1].node.target != exir_ops.edge.aten.full.default) or (args[2].node.target != exir_ops.edge.aten.full.default):
2075+
return super().call_operator(op, args, kwargs, meta)
2076+
2077+
# If one of the full ops is a different size than than the cond tensor, we need to broadcast. Bail.
2078+
# pyre-ignore[16]: `ProxyValue` has no attribute `node`.
2079+
if list(args[0].to_tensor().shape) != args[1].node.args[0] or list(args[0].to_tensor().shape) != args[2].node.args[0]:
2080+
return super().call_operator(op, args, kwargs, meta)
2081+
2082+
# Get the scalar values from the full ops
2083+
scalar_value_1 = args[1].node.args[1]
2084+
scalar_value_2 = args[2].node.args[1]
2085+
2086+
# Replace the where op with a scalar where op
2087+
return super().call_operator(
2088+
exir_ops.edge.cadence.where_Scalar.default,
2089+
(args[0], scalar_value_1, scalar_value_2),
2090+
kwargs,
2091+
meta,
2092+
)
2093+
2094+
return super().call_operator(op, args, kwargs, meta)
2095+
2096+
20542097
# This class encapsulates all the functions that replace/switch one op in the
20552098
# graph with another.
20562099
class CadenceReplaceOpsInGraph:
@@ -2089,4 +2132,5 @@ class CadenceReplaceOpsInGraph:
20892132
ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass,
20902133
ReplaceAtenAvgPoolWithJarvisAvgPoolPass,
20912134
ReplaceAtenLinalgVectorNormWithCadenceLinalgVectorNormPass,
2135+
ReplaceWhereWithFullArgsWithWhereScalar,
20922136
]

backends/cadence/aot/tests/test_replace_ops_passes.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
ReplaceTCopyWithTransposePass,
4545
ReplaceTransposedConvWithLinearPass,
4646
ReplaceTrivialConvWithLinear,
47+
ReplaceWhereWithFullArgsWithWhereScalar,
4748
)
4849
from executorch.exir.dialects._ops import ops as exir_ops
4950
from executorch.exir.pass_base import ExportPass
@@ -1217,6 +1218,85 @@ def forward(self, x: torch.Tensor):
12171218
1,
12181219
)
12191220

1221+
def test_replace_aten_where_with_cadence_where_Scalar(self):
1222+
cond_shape, a_shape, b_shape, val1, val2 = [(4, 8), (4, 8), (4, 8), 0.0, 1.0]
1223+
class WhereScalarModel(torch.nn.Module):
1224+
def forward(self, cond: torch.Tensor):
1225+
a = torch.ops.aten.full.default(a_shape, val1)
1226+
b = torch.ops.aten.full.default(b_shape, val2)
1227+
return torch.where(cond > 0, a, b)
1228+
cond = torch.randn(cond_shape)
1229+
1230+
graph_module = (
1231+
export_to_edge(WhereScalarModel(), (cond,)).exported_program().graph_module
1232+
)
1233+
1234+
p = ReplaceWhereWithFullArgsWithWhereScalar()
1235+
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
1236+
1237+
# Assert that aten.where op was replaced by a
1238+
# cadence.where_Scalar op
1239+
self.assertEqual(
1240+
count_node(
1241+
graph_after_passes,
1242+
exir_ops.edge.aten.where.self,
1243+
),
1244+
0,
1245+
)
1246+
self.assertEqual(
1247+
count_node(
1248+
graph_after_passes, exir_ops.edge.cadence.where_Scalar.default
1249+
),
1250+
1,
1251+
)
1252+
1253+
class WhereBroadcastModel(torch.nn.Module):
1254+
def forward(self, cond: torch.Tensor):
1255+
a = torch.ops.aten.full.default(a_shape, val1)
1256+
b = torch.ops.aten.full.default(b_shape, val2)
1257+
return torch.where(cond > 0, a, b)
1258+
1259+
# a tensor bigger than cond and b
1260+
cond_shape, a_shape, b_shape, val1, val2 = [(8,), (4, 8), (8,), 0.0, 1.0]
1261+
cond = torch.randn(cond_shape)
1262+
1263+
graph_module = (
1264+
export_to_edge(WhereBroadcastModel(), (cond,)).exported_program().graph_module
1265+
)
1266+
1267+
p = ReplaceWhereWithFullArgsWithWhereScalar()
1268+
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
1269+
1270+
# Assert that aten.where op is still in the graph since where_Scalar does not
1271+
# support broadcast
1272+
self.assertEqual(
1273+
count_node(
1274+
graph_after_passes,
1275+
exir_ops.edge.aten.where.self,
1276+
),
1277+
1,
1278+
)
1279+
1280+
# cond tensor bigger than a and b
1281+
cond_shape, a_shape, b_shape, val1, val2 = [(4, 8,), (8,), (8,), 0.0, 1.0]
1282+
cond = torch.randn(cond_shape)
1283+
1284+
graph_module = (
1285+
export_to_edge(WhereBroadcastModel(), (cond,)).exported_program().graph_module
1286+
)
1287+
1288+
p = ReplaceWhereWithFullArgsWithWhereScalar()
1289+
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
1290+
1291+
# Assert that aten.where op is still in the graph since where_Scalar does not
1292+
# support broadcast
1293+
self.assertEqual(
1294+
count_node(
1295+
graph_after_passes,
1296+
exir_ops.edge.aten.where.self,
1297+
),
1298+
1,
1299+
)
12201300

12211301
class TestReplaceIm2rowWithViewPass(unittest.TestCase):
12221302
def test_no_replacement_for_conv(self):

0 commit comments

Comments
 (0)