Skip to content

Commit 564b6e0

Browse files
Arm backend: Move QuantArgs to its own file in _passes/ (#13124)
The is part of an effort to reduce the number of functions/classes in utility files. This class has no reason being in a utility file when it has a clear purpose on its own. Signed-off-by: Sebastian Larsson <[email protected]> Co-authored-by: Oscar Andersson <[email protected]>
1 parent 8dfcb21 commit 564b6e0

7 files changed

+132
-126
lines changed

backends/arm/_passes/decompose_grouped_conv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from copy import copy
77

88
import torch
9-
from executorch.backends.arm.tosa_quant_utils import QuantArgs
9+
from executorch.backends.arm._passes.quant_args import QuantArgs
1010
from executorch.exir.dialects._ops import ops as exir_ops
1111
from executorch.exir.pass_base import ExportPass
1212

backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
get_param_tensor,
1616
is_param_node,
1717
)
18-
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
1918

20-
from executorch.backends.arm.tosa_quant_utils import QuantArgs
19+
from executorch.backends.arm._passes.quant_args import QuantArgs
20+
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
2121

2222
from executorch.exir.dialects._ops import ops as exir_ops
2323
from executorch.exir.dialects.edge._ops import EdgeOpOverload

backends/arm/_passes/fuse_quantized_activation_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
# pyre-unsafe
77

88
import torch
9+
from executorch.backends.arm._passes.quant_args import QuantArgs
910
from executorch.backends.arm.constants import Q_OPS
10-
from executorch.backends.arm.tosa_quant_utils import QuantArgs
1111
from executorch.exir.dialects._ops import ops as exir_ops
1212
from executorch.exir.pass_base import ExportPass, PassResult
1313
from torch.fx import Node

backends/arm/_passes/insert_rescales_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99

1010
import torch
1111
from executorch.backends.arm._passes.arm_pass_utils import create_node
12+
from executorch.backends.arm._passes.quant_args import QuantArgs
1213
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
13-
from executorch.backends.arm.tosa_quant_utils import QuantArgs
1414
from executorch.exir.pass_base import ExportPass, PassResult
1515
from torch import Tensor
1616
from torch.fx import GraphModule, Node

backends/arm/_passes/insert_table_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import torch
1212
from executorch.backends.arm._passes.arm_pass_utils import create_node
13-
from executorch.backends.arm.tosa_quant_utils import QuantArgs
13+
from executorch.backends.arm._passes.quant_args import QuantArgs
1414
from executorch.exir import ExportedProgram
1515

1616
from executorch.exir.dialects._ops import ops as exir_ops

