Skip to content

Commit d2d57cf

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Add cadence.where_Scalar op (#9764)
Summary: Pull Request resolved: #9764 The op replaces the regular where op when it uses two `aten.full` ops for the tensors. Those cases do no need broadcast (but would call it in the current state) and can sometimes be constant folded, if the `condition` is a constant tensor. Since `aten.full` is _not_ currently constant folded, it would stay in the graph. Reviewed By: skrtskrtfb, zonglinpeng Differential Revision: D70539497
1 parent 97bca05 commit d2d57cf

File tree

3 files changed

+146
-0
lines changed

3 files changed

+146
-0
lines changed

backends/cadence/aot/ops_registrations.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,10 @@
151151
"quantized_fully_connected.per_tensor(Tensor src, Tensor weight, Tensor bias, int src_zero_point, "
152152
"int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset) -> (Tensor Z)"
153153
)
154+
lib.define("where_Scalar(Tensor condition, float self, float other) -> (Tensor Z)")
155+
lib.define(
156+
"where_Scalar.out(Tensor condition, float self, float other, *, Tensor(a!) out) -> Tensor(a!)"
157+
)
154158

155159
# ------------------------------------ #
156160
# Migrated from custom_ops.yaml #
@@ -898,3 +902,12 @@ def transposed_im2row_meta(
898902
output_size = torch.Size((batch_size, output_length, n_output_plane))
899903

900904
return input.new_empty(output_size, dtype=input.dtype)
905+
906+
907+
@register_fake("cadence::where_Scalar")
908+
def where_Scalar_meta(
909+
condition: torch.Tensor,
910+
self: float,
911+
other: float,
912+
) -> torch.Tensor:
913+
return condition.new_empty(condition.size(), dtype=torch.float32)

backends/cadence/aot/replace_ops.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2051,6 +2051,54 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
20512051
return PassResult(ret.graph_module, modified)
20522052

20532053

2054+
@register_cadence_pass(CadencePassAttribute(opt_level=1))
2055+
class ReplaceWhereWithFullArgsWithWhereScalar(ExportPass):
2056+
"""Replaces where ops using two full ops as tensors with a scalar
2057+
version.
2058+
"""
2059+
2060+
def call_operator(
2061+
self,
2062+
op,
2063+
args: Tuple[Argument, ...],
2064+
kwargs: Dict[str, Argument],
2065+
meta: NodeMetadata,
2066+
) -> ProxyValue:
2067+
if op not in {
2068+
exir_ops.edge.aten.where.self,
2069+
}:
2070+
return super().call_operator(op, args, kwargs, meta)
2071+
2072+
# If the args are not full ops, bail
2073+
# pyre-ignore[16]: `ProxyValue` has no attribute `node`.
2074+
if (args[1].node.target != exir_ops.edge.aten.full.default) or (
2075+
args[2].node.target != exir_ops.edge.aten.full.default
2076+
):
2077+
return super().call_operator(op, args, kwargs, meta)
2078+
2079+
# If one of the full ops is a different size than than the cond tensor, we need to broadcast. Bail.
2080+
if (
2081+
# pyre-ignore[16]: `ProxyValue` has no attribute `node`.
2082+
list(args[0].to_tensor().shape) != args[1].node.args[0]
2083+
or list(args[0].to_tensor().shape) != args[2].node.args[0]
2084+
):
2085+
return super().call_operator(op, args, kwargs, meta)
2086+
2087+
# Get the scalar values from the full ops
2088+
scalar_value_1 = args[1].node.args[1]
2089+
scalar_value_2 = args[2].node.args[1]
2090+
2091+
# Replace the where op with a scalar where op
2092+
return super().call_operator(
2093+
exir_ops.edge.cadence.where_Scalar.default,
2094+
(args[0], scalar_value_1, scalar_value_2),
2095+
kwargs,
2096+
meta,
2097+
)
2098+
2099+
return super().call_operator(op, args, kwargs, meta)
2100+
2101+
20542102
# This class encapsulates all the functions that replace/switch one op in the
20552103
# graph with another.
20562104
class CadenceReplaceOpsInGraph:
@@ -2089,4 +2137,5 @@ class CadenceReplaceOpsInGraph:
20892137
ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass,
20902138
ReplaceAtenAvgPoolWithJarvisAvgPoolPass,
20912139
ReplaceAtenLinalgVectorNormWithCadenceLinalgVectorNormPass,
2140+
ReplaceWhereWithFullArgsWithWhereScalar,
20922141
]

