Skip to content

Commit 9c28a88

Browse files
author
morelos
committed
[ET-VK] Creating get_symmetric_quantization_config
Pull Request resolved: #12573 # Context Eventually dynamic quantization will be enabled in the vulkan_quantizer (particularly 8bit dyn act with 8bit weights). In order to enable this functionality we need to utilize a similar method as XNNPack with how they define their quantization config. This diff aims to align with XNNPack quantizer logic and also migrate away from utilizing the old static quantization config logic. # Changes A few noticable changes is that we migrate away from `get_linear_weight_only_qcs_xnn_qconfig`, and we now define a symmetric config that has parameters to define whether it's dynamically quantized or not. Furthermore, we also incorporate bits_to_range so that we can automatically designate the min and max quant ranges without having to set them during initialization. We also change some wording from using just static as we are now enabling dynamic quantization as well. Furthermore, we change internally other codebases that are calling our existing legacy config, and move them into the more universal symmetric config. Since this follows the same naming scheme as XNNPack, I have decided to just add aliases in cases where its being imported directly along with XNNPack. ghstack-source-id: 299473613 @exported-using-ghexport Differential Revision: [D78291249](https://our.internmc.facebook.com/intern/diff/D78291249/)
1 parent 8abb889 commit 9c28a88

File tree

4 files changed

+104
-45
lines changed

4 files changed

+104
-45
lines changed

backends/vulkan/quantizer/vulkan_quantizer.py

