Skip to content

Commit 481a827

Browse files
committed
Arm backend: Replace Tensor_Scalar with Tensor_Tensor in sqrt op
Signed-off-by: Elena Zhelezina <[email protected]> Change-Id: Ia3e596f855ec97b0ad59161bccc906b13e96c770
1 parent 7565342 commit 481a827

File tree

7 files changed

+77
-25
lines changed

7 files changed

+77
-25
lines changed

backends/arm/_passes/annotate_channels_last_dim_order_pass.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node):
6161
"""
6262
if node.op == "placeholder":
6363
# node is an input, weight or bias node
64+
if not node.users:
65+
return False
6466
consumer_node = list(node.users)[0]
6567
if self.is_weight_node_for_depthwise_conv2d(consumer_node):
6668
return True

backends/arm/_passes/decompose_sqrt_pass.py

Lines changed: 46 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,36 +4,62 @@
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
7+
8+
from typing import Any, Dict, Tuple
9+
710
import torch
11+
812
from executorch.exir.dialects._ops import ops as exir_ops
913
from executorch.exir.pass_base import ExportPass
1014

11-
edge_sqrt_ops = (exir_ops.edge.aten.sqrt.default,)
12-
aten_sqrt_ops = (
13-
torch.ops.aten.sqrt.default,
14-
torch.ops.aten.sqrt_.default,
15-
)
1615

16+
class DecomposeSqrtPass(ExportPass):
17+
def __init__(self) -> None:
18+
super().__init__()
19+
20+
# We cache constant tensor for the exponent
21+
self._half_cache: Dict[Tuple[Any, Any], Any] = {}
22+
self.SQRT_TO_POW = {
23+
exir_ops.edge.aten.sqrt.default: exir_ops.edge.aten.pow.Tensor_Tensor,
24+
torch.ops.aten.sqrt.default: torch.ops.aten.pow.Tensor_Tensor,
25+
torch.ops.aten.sqrt_.default: torch.ops.aten.pow.Tensor_Tensor,
26+
}
1727

18-
def get_sqrt_decomposition(op) -> tuple:
19-
# TODO : "MLETORCH-863 : Replace current sqrt -> pow.Tensor_Scalar workaround with pow.Tensor_Tensor"
20-
if op in edge_sqrt_ops:
21-
return exir_ops.edge.aten.pow.Tensor_Scalar
22-
if op in aten_sqrt_ops:
23-
return torch.ops.aten.pow.Tensor_Scalar
24-
raise RuntimeError(f"Can't get sqrt decomposition for op {op}")
28+
def _get_half_tensor(
29+
self,
30+
dtype: Any,
31+
device: Any,
32+
meta: Any,
33+
) -> Any:
34+
# Choose a floating dtype for 0.5
35+
if dtype in (torch.float16, torch.float32, torch.float64):
36+
half_dtype = dtype
37+
else:
38+
half_dtype = torch.float32
2539

40+
key = (half_dtype, device)
41+
if key not in self._half_cache:
42+
half = super().call_operator(
43+
exir_ops.edge.aten.full.default,
44+
([], 0.5),
45+
{"dtype": half_dtype, "device": device},
46+
meta,
47+
)
48+
self._half_cache[key] = half
2649

27-
class DecomposeSqrtPass(ExportPass):
50+
return self._half_cache[key]
2851

29-
def call_operator(self, op, args, kwargs, meta):
30-
"""
31-
Decomposes `sqrt(x)` into `pow(x, 0.5)` for backend support.
32-
"""
52+
def call_operator(self, op: Any, args: tuple, kwargs: dict, meta: Any) -> Any:
3353

34-
if op not in (edge_sqrt_ops + aten_sqrt_ops):
54+
if op not in self.SQRT_TO_POW:
3555
return super().call_operator(op, args, kwargs, meta)
3656

37-
pow_op = get_sqrt_decomposition(op)
57+
if len(args) != 1:
58+
raise ValueError(f"Expected 1 arg to sqrt, got {len(args)}")
59+
60+
x = args[0]
61+
pow_op = self.SQRT_TO_POW[op]
62+
63+
half = self._get_half_tensor(x.data.dtype, x.data.device, meta)
3864

39-
return super().call_operator(pow_op, (args[0], 0.5), {}, meta)
65+
return super().call_operator(pow_op, (x, half), {}, meta)

backends/arm/_passes/insert_table_ops.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ class TableOps:
5757

5858
# Targets that must be treated explicitly
5959
special_table_ops: Set[EdgeOpOverload] = {
60+
exir_ops.edge.aten.pow.Tensor_Tensor,
6061
exir_ops.edge.aten.pow.Tensor_Scalar,
6162
exir_ops.edge.aten.gelu.default,
6263
}
@@ -75,6 +76,13 @@ def __getitem__(self, node: Node):
7576
return self.unary_table_ops[target]
7677
elif target in self.special_table_ops:
7778
match target:
79+
case exir_ops.edge.aten.pow.Tensor_Tensor:
80+
# Exponent is a constant. Embed it into a lambda.
81+
exp_node = node.args[1]
82+
exp = float(
83+
self.exported_program.state_dict[exp_node.name].item() # type: ignore[union-attr]
84+
)
85+
return lambda x: torch.pow(x, exp).flatten()
7886
case exir_ops.edge.aten.pow.Tensor_Scalar:
7987
# Exponent is a constant. Embed it into a lambda.
8088
exp = cast(int, node.args[1])
@@ -283,8 +291,16 @@ def call(self, graph_module: GraphModule) -> PassResult:
283291
modified = True
284292

285293
if modified:
294+
graph_module.graph.eliminate_dead_code()
295+
296+
# Remove any placeholder with zero users
297+
for ph in list(graph_module.graph.nodes):
298+
if ph.op == "placeholder" and len(ph.users) == 0:
299+
graph_module.graph.erase_node(ph)
300+
self.exported_program.state_dict.pop(ph.name, None)
301+
286302
# retrace the graph to update the fake tensor types
287303
graph_module = super().call(graph_module).graph_module
288-
289304
graph_module.recompile()
305+
290306
return PassResult(graph_module, modified)

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ def is_node_supported(
226226
exir_ops.edge.aten.squeeze_copy.dims,
227227
exir_ops.edge.aten.pow.Tensor_Scalar,
228228
exir_ops.edge.aten.pow.Tensor_Tensor,
229+
torch.ops.aten.pow.Tensor_Tensor,
229230
exir_ops.edge.aten.where.self,
230231
operator.getitem,
231232
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
@@ -306,8 +307,6 @@ class CheckProperQuantization(OperatorSupportBase):
306307
exir_ops.edge.aten.avg_pool2d.default,
307308
exir_ops.edge.aten.bmm.default,
308309
exir_ops.edge.aten.convolution.default,
309-
exir_ops.edge.aten.full.default,
310-
exir_ops.edge.aten.full_like.default,
311310
exir_ops.edge.aten.hardtanh.default,
312311
exir_ops.edge.aten.linear.default,
313312
exir_ops.edge.aten.max_pool2d_with_indices.default,
@@ -410,6 +409,7 @@ def is_node_supported(
410409

411410
input_quantized = input_quantized or all(
412411
(input_node.target in dq_ops)
412+
or (node.name == "aten_pow_tensor_tensor")
413413
or (not get_first_fake_tensor(input_node).dtype.is_floating_point)
414414
for input_node in node.all_input_nodes
415415
)

backends/arm/quantizer/quantization_annotator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ def _match_pattern(
212212
torch.ops.aten.hardswish_.default,
213213
torch.ops.aten.full_like.default,
214214
torch.ops.aten.pow.Tensor_Scalar,
215+
torch.ops.aten.pow.Tensor_Tensor,
215216
torch.ops.aten.gelu.default,
216217
]
217218

backends/arm/test/ops/test_sqrt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ class Sqrt(torch.nn.Module):
2121
aten_op_MI = "torch.ops.aten.sqrt.default"
2222
exir_op_MI = "executorch_exir_dialects_edge__ops_aten_pow_Tensor_Tensor"
2323

24-
aten_op_BI = "torch.ops.aten.pow.Tensor_Scalar"
25-
exir_op_BI = "executorch_exir_dialects_edge__ops_aten_pow_Tensor_Scalar"
24+
aten_op_BI = "torch.ops.aten.pow.Tensor_Tensor"
25+
exir_op_BI = "executorch_exir_dialects_edge__ops_aten_pow_Tensor_Tensor"
2626

2727
def __init__(self):
2828
super().__init__()

backends/arm/tosa_partitioner.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,13 @@ def is_partitioned(node: torch.fx.Node, tag=tag) -> bool:
118118

119119
if is_partitioned(node):
120120
for input in node.all_input_nodes:
121+
if input.target in (
122+
exir_ops.edge.aten.full.default,
123+
exir_ops.edge.aten.full_like.default,
124+
):
125+
continue
126+
if is_dequant_node(input):
127+
continue
121128
if is_partitioned(input):
122129
continue
123130
if get_first_fake_tensor(input).dtype.is_floating_point:

0 commit comments

Comments
 (0)