Skip to content

Commit eaad1c2

Browse files
lucylqNinja91
andauthored
Arm backend: Add INT16 support to rescale operation (#14301)
Differential Revision: D80513725 Pull Request resolved: #13802 #13802 (comment) failed to cp to main Co-authored-by: Nitin Jain <[email protected]>
1 parent 7edb278 commit eaad1c2

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

backends/arm/operators/op_rescale.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,20 @@ def define_node(
4646
input_zp = cast(int, node.args[3])
4747
output_zp = cast(int, node.args[4])
4848

49-
if input_dtype != map_dtype(torch.int8, self.tosa_spec) and input_zp != 0:
49+
if (
50+
input_dtype
51+
not in [
52+
map_dtype(torch.int8, self.tosa_spec),
53+
map_dtype(torch.int16, self.tosa_spec),
54+
]
55+
and input_zp != 0
56+
):
5057
raise ValueError(
51-
f"If input dtype is not int8, input_zp must be 0. Got input_dtype{input_dtype=}, {input_zp=}"
58+
f"If input dtype is not int8 or int16, input_zp must be 0. Got input_dtype{input_dtype=}, {input_zp=}"
5259
)
53-
if output_dtype != torch.int8 and output_zp != 0:
60+
if output_dtype not in [torch.int8, torch.int16] and output_zp != 0:
5461
raise ValueError(
55-
f"If output dtype is not int8, output_zp must be 0. Got {ts.DTypeNames[output_dtype]}, {output_zp=}"
62+
f"If output dtype is not int8 or int16, output_zp must be 0. Got {ts.DTypeNames[output_dtype]}, {output_zp=}"
5663
)
5764

5865
build_rescale(

0 commit comments

Comments
 (0)