Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/arm/_passes/decompose_grouped_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from copy import copy

import torch
from executorch.backends.arm.tosa_quant_utils import QuantArgs
from executorch.backends.arm._passes.quant_args import QuantArgs
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
get_param_tensor,
is_param_node,
)
from executorch.backends.arm.constants import DQ_OPS, Q_OPS

from executorch.backends.arm.tosa_quant_utils import QuantArgs
from executorch.backends.arm._passes.quant_args import QuantArgs
from executorch.backends.arm.constants import DQ_OPS, Q_OPS

from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/_passes/fuse_quantized_activation_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
# pyre-unsafe

import torch
from executorch.backends.arm._passes.quant_args import QuantArgs
from executorch.backends.arm.constants import Q_OPS
from executorch.backends.arm.tosa_quant_utils import QuantArgs
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
from torch.fx import Node
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/_passes/insert_rescales_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

import torch
from executorch.backends.arm._passes.arm_pass_utils import create_node
from executorch.backends.arm._passes.quant_args import QuantArgs
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
from executorch.backends.arm.tosa_quant_utils import QuantArgs
from executorch.exir.pass_base import ExportPass, PassResult
from torch import Tensor
from torch.fx import GraphModule, Node
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/_passes/insert_table_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import torch
from executorch.backends.arm._passes.arm_pass_utils import create_node
from executorch.backends.arm.tosa_quant_utils import QuantArgs
from executorch.backends.arm._passes.quant_args import QuantArgs
from executorch.exir import ExportedProgram

from executorch.exir.dialects._ops import ops as exir_ops
Expand Down
125 changes: 125 additions & 0 deletions backends/arm/_passes/quant_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, cast, NamedTuple

import torch
from executorch.exir.dialects._ops import ops as exir_ops

exir_ops = cast(Any, exir_ops)
from executorch.backends.arm.constants import PER_CHANNEL_QDQ_OPS, PER_TENSOR_QDQ_OPS
from torch import Tensor


class QuantArgs(NamedTuple):
scale: list[float] | float
zp: list[int] | int
qmin: int
qmax: int
dtype: torch.dtype
axis: int = 0
per_channel: bool = False

def quantize_value(self, x: torch.Tensor | float) -> Tensor:
"""Quantizes the input tensor or value to a quantized tensor. If the input is
not a tensor, it is converted to a tensor first. If self.per_channel is True,
the quantization is done per channel, otherwise it is done per tensor.
"""
if not isinstance(x, torch.Tensor):
x = torch.Tensor([x])
x = x.to(torch.float32)
if self.per_channel:
q_op = exir_ops.edge.quantized_decomposed.quantize_per_channel.default
args = (
x,
torch.tensor(self.scale),
torch.tensor(self.zp),
self.axis,
self.qmin,
self.qmax,
self.dtype,
)
else:
q_op = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
args = (x, self.scale, self.zp, self.qmin, self.qmax, self.dtype) # type: ignore[assignment]
return q_op(*args)

def dequantize_value(self, qx: torch.Tensor) -> torch.Tensor:
"""Dequantizes the input tensor or value to a dequantized tensor If the input
is not a tensor, it is converted to a tensor first. If self.per_channel is True,
the dequantization is done per channel, otherwise it is done per tensor.
"""
if self.per_channel:
dq_op = exir_ops.edge.quantized_decomposed.dequantize_per_channel.default
args = (
qx,
torch.tensor(self.scale),
torch.tensor(self.zp),
self.axis,
self.qmin,
self.qmax,
self.dtype,
)
else:
dq_op = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
args = (qx, self.scale, self.zp, self.qmin, self.qmax, self.dtype) # type: ignore[assignment]
return dq_op(*args)

@classmethod
def from_operator(cls, op, args):
if op in PER_TENSOR_QDQ_OPS:
return cls(
scale=cast(float, args[1]),
zp=cast(int, args[2]),
qmin=cast(int, args[3]),
qmax=cast(int, args[4]),
dtype=cast(torch.dtype, args[5]),
axis=0,
per_channel=False,
)
elif op in PER_CHANNEL_QDQ_OPS:
return cls(
scale=cast(list[float], args[1].tolist()),
zp=cast(list[int], args[2].tolist()),
axis=cast(int, args[3]),
qmin=cast(int, args[4]),
qmax=cast(int, args[5]),
dtype=cast(torch.dtype, args[6]),
per_channel=True,
)
else:
# We're only handling per tensor and per channel quantization
raise NotImplementedError(f"Unsupported quantization operation: {op}")

def get_scale_per_tensor(self) -> float:
if not isinstance(self.scale, float):
raise TypeError(
f"Expected scale {self.scale} to be a float but found scale of "
f"type {type(self.scale)}"
)
return self.scale

def get_zp_per_tensor(self) -> int:
if not isinstance(self.zp, int):
raise TypeError(
f"Expected zero point {self.zp} to be an int but found zp of "
f"type {type(self.zp)}"
)
return self.zp

