Skip to content

Commit cc3a35f

Browse files
authored
unbreak lintrunner in backends/arm (#12075)
1 parent 3e19e67 commit cc3a35f

File tree

3 files changed

+15
-5
lines changed

3 files changed

+15
-5
lines changed

backends/arm/_passes/decompose_sqrt_pass.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
7+
from typing import Tuple, Union
8+
79
import torch
810
from executorch.exir.dialects._ops import ops as exir_ops
911
from executorch.exir.pass_base import ExportPass
@@ -15,7 +17,7 @@
1517
)
1618

1719

18-
def get_sqrt_decomposition(op) -> tuple:
20+
def get_sqrt_decomposition(op) -> Union[Tuple, torch._ops.OpOverload]:
1921
# TODO : "MLETORCH-863 : Replace current sqrt -> pow.Tensor_Scalar workaround with pow.Tensor_Tensor"
2022
if op in edge_sqrt_ops:
2123
return exir_ops.edge.aten.pow.Tensor_Scalar

backends/arm/_passes/replace_scalar_with_tensor_pass.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
# pyre-unsafe
77

88

9-
from typing import Dict
9+
from typing import Dict, Union
1010

1111
import torch
1212
from executorch.backends.transforms.replace_scalar_with_tensor import (
@@ -18,7 +18,10 @@
1818

1919

2020
# Operators that are included for both TOSA profiles
21-
_common_ops: Dict[EdgeOpOverload, EdgeOpOverload] = {
21+
_common_ops: Dict[
22+
Union[EdgeOpOverload, torch._ops.OpOverload],
23+
Union[EdgeOpOverload, torch._ops.OpOverload],
24+
] = {
2225
exir_ops.edge.aten.add.Scalar: exir_ops.edge.aten.add.Tensor,
2326
exir_ops.edge.aten.sub.Scalar: exir_ops.edge.aten.sub.Tensor,
2427
exir_ops.edge.aten.mul.Scalar: exir_ops.edge.aten.mul.Tensor,

backends/transforms/replace_scalar_with_tensor.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
77

8-
from typing import Dict, Optional
8+
from typing import Dict, Optional, Union
99

1010
import torch
1111
from executorch.exir.dialects._ops import ops as exir_ops
@@ -32,7 +32,12 @@ class ReplaceScalarWithTensorArgPass(ExportPass):
3232

3333
def __init__(
3434
self,
35-
scalar_to_tensor_ops: Optional[Dict[EdgeOpOverload, EdgeOpOverload]] = None,
35+
scalar_to_tensor_ops: Optional[
36+
Dict[
37+
Union[EdgeOpOverload, torch._ops.OpOverload],
38+
Union[EdgeOpOverload, torch._ops.OpOverload],
39+
]
40+
] = None,
3641
):
3742
if scalar_to_tensor_ops is not None:
3843
self.scalar_to_tensor_ops = scalar_to_tensor_ops

0 commit comments

Comments
 (0)