Skip to content

Commit a39fbd5

Browse files
committed
Address merge conflicts and typing issues
Signed-off-by: Tom Allsop <[email protected]> Change-Id: If9f6bbb9d23f00b9cea6d00c6aa97a6d5d3e77ab
1 parent df441f5 commit a39fbd5

File tree

2 files changed

+19
-15
lines changed

2 files changed

+19
-15
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
8989

9090
self.add_pass(AnnotateDecomposedMatmulPass())
9191
self.add_pass(QuantizeOperatorArguments())
92-
self.add_pass(FoldAndAnnotateQParamsPass())
92+
self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg]
9393
self.add_pass(RetraceFoldedDtypesPass())
9494
self.add_pass(InsertTableOpsPass(exported_program))
9595

@@ -125,7 +125,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
125125

126126
self.add_pass(AnnotateDecomposedMatmulPass())
127127
self.add_pass(QuantizeOperatorArguments())
128-
self.add_pass(FoldAndAnnotateQParamsPass())
128+
self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg]
129129
self.add_pass(RetraceFoldedDtypesPass())
130130
self.add_pass(InsertTableOpsPass(exported_program))
131131

backends/arm/operators/op_clamp.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,9 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree
66

7-
from numbers import Number
8-
from typing import List, Tuple
7+
from typing import Any, List, Tuple
98

10-
import serializer.tosa_serializer as ts
9+
import serializer.tosa_serializer as ts # type: ignore
1110

1211
import torch
1312
from executorch.backends.arm.operators.node_visitor import (
@@ -41,7 +40,7 @@ def _create_clamp_node(
4140
max_int: int,
4241
min_fp32: float,
4342
max_fp32: float,
44-
):
43+
) -> None:
4544
attr = ts.TosaSerializerAttribute()
4645
attr.ClampAttribute(
4746
tosa_graph.builder,
@@ -53,22 +52,27 @@ def _create_clamp_node(
5352
tosa_graph.addOperator(TosaOp.Op().CLAMP, [input_name], [output_name], attr)
5453

5554
def _get_min_max_arguments(
56-
self,
57-
node: Node,
58-
dtype_min: Number,
59-
dtype_max: Number,
60-
) -> Tuple[Number, Number]:
55+
self, node: Node, dtype_min: int | float, dtype_max: int | float
56+
) -> Tuple[int | float, int | float]:
57+
58+
def cast_type(value: Any) -> int | float:
59+
if isinstance(value, int):
60+
return value
61+
else:
62+
# Attempt to cast to float
63+
return float(value)
64+
6165
assert 2 <= len(node.args) <= 3
6266

6367
min_arg = dtype_min
6468
max_arg = dtype_max
6569

6670
if node.args[1] is not None:
67-
min_arg = node.args[1]
71+
min_arg = cast_type(node.args[1])
6872

6973
if len(node.args) > 2:
7074
if node.args[2] is not None:
71-
max_arg = node.args[2]
75+
max_arg = cast_type(node.args[2])
7276

7377
return min_arg, max_arg
7478

@@ -92,8 +96,8 @@ def define_node(
9296
tosa_graph,
9397
inputs[0].name,
9498
output.name,
95-
min_int8,
96-
max_int8,
99+
int(min_int8),
100+
int(max_int8),
97101
0,
98102
0,
99103
)

0 commit comments

Comments
 (0)