Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions backends/vulkan/quantizer/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@ python_library(
],
deps = [
"//caffe2:torch",
"//pytorch/ao/torchao/quantization/pt2e:quantizer",
"//pytorch/ao/torchao/quantization/pt2e:utils",
"//pytorch/ao:torchao", # @manual
],
)

Expand All @@ -22,6 +21,6 @@ python_library(
deps = [
":vulkan_quantizer_utils",
"//caffe2:torch",
"//pytorch/ao/torchao/quantization/pt2e:quantizer",
"//pytorch/ao:torchao", # @manual
],
)
104 changes: 70 additions & 34 deletions backends/vulkan/quantizer/vulkan_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@
import torch
from executorch.backends.vulkan.quantizer.vulkan_quantizer_utils import (
_convert_scalars_to_attrs,
bits_to_range,
OP_TO_ANNOTATOR,
propagate_annotation,
)
from torch.fx import Node
from torchao.quantization.pt2e import PerChannelMinMaxObserver
from torchao.quantization.pt2e import PerChannelMinMaxObserver, PlaceholderObserver
from torchao.quantization.pt2e.quantizer import (
QuantizationConfig,
QuantizationSpec,
Expand All @@ -28,50 +29,86 @@

__all__ = [
"VulkanQuantizer",
"get_linear_weight_qcs_qspec",
"get_linear_weight_only_qcs_xnn_qconfig",
"get_symmetric_quantization_config",
]


def get_linear_weight_qcs_qspec(quant_bits: int) -> QuantizationSpec:
@functools.lru_cache
def get_symmetric_quantization_config(
is_dynamic: bool = False,
weight_bits: int = 8,
act_bits: int = 8,
act_qmin: Optional[int] = None,
act_qmax: Optional[int] = None,
weight_qmin: Optional[int] = None,
weight_qmax: Optional[int] = None,
) -> QuantizationConfig:
"""
Return a QuantizationSpec to perform per-channel symmetric (i.e. "qcs") quantization
of weight tensors of linear layers to the number of bits specified by quant_bits.
Return a QuantizationConfig for Vulkan quantizer.
Args:
is_dynamic: If False, weight-only quantization. If True, dynamic quantization (activation + weight)
weight_bits: Number of bits for weight quantization (4 or 8)
act_bits: Number of bits for activation quantization (8)
act_qmin: Minimum quantization value for activations (auto-calculated if None)
act_qmax: Maximum quantization value for activations (auto-calculated if None)
weight_qmin: Minimum quantization value for weights (auto-calculated if None)
weight_qmax: Maximum quantization value for weights (auto-calculated if None)
"""
weight_observer = PerChannelMinMaxObserver
assert quant_bits in {
assert weight_bits in {
8,
4,
}, f"Unsupported weight quantization bits: {quant_bits}"
}, f"Unsupported weight quantization bits: {weight_bits}"

assert act_bits in {
8,
}, f"Unsupported activation quantization bits: {act_bits}"

quant_min = -(2 ** (quant_bits - 1))
quant_max = 2 ** (quant_bits - 1) - 1
qscheme = torch.per_channel_symmetric
# Auto-calculate weight ranges if not provided
if weight_qmin is None or weight_qmax is None:
weight_range = bits_to_range(weight_bits)
weight_qmin = weight_qmin if weight_qmin is not None else weight_range[0]
weight_qmax = weight_qmax if weight_qmax is not None else weight_range[1]

return QuantizationSpec(
# Weight quantization: per-channel symmetric for Vulkan
weight_quantization_spec = QuantizationSpec(
dtype=torch.int8,
quant_min=quant_min,
quant_max=quant_max,
qscheme=qscheme,
quant_min=weight_qmin,
quant_max=weight_qmax,
qscheme=torch.per_channel_symmetric,
ch_axis=0,
is_dynamic=False,
observer_or_fake_quant_ctr=weight_observer,
observer_or_fake_quant_ctr=PerChannelMinMaxObserver,
)


@functools.lru_cache
def get_linear_weight_only_qcs_xnn_qconfig(quant_bits: int) -> QuantizationConfig:
"""
Return a XNNPACKQuantizer QuantizationConfig class instance that specifies
quantizing the weight tensors of linear layers using per-channel symmetric (qcs)
quantization to the number of bits specified by quant_bits.
"""
weight_qspec = get_linear_weight_qcs_qspec(quant_bits)
# Configure activation quantization based on is_dynamic
if not is_dynamic:
# Weight-only quantization: no activation quantization
act_quantization_spec = None
output_activation_spec = None
else:
# Dynamic quantization: per-token input quantization, no output quantization
# Auto-calculate activation ranges if not provided
if act_qmin is None or act_qmax is None:
act_range = bits_to_range(act_bits)
act_qmin = act_qmin if act_qmin is not None else act_range[0]
act_qmax = act_qmax if act_qmax is not None else act_range[1]

act_observer_or_fake_quant_ctr = PlaceholderObserver
act_quantization_spec = QuantizationSpec(
dtype=torch.int8,
quant_min=act_qmin,
quant_max=act_qmax,
qscheme=torch.per_tensor_affine,
is_dynamic=True,
observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr,
)
output_activation_spec = None

return QuantizationConfig(
input_activation=None,
output_activation=None,
weight=weight_qspec,
input_activation=act_quantization_spec,
output_activation=output_activation_spec,
weight=weight_quantization_spec,
bias=None,
is_qat=False,
)
Expand Down Expand Up @@ -99,12 +136,11 @@ def transform_for_annotation(
return _convert_scalars_to_attrs(model)

def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
# currently only support static quant on Vulkan
model = self._annotate_for_static_quantization_config(model)
model = self._annotate_for_quantization_config(model)
propagate_annotation(model)
return model

def _annotate_all_static_patterns(
def _annotate_all_patterns(
self,
model: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
Expand All @@ -117,10 +153,10 @@ def _annotate_all_static_patterns(
OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn)
return model

def _annotate_for_static_quantization_config(
def _annotate_for_quantization_config(
self, model: torch.fx.GraphModule
) -> torch.fx.GraphModule:
self._annotate_all_static_patterns(
self._annotate_all_patterns(
model,
self.global_config,
)
Expand Down
25 changes: 21 additions & 4 deletions backends/vulkan/quantizer/vulkan_quantizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

# pyre-strict

from typing import Callable, Optional
from typing import Callable, Optional, Tuple

import torch
from torch.fx import Node
Expand All @@ -27,9 +27,26 @@
"OP_TO_ANNOTATOR",
"propagate_annotation",
"_convert_scalars_to_attrs",
"bits_to_range",
]


def bits_to_range(bits: int) -> Tuple[int, int]:
"""
Calculate quantization range for given number of bits.

Args:
bits: Number of quantization bits

Returns:
Tuple of (qmin, qmax) for the given bit width
"""
return (
-(2 ** (bits - 1)),
(2 ** (bits - 1) - 1),
)


AnnotatorType = Callable[
[
torch.fx.GraphModule,
Expand All @@ -48,7 +65,7 @@ def decorator(annotator: AnnotatorType) -> None:
return decorator


def _is_annotated(nodes: list[Node]):
def _is_annotated(nodes: list[Node]) -> bool:
"""
Given a list of nodes (that represents an operator pattern),
check if any of the node is annotated, return True if any of the node
Expand All @@ -63,7 +80,7 @@ def _is_annotated(nodes: list[Node]):
return annotated


def _mark_nodes_as_annotated(nodes: list[Node]):
def _mark_nodes_as_annotated(nodes: list[Node]) -> None:
for node in nodes:
if node is not None:
if "quantization_annotation" not in node.meta:
Expand Down Expand Up @@ -119,7 +136,7 @@ def _annotate_linear(
return annotated_partitions


def _is_share_obs_or_fq_op(op: Callable) -> bool:
def _is_share_obs_or_fq_op(op: Callable[..., torch.Tensor]) -> bool:
return op in [
torch.ops.aten.relu.default,
torch.ops.aten.hardtanh.default,
Expand Down
10 changes: 7 additions & 3 deletions backends/vulkan/test/test_vulkan_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from executorch.backends.vulkan._passes import FuseQuantizedOpsTransform

from executorch.backends.vulkan.quantizer.vulkan_quantizer import (
get_linear_weight_only_qcs_xnn_qconfig,
get_symmetric_quantization_config,
VulkanQuantizer,
)

Expand Down Expand Up @@ -101,7 +101,9 @@ def test_fuse_int8pack_mm(self):
sample_inputs = model.get_sample_inputs()

quantizer = VulkanQuantizer()
quantizer.set_global(get_linear_weight_only_qcs_xnn_qconfig(8))
quantizer.set_global(
get_symmetric_quantization_config(is_dynamic=False, weight_bits=8)
)

edge_manager = quantize_and_lower_module(
model,
Expand Down Expand Up @@ -129,7 +131,9 @@ def test_fuse_linear_qcs4w(self):
sample_inputs = model.get_sample_inputs()

quantizer = VulkanQuantizer()
quantizer.set_global(get_linear_weight_only_qcs_xnn_qconfig(4))
quantizer.set_global(
get_symmetric_quantization_config(is_dynamic=False, weight_bits=4)
)

edge_manager = quantize_and_lower_module(
model,
Expand Down
14 changes: 7 additions & 7 deletions extension/llm/export/quantizer_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import torch
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
get_symmetric_quantization_config,
get_symmetric_quantization_config as get_symmetric_quantization_config_xnnpack,
XNNPACKQuantizer,
)

Expand Down Expand Up @@ -127,11 +127,11 @@ def check_embedding_byte_registered():
"At the moment only per channel weight quantization is supported."
)
if quant_params.quantize_linear.is_qc4:
operator_config_dynamic = get_symmetric_quantization_config(
operator_config_dynamic = get_symmetric_quantization_config_xnnpack(
is_per_channel=True, is_dynamic=True, weight_qmin=-8, weight_qmax=7
)
else:
operator_config_dynamic = get_symmetric_quantization_config(
operator_config_dynamic = get_symmetric_quantization_config_xnnpack(
is_per_channel=True, is_dynamic=True
)
dynamic_quantizer.set_global(operator_config_dynamic)
Expand Down Expand Up @@ -247,13 +247,13 @@ def get_coreml_quantizer(pt2e_quantize: str):
raise NotImplementedError("4-bit Core ML quantizer is still under development")

elif pt2e_quantize == "coreml_baseline_8a_c8w":
config = get_symmetric_quantization_config(
config = get_symmetric_quantization_config_xnnpack(
is_per_channel=True, is_dynamic=False
)
quantizer = XNNPACKQuantizer().set_global(config)

elif pt2e_quantize == "coreml_baseline_8a_c4w":
config = get_symmetric_quantization_config(
config = get_symmetric_quantization_config_xnnpack(
is_per_channel=True, is_dynamic=False, weight_qmin=-8, weight_qmax=7
)
quantizer = XNNPACKQuantizer().set_global(config)
Expand All @@ -266,12 +266,12 @@ def get_coreml_quantizer(pt2e_quantize: str):

def get_vulkan_quantizer(pt2e_quantize: str):
from executorch.backends.vulkan.quantizer.vulkan_quantizer import (
get_linear_weight_only_qcs_xnn_qconfig,
get_symmetric_quantization_config as get_symmetric_quantization_config_vulkan,
VulkanQuantizer,
)

if pt2e_quantize == "vulkan_8w":
config = get_linear_weight_only_qcs_xnn_qconfig(8)
config = get_symmetric_quantization_config_vulkan(is_dynamic=False, weight_bits=8)
else:
raise ValueError(f"Unsupported Vulkan quantizer specification {pt2e_quantize}")

Expand Down
Loading