backends/arm/_passes/quant_args.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from typing import Any, cast, NamedTuple
6+
7+
import torch
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
10+
exir_ops = cast(Any, exir_ops)
11+
from executorch.backends.arm.constants import PER_CHANNEL_QDQ_OPS, PER_TENSOR_QDQ_OPS
12+
from torch import Tensor
13+
14+
15+
class QuantArgs(NamedTuple):
16+
scale: list[float] | float
17+
zp: list[int] | int
18+
qmin: int
19+
qmax: int
20+
dtype: torch.dtype
21+
axis: int = 0
22+
per_channel: bool = False
23+
24+
def quantize_value(self, x: torch.Tensor | float) -> Tensor:
25+
"""Quantizes the input tensor or value to a quantized tensor. If the input is
26+
not a tensor, it is converted to a tensor first. If self.per_channel is True,
27+
the quantization is done per channel, otherwise it is done per tensor.
28+
"""
29+
if not isinstance(x, torch.Tensor):
30+
x = torch.Tensor([x])
31+
x = x.to(torch.float32)
32+
if self.per_channel:
33+
q_op = exir_ops.edge.quantized_decomposed.quantize_per_channel.default
34+
args = (
35+
x,
36+
torch.tensor(self.scale),
37+
torch.tensor(self.zp),
38+
self.axis,
39+
self.qmin,
40+
self.qmax,
41+
self.dtype,
42+
)
43+
else:
44+
q_op = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
45+
args = (x, self.scale, self.zp, self.qmin, self.qmax, self.dtype) # type: ignore[assignment]
46+
return q_op(*args)
47+
48+
def dequantize_value(self, qx: torch.Tensor) -> torch.Tensor:
49+
"""Dequantizes the input tensor or value to a dequantized tensor If the input
50+
is not a tensor, it is converted to a tensor first. If self.per_channel is True,
51+
the dequantization is done per channel, otherwise it is done per tensor.
52+
"""
53+
if self.per_channel:
54+
dq_op = exir_ops.edge.quantized_decomposed.dequantize_per_channel.default
55+
args = (
56+
qx,
57+
torch.tensor(self.scale),
58+
torch.tensor(self.zp),
59+
self.axis,
60+
self.qmin,
61+
self.qmax,
62+
self.dtype,
63+
)
64+
else:
65+
dq_op = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
66+
args = (qx, self.scale, self.zp, self.qmin, self.qmax, self.dtype) # type: ignore[assignment]
67+
return dq_op(*args)
68+
69+
@classmethod
70+
def from_operator(cls, op, args):
71+
if op in PER_TENSOR_QDQ_OPS:
72+
return cls(
73+
scale=cast(float, args[1]),
74+
zp=cast(int, args[2]),
75+
qmin=cast(int, args[3]),
76+
qmax=cast(int, args[4]),
77+
dtype=cast(torch.dtype, args[5]),
78+
axis=0,
79+
per_channel=False,
80+
)
81+
elif op in PER_CHANNEL_QDQ_OPS:
82+
return cls(
83+
scale=cast(list[float], args[1].tolist()),
84+
zp=cast(list[int], args[2].tolist()),
85+
axis=cast(int, args[3]),
86+
qmin=cast(int, args[4]),
87+
qmax=cast(int, args[5]),
88+
dtype=cast(torch.dtype, args[6]),
89+
per_channel=True,
90+
)
91+
else:
92+
# We're only handling per tensor and per channel quantization
93+
raise NotImplementedError(f"Unsupported quantization operation: {op}")
94+
95+
def get_scale_per_tensor(self) -> float:
96+
if not isinstance(self.scale, float):
97+
raise TypeError(
98+
f"Expected scale {self.scale} to be a float but found scale of "
99+
f"type {type(self.scale)}"
100+
)
101+
return self.scale
102+
103+
def get_zp_per_tensor(self) -> int:
104+
if not isinstance(self.zp, int):
105+
raise TypeError(
106+
f"Expected zero point {self.zp} to be an int but found zp of "
107+
f"type {type(self.zp)}"
108+
)
109+
return self.zp
110+
111+
def get_scale_per_channel(self) -> list[float]:
112+
if not isinstance(self.scale, list):
113+
raise TypeError(
114+
f"Expected scale {self.scale} to be a list but found scale of "
115+
f"type {type(self.scale)}"
116+
)
117+
return self.scale
118+
119+
def get_zp_per_channel(self) -> list[int]:
120+
if not isinstance(self.zp, list):
121+
raise TypeError(
122+
f"Expected zero point {self.zp} to be a list but found zp of "
123+
f"type {type(self.zp)}"
124+
)
125+
return self.zp

backends/arm/tosa_quant_utils.py

Lines changed: 1 addition & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,14 @@
99

1010
import math
1111

12-
from typing import Any, cast, NamedTuple, Tuple
12+
from typing import Any, Tuple
1313

1414
import executorch.backends.arm.tosa_specification as tosa_specification
1515

1616
import torch.fx
1717
import torch.fx.node
18-
from executorch.backends.arm.constants import PER_CHANNEL_QDQ_OPS, PER_TENSOR_QDQ_OPS
1918

2019
from executorch.backends.arm.tosa_mapping import TosaArg
21-
from executorch.exir.dialects._ops import ops as exir_ops
22-
from torch import Tensor
2320
from torch.fx import Node
2421
from tosa.RoundingMode import RoundingMode # type: ignore
2522

@@ -109,122 +106,6 @@ def insert_rescale_op_to_int8(
109106
)
110107

111108