def get_scale_per_channel(self) -> list[float]:
if not isinstance(self.scale, list):
raise TypeError(
f"Expected scale {self.scale} to be a list but found scale of "
f"type {type(self.scale)}"
)
return self.scale

def get_zp_per_channel(self) -> list[int]:
if not isinstance(self.zp, list):
raise TypeError(
f"Expected zero point {self.zp} to be a list but found zp of "
f"type {type(self.zp)}"
)
return self.zp
121 changes: 1 addition & 120 deletions backends/arm/tosa_quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,14 @@

import math

from typing import Any, cast, NamedTuple, Tuple
from typing import Any, Tuple

import executorch.backends.arm.tosa_specification as tosa_specification

import torch.fx
import torch.fx.node
from executorch.backends.arm.constants import PER_CHANNEL_QDQ_OPS, PER_TENSOR_QDQ_OPS

from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.exir.dialects._ops import ops as exir_ops
from torch import Tensor
from torch.fx import Node
from tosa.RoundingMode import RoundingMode # type: ignore

Expand Down Expand Up @@ -109,122 +106,6 @@ def insert_rescale_op_to_int8(
)


class QuantArgs(NamedTuple):
scale: list[float] | float
zp: list[int] | int
qmin: int
qmax: int
dtype: torch.dtype
axis: int = 0
per_channel: bool = False

def quantize_value(self, x: torch.Tensor | float) -> Tensor:
"""Quantizes the input tensor or value to a quantized tensor. If the input is
not a tensor, it is converted to a tensor first. If self.per_channel is True,
the quantization is done per channel, otherwise it is done per tensor.
"""
if not isinstance(x, torch.Tensor):
x = torch.Tensor([x])
x = x.to(torch.float32)
if self.per_channel:
q_op = exir_ops.edge.quantized_decomposed.quantize_per_channel.default
args = (
x,
torch.tensor(self.scale),
torch.tensor(self.zp),
self.axis,
self.qmin,
self.qmax,
self.dtype,
)
else:
q_op = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
args = (x, self.scale, self.zp, self.qmin, self.qmax, self.dtype) # type: ignore[assignment]

return q_op(*args)

def dequantize_value(self, qx: torch.Tensor) -> torch.Tensor:
"""Dequantizes the input tensor or value to a dequantized tensor If the input
is not a tensor, it is converted to a tensor first. If self.per_channel is True,
the dequantization is done per channel, otherwise it is done per tensor.
"""
if self.per_channel:
dq_op = exir_ops.edge.quantized_decomposed.dequantize_per_channel.default
args = (
qx,
torch.tensor(self.scale),
torch.tensor(self.zp),
self.axis,
self.qmin,
self.qmax,
self.dtype,
)
else:
dq_op = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
args = (qx, self.scale, self.zp, self.qmin, self.qmax, self.dtype) # type: ignore[assignment]

return dq_op(*args)

@classmethod
def from_operator(cls, op, args):
if op in PER_TENSOR_QDQ_OPS:
return cls(
scale=cast(float, args[1]),
zp=cast(int, args[2]),
qmin=cast(int, args[3]),
qmax=cast(int, args[4]),
dtype=cast(torch.dtype, args[5]),
axis=0,
per_channel=False,
)
elif op in PER_CHANNEL_QDQ_OPS:
return cls(
scale=cast(list[float], args[1].tolist()),
zp=cast(list[int], args[2].tolist()),
axis=cast(int, args[3]),
qmin=cast(int, args[4]),
qmax=cast(int, args[5]),
dtype=cast(torch.dtype, args[6]),
per_channel=True,
)

else:
# We're only handling per tensor and per channel quantization
raise NotImplementedError(f"Unsupported quantization operation: {op}")

def get_scale_per_tensor(self) -> float:
if not isinstance(self.scale, float):
raise TypeError(
f"Expected scale {self.scale} to be a float but found scale of "
f"type {type(self.scale)}"
)
return self.scale

def get_zp_per_tensor(self) -> int:
if not isinstance(self.zp, int):
raise TypeError(
f"Expected zero point {self.zp} to be an int but found zp of "
f"type {type(self.zp)}"
)
return self.zp

def get_scale_per_channel(self) -> list[float]:
if not isinstance(self.scale, list):
raise TypeError(
f"Expected scale {self.scale} to be a list but found scale of "
f"type {type(self.scale)}"
)
return self.scale

def get_zp_per_channel(self) -> list[int]:
if not isinstance(self.zp, list):
raise TypeError(
f"Expected zero point {self.zp} to be a list but found zp of "
f"type {type(self.zp)}"
)
return self.zp


# TOSA uses the RESCALE operation to scale between values with differing precision.
# The RESCALE operator is defined using an integer multiply, add, and shift.
# This utility function is for calculating the multier and shift given a scale.
Expand Down
Loading