backends/cadence/aot/tests/test_replace_ops_passes.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
ReplaceTCopyWithTransposePass,
4545
ReplaceTransposedConvWithLinearPass,
4646
ReplaceTrivialConvWithLinear,
47+
ReplaceWhereWithFullArgsWithWhereScalar,
4748
)
4849
from executorch.exir.dialects._ops import ops as exir_ops
4950
from executorch.exir.pass_base import ExportPass
@@ -1217,6 +1218,89 @@ def forward(self, x: torch.Tensor):
12171218
1,
12181219
)
12191220

1221+
def test_replace_aten_where_with_cadence_where_Scalar(self):
1222+
class WhereScalarModel(torch.nn.Module):
1223+
def forward(self, cond: torch.Tensor):
1224+
a = torch.ops.aten.full.default(a_shape, val1)
1225+
b = torch.ops.aten.full.default(b_shape, val2)
1226+
return torch.where(cond > 0, a, b)
1227+
1228+
cond_shape, a_shape, b_shape, val1, val2 = [(4, 8), (4, 8), (4, 8), 0.0, 1.0]
1229+
cond = torch.randn(cond_shape)
1230+
1231+
graph_module = (
1232+
export_to_edge(WhereScalarModel(), (cond,)).exported_program().graph_module
1233+
)
1234+
1235+
p = ReplaceWhereWithFullArgsWithWhereScalar()
1236+
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
1237+
1238+
# Assert that aten.where op was replaced by a
1239+
# cadence.where_Scalar op
1240+
self.assertEqual(
1241+
count_node(
1242+
graph_after_passes,
1243+
exir_ops.edge.aten.where.self,
1244+
),
1245+
0,
1246+
)
1247+
self.assertEqual(
1248+
count_node(graph_after_passes, exir_ops.edge.cadence.where_Scalar.default),
1249+
1,
1250+
)
1251+
1252+
class WhereBroadcastModel(torch.nn.Module):
1253+
def forward(self, cond: torch.Tensor):
1254+
a = torch.ops.aten.full.default(a_shape, val1)
1255+
b = torch.ops.aten.full.default(b_shape, val2)
1256+
return torch.where(cond > 0, a, b)
1257+
1258+
# a tensor bigger than cond and b
1259+
cond_shape, a_shape, b_shape, val1, val2 = [(8,), (4, 8), (8,), 0.0, 1.0]
1260+
cond = torch.randn(cond_shape)
1261+
1262+
graph_module = (
1263+
export_to_edge(WhereBroadcastModel(), (cond,))
1264+
.exported_program()
1265+
.graph_module
1266+
)
1267+
1268+
p = ReplaceWhereWithFullArgsWithWhereScalar()
1269+
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
1270+
1271+
# Assert that aten.where op is still in the graph since where_Scalar does not
1272+
# support broadcast
1273+
self.assertEqual(
1274+
count_node(
1275+
graph_after_passes,
1276+
exir_ops.edge.aten.where.self,
1277+
),
1278+
1,
1279+
)
1280+
1281+
# cond tensor bigger than a and b
1282+
cond_shape, a_shape, b_shape, val1, val2 = [(4, 8), (8,), (8,), 0.0, 1.0]
1283+
cond = torch.randn(cond_shape)
1284+
1285+
graph_module = (
1286+
export_to_edge(WhereBroadcastModel(), (cond,))
1287+
.exported_program()
1288+
.graph_module
1289+
)
1290+
1291+
p = ReplaceWhereWithFullArgsWithWhereScalar()
1292+
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
1293+
1294+
# Assert that aten.where op is still in the graph since where_Scalar does not
1295+
# support broadcast
1296+
self.assertEqual(
1297+
count_node(
1298+
graph_after_passes,
1299+
exir_ops.edge.aten.where.self,
1300+
),
1301+
1,
1302+
)
1303+
12201304

12211305
class TestReplaceIm2rowWithViewPass(unittest.TestCase):
12221306
def test_no_replacement_for_conv(self):

0 commit comments

Comments
 (0)