112-
class QuantArgs(NamedTuple):
113-
scale: list[float] | float
114-
zp: list[int] | int
115-
qmin: int
116-
qmax: int
117-
dtype: torch.dtype
118-
axis: int = 0
119-
per_channel: bool = False
120-
121-
def quantize_value(self, x: torch.Tensor | float) -> Tensor:
122-
"""Quantizes the input tensor or value to a quantized tensor. If the input is
123-
not a tensor, it is converted to a tensor first. If self.per_channel is True,
124-
the quantization is done per channel, otherwise it is done per tensor.
125-
"""
126-
if not isinstance(x, torch.Tensor):
127-
x = torch.Tensor([x])
128-
x = x.to(torch.float32)
129-
if self.per_channel:
130-
q_op = exir_ops.edge.quantized_decomposed.quantize_per_channel.default
131-
args = (
132-
x,
133-
torch.tensor(self.scale),
134-
torch.tensor(self.zp),
135-
self.axis,
136-
self.qmin,
137-
self.qmax,
138-
self.dtype,
139-
)
140-
else:
141-
q_op = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
142-
args = (x, self.scale, self.zp, self.qmin, self.qmax, self.dtype) # type: ignore[assignment]
143-
144-
return q_op(*args)
145-
146-
def dequantize_value(self, qx: torch.Tensor) -> torch.Tensor:
147-
"""Dequantizes the input tensor or value to a dequantized tensor If the input
148-
is not a tensor, it is converted to a tensor first. If self.per_channel is True,
149-
the dequantization is done per channel, otherwise it is done per tensor.
150-
"""
151-
if self.per_channel:
152-
dq_op = exir_ops.edge.quantized_decomposed.dequantize_per_channel.default
153-
args = (
154-
qx,
155-
torch.tensor(self.scale),
156-
torch.tensor(self.zp),
157-
self.axis,
158-
self.qmin,
159-
self.qmax,
160-
self.dtype,
161-
)
162-
else:
163-
dq_op = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
164-
args = (qx, self.scale, self.zp, self.qmin, self.qmax, self.dtype) # type: ignore[assignment]
165-
166-
return dq_op(*args)
167-
168-
@classmethod
169-
def from_operator(cls, op, args):
170-
if op in PER_TENSOR_QDQ_OPS:
171-
return cls(
172-
scale=cast(float, args[1]),
173-
zp=cast(int, args[2]),
174-
qmin=cast(int, args[3]),
175-
qmax=cast(int, args[4]),
176-
dtype=cast(torch.dtype, args[5]),
177-
axis=0,
178-
per_channel=False,
179-
)
180-
elif op in PER_CHANNEL_QDQ_OPS:
181-
return cls(
182-
scale=cast(list[float], args[1].tolist()),
183-
zp=cast(list[int], args[2].tolist()),
184-
axis=cast(int, args[3]),
185-
qmin=cast(int, args[4]),
186-
qmax=cast(int, args[5]),
187-
dtype=cast(torch.dtype, args[6]),
188-
per_channel=True,
189-
)
190-
191-
else:
192-
# We're only handling per tensor and per channel quantization
193-
raise NotImplementedError(f"Unsupported quantization operation: {op}")
194-
195-
def get_scale_per_tensor(self) -> float:
196-
if not isinstance(self.scale, float):
197-
raise TypeError(
198-
f"Expected scale {self.scale} to be a float but found scale of "
199-
f"type {type(self.scale)}"
200-
)
201-
return self.scale
202-
203-
def get_zp_per_tensor(self) -> int:
204-
if not isinstance(self.zp, int):
205-
raise TypeError(
206-
f"Expected zero point {self.zp} to be an int but found zp of "
207-
f"type {type(self.zp)}"
208-
)
209-
return self.zp
210-
211-
def get_scale_per_channel(self) -> list[float]:
212-
if not isinstance(self.scale, list):
213-
raise TypeError(
214-
f"Expected scale {self.scale} to be a list but found scale of "
215-
f"type {type(self.scale)}"
216-
)
217-
return self.scale
218-
219-
def get_zp_per_channel(self) -> list[int]:
220-
if not isinstance(self.zp, list):
221-
raise TypeError(
222-
f"Expected zero point {self.zp} to be a list but found zp of "
223-
f"type {type(self.zp)}"
224-
)
225-
return self.zp
226-
227-
228109
# TOSA uses the RESCALE operation to scale between values with differing precision.
229110
# The RESCALE operator is defined using an integer multiply, add, and shift.
230111
# This utility function is for calculating the multier and shift given a scale.

0 commit comments

Comments
 (0)