Lines changed: 70 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,12 @@
1414
import torch
1515
from executorch.backends.vulkan.quantizer.vulkan_quantizer_utils import (
1616
_convert_scalars_to_attrs,
17+
bits_to_range,
1718
OP_TO_ANNOTATOR,
1819
propagate_annotation,
1920
)
2021
from torch.fx import Node
21-
from torchao.quantization.pt2e import PerChannelMinMaxObserver
22+
from torchao.quantization.pt2e import PerChannelMinMaxObserver, PlaceholderObserver
2223
from torchao.quantization.pt2e.quantizer import (
2324
QuantizationConfig,
2425
QuantizationSpec,
@@ -28,50 +29,86 @@
2829

2930
__all__ = [
3031
"VulkanQuantizer",
31-
"get_linear_weight_qcs_qspec",
32-
"get_linear_weight_only_qcs_xnn_qconfig",
32+
"get_symmetric_quantization_config",
3333
]
3434

3535

36-
def get_linear_weight_qcs_qspec(quant_bits: int) -> QuantizationSpec:
36+
@functools.lru_cache
37+
def get_symmetric_quantization_config(
38+
is_dynamic: bool = False,
39+
weight_bits: int = 8,
40+
act_bits: int = 8,
41+
act_qmin: Optional[int] = None,
42+
act_qmax: Optional[int] = None,
43+
weight_qmin: Optional[int] = None,
44+
weight_qmax: Optional[int] = None,
45+
) -> QuantizationConfig:
3746
"""
38-
Return a QuantizationSpec to perform per-channel symmetric (i.e. "qcs") quantization
39-
of weight tensors of linear layers to the number of bits specified by quant_bits.
47+
Return a QuantizationConfig for Vulkan quantizer.
48+
49+
Args:
50+
is_dynamic: If False, weight-only quantization. If True, dynamic quantization (activation + weight)
51+
weight_bits: Number of bits for weight quantization (4 or 8)
52+
act_bits: Number of bits for activation quantization (8)
53+
act_qmin: Minimum quantization value for activations (auto-calculated if None)
54+
act_qmax: Maximum quantization value for activations (auto-calculated if None)
55+
weight_qmin: Minimum quantization value for weights (auto-calculated if None)
56+
weight_qmax: Maximum quantization value for weights (auto-calculated if None)
4057
"""
41-
weight_observer = PerChannelMinMaxObserver
42-
assert quant_bits in {
58+
assert weight_bits in {
4359
8,
4460
4,
45-
}, f"Unsupported weight quantization bits: {quant_bits}"
61+
}, f"Unsupported weight quantization bits: {weight_bits}"
62+
63+
assert act_bits in {
64+
8,
65+
}, f"Unsupported activation quantization bits: {act_bits}"
4666

47-
quant_min = -(2 ** (quant_bits - 1))
48-
quant_max = 2 ** (quant_bits - 1) - 1
49-
qscheme = torch.per_channel_symmetric
67+
# Auto-calculate weight ranges if not provided
68+
if weight_qmin is None or weight_qmax is None:
69+
weight_range = bits_to_range(weight_bits)
70+
weight_qmin = weight_qmin if weight_qmin is not None else weight_range[0]
71+
weight_qmax = weight_qmax if weight_qmax is not None else weight_range[1]
5072

51-
return QuantizationSpec(
73+
# Weight quantization: per-channel symmetric for Vulkan
74+
weight_quantization_spec = QuantizationSpec(
5275
dtype=torch.int8,
53-
quant_min=quant_min,
54-
quant_max=quant_max,
55-
qscheme=qscheme,
76+
quant_min=weight_qmin,
77+
quant_max=weight_qmax,
78+
qscheme=torch.per_channel_symmetric,
5679
ch_axis=0,
5780
is_dynamic=False,
58-
observer_or_fake_quant_ctr=weight_observer,
81+
observer_or_fake_quant_ctr=PerChannelMinMaxObserver,
5982
)
6083

61-
62-
@functools.lru_cache
63-
def get_linear_weight_only_qcs_xnn_qconfig(quant_bits: int) -> QuantizationConfig:
64-
"""
65-
Return a XNNPACKQuantizer QuantizationConfig class instance that specifies
66-
quantizing the weight tensors of linear layers using per-channel symmetric (qcs)
67-
quantization to the number of bits specified by quant_bits.
68-
"""
69-
weight_qspec = get_linear_weight_qcs_qspec(quant_bits)
84+
# Configure activation quantization based on is_dynamic
85+
if not is_dynamic:
86+
# Weight-only quantization: no activation quantization
87+
act_quantization_spec = None
88+
output_activation_spec = None
89+
else:
90+
# Dynamic quantization: per-token input quantization, no output quantization
91+
# Auto-calculate activation ranges if not provided
92+
if act_qmin is None or act_qmax is None:
93+
act_range = bits_to_range(act_bits)
94+
act_qmin = act_qmin if act_qmin is not None else act_range[0]
95+
act_qmax = act_qmax if act_qmax is not None else act_range[1]
96+
97+
act_observer_or_fake_quant_ctr = PlaceholderObserver
98+
act_quantization_spec = QuantizationSpec(
99+
dtype=torch.int8,
100+
quant_min=act_qmin,
101+
quant_max=act_qmax,
102+
qscheme=torch.per_tensor_affine,
103+
is_dynamic=True,
104+
observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr,
105+
)
106+
output_activation_spec = None
70107

71108
return QuantizationConfig(
72-
input_activation=None,
73-
output_activation=None,
74-
weight=weight_qspec,
109+
input_activation=act_quantization_spec,
110+
output_activation=output_activation_spec,
111+
weight=weight_quantization_spec,
75112
bias=None,
76113
is_qat=False,
77114
)
@@ -99,12 +136,11 @@ def transform_for_annotation(
99136
return _convert_scalars_to_attrs(model)
100137

101138
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
102-
# currently only support static quant on Vulkan
103-
model = self._annotate_for_static_quantization_config(model)
139+
model = self._annotate_for_quantization_config(model)
104140
propagate_annotation(model)
105141
return model
106142

107-
def _annotate_all_static_patterns(
143+
def _annotate_all_patterns(
108144
self,
109145
model: torch.fx.GraphModule,
110146
quantization_config: Optional[QuantizationConfig],
@@ -117,10 +153,10 @@ def _annotate_all_static_patterns(
117153
OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn)
118154
return model
119155

120-
def _annotate_for_static_quantization_config(
156+
def _annotate_for_quantization_config(
121157
self, model: torch.fx.GraphModule
122158
) -> torch.fx.GraphModule:
123-
self._annotate_all_static_patterns(
159+
self._annotate_all_patterns(
124160
model,
125161
self.global_config,
126162
)

backends/vulkan/quantizer/vulkan_quantizer_utils.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
# pyre-strict
88

9-
from typing import Callable, Optional
9+
from typing import Callable, Optional, Tuple
1010

1111
import torch
1212
from torch.fx import Node
@@ -27,9 +27,26 @@
2727
"OP_TO_ANNOTATOR",
2828
"propagate_annotation",
2929
"_convert_scalars_to_attrs",
30+
"bits_to_range",
3031
]
3132

3233

34+
def bits_to_range(bits: int) -> Tuple[int, int]:
35+
"""
36+
Calculate quantization range for given number of bits.
37+
38+
Args:
39+
bits: Number of quantization bits
40+
41+
Returns:
42+
Tuple of (qmin, qmax) for the given bit width
43+
"""
44+
return (
45+
-(2 ** (bits - 1)),
46+
(2 ** (bits - 1) - 1),
47+
)
48+
49+
3350
AnnotatorType = Callable[
3451
[
3552
torch.fx.GraphModule,

backends/vulkan/test/test_vulkan_passes.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from executorch.backends.vulkan._passes import FuseQuantizedOpsTransform
88

99
from executorch.backends.vulkan.quantizer.vulkan_quantizer import (
10-
get_linear_weight_only_qcs_xnn_qconfig,
10+
get_symmetric_quantization_config,
1111
VulkanQuantizer,
1212
)
1313

@@ -101,7 +101,9 @@ def test_fuse_int8pack_mm(self):
101101
sample_inputs = model.get_sample_inputs()
102102

103103
quantizer = VulkanQuantizer()
104-
quantizer.set_global(get_linear_weight_only_qcs_xnn_qconfig(8))
104+
quantizer.set_global(
105+
get_symmetric_quantization_config(is_dynamic=False, weight_bits=8)
106+
)
105107

106108
edge_manager = quantize_and_lower_module(
107109
model,
@@ -129,7 +131,9 @@ def test_fuse_linear_qcs4w(self):
129131
sample_inputs = model.get_sample_inputs()
130132

131133
quantizer = VulkanQuantizer()
132-
quantizer.set_global(get_linear_weight_only_qcs_xnn_qconfig(4))
134+
quantizer.set_global(
135+
get_symmetric_quantization_config(is_dynamic=False, weight_bits=4)
136+
)
133137

134138
edge_manager = quantize_and_lower_module(
135139
model,

extension/llm/export/quantizer_lib.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import torch
1414
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
15-
get_symmetric_quantization_config,
15+
get_symmetric_quantization_config as get_symmetric_quantization_config_xnnpack,
1616
XNNPACKQuantizer,
1717
)
1818

@@ -127,11 +127,11 @@ def check_embedding_byte_registered():
127127
"At the moment only per channel weight quantization is supported."
128128
)
129129
if quant_params.quantize_linear.is_qc4:
130-
operator_config_dynamic = get_symmetric_quantization_config(
130+
operator_config_dynamic = get_symmetric_quantization_config_xnnpack(
131131
is_per_channel=True, is_dynamic=True, weight_qmin=-8, weight_qmax=7
132132
)
133133
else:
134-
operator_config_dynamic = get_symmetric_quantization_config(
134+
operator_config_dynamic = get_symmetric_quantization_config_xnnpack(
135135
is_per_channel=True, is_dynamic=True
136136
)
137137
dynamic_quantizer.set_global(operator_config_dynamic)
@@ -247,13 +247,13 @@ def get_coreml_quantizer(pt2e_quantize: str):
247247
raise NotImplementedError("4-bit Core ML quantizer is still under development")
248248

249249
elif pt2e_quantize == "coreml_baseline_8a_c8w":
250-
config = get_symmetric_quantization_config(
250+
config = get_symmetric_quantization_config_xnnpack(
251251
is_per_channel=True, is_dynamic=False
252252
)
253253
quantizer = XNNPACKQuantizer().set_global(config)
254254

255255
elif pt2e_quantize == "coreml_baseline_8a_c4w":
256-
config = get_symmetric_quantization_config(
256+
config = get_symmetric_quantization_config_xnnpack(
257257
is_per_channel=True, is_dynamic=False, weight_qmin=-8, weight_qmax=7
258258
)
259259
quantizer = XNNPACKQuantizer().set_global(config)
@@ -266,12 +266,14 @@ def get_coreml_quantizer(pt2e_quantize: str):
266266

267267
def get_vulkan_quantizer(pt2e_quantize: str):
268268
from executorch.backends.vulkan.quantizer.vulkan_quantizer import (
269-
get_linear_weight_only_qcs_xnn_qconfig,
269+
get_symmetric_quantization_config as get_symmetric_quantization_config_vulkan,
270270
VulkanQuantizer,
271271
)
272272

273273
if pt2e_quantize == "vulkan_8w":
274-
config = get_linear_weight_only_qcs_xnn_qconfig(8)
274+
config = get_symmetric_quantization_config_vulkan(
275+
is_dynamic=False, weight_bits=8
276+
)
275277
else:
276278
raise ValueError(f"Unsupported Vulkan quantizer specification {pt2e_quantize}")
277279

0 commit comments

Comments
 (0)