Skip to content

Commit ea5581d

Browse files
committed
[ET-VK] Add custom VkInt4WeightOnlyQuantizer for vulkan
## Context This diff adds the `VkInt4WeightOnlyQuantizer` class which enables 4-bit quantization of linear layers via source transformation. This quantizer class is copied from `torchao.quantization.GPTQ.WeightOnlyInt4Linear` with some minor changes as annotated in the implementation. Note that the pt2e quantization flow does not yet support groupwise quantization, so source transformation is the only way to perform groupwise quantization at the moment. Differential Revision: [D64406457](https://our.internmc.facebook.com/intern/diff/D64406457/) ghstack-source-id: 248113801 Pull Request resolved: #6234
1 parent 179bf40 commit ea5581d

File tree

5 files changed

+317
-0
lines changed

5 files changed

+317
-0
lines changed

backends/vulkan/_passes/TARGETS

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,20 @@ runtime.python_library(
4141
],
4242
)
4343

44+
runtime.python_library(
45+
name = "int4_weight_only_quantizer",
46+
srcs = [
47+
"int4_weight_only_quantizer.py",
48+
],
49+
visibility = [
50+
"//executorch/backends/...",
51+
],
52+
deps = [
53+
":custom_ops_defs",
54+
"//pytorch/ao:torchao",
55+
]
56+
)
57+
4458
runtime.python_library(
4559
name = "vulkan_passes",
4660
srcs = [
@@ -50,6 +64,7 @@ runtime.python_library(
5064
"//executorch/backends/...",
5165
],
5266
deps = [
67+
":int4_weight_only_quantizer",
5368
":remove_local_scalar_dense",
5469
]
5570
)
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1+
from executorch.backends.vulkan._passes.int4_weight_only_quantizer import (
2+
VkInt4WeightOnlyQuantizer,
3+
)
14
from executorch.backends.vulkan._passes.remove_local_scalar_dense_ops import (
25
RemoveLocalScalarDenseOpsTransform,
36
)
47

58
__all__ = [
9+
"VkInt4WeightOnlyQuantizer",
610
"RemoveLocalScalarDenseOpsTransform",
711
]

backends/vulkan/_passes/custom_ops_defs.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@
99
namespace = "et_vk"
1010
lib = torch.library.Library(namespace, "DEF")
1111

12+
#####################
13+
## conv_with_clamp ##
14+
#####################
15+
1216

1317
def conv_with_clamp_impl(
1418
input,
@@ -47,6 +51,10 @@ def conv_with_clamp_impl(
4751
lib.impl(name, conv_with_clamp_impl, "CompositeExplicitAutograd")
4852
conv_with_clamp_op = getattr(getattr(torch.ops, namespace), name)
4953

54+
#########################
55+
## conv_with_clamp.out ##
56+
#########################
57+
5058

5159
def conv_with_clamp_out_impl(
5260
input,
@@ -84,6 +92,10 @@ def conv_with_clamp_out_impl(
8492
)
8593
lib.impl(name, conv_with_clamp_out_impl, "CompositeExplicitAutograd")
8694

95+
#################
96+
## grid_priors ##
97+
#################
98+
8799

88100
# The dimension of x should be larger than 1
89101
def grid_priors_impl(
@@ -125,3 +137,35 @@ def grid_priors_out_impl(
125137
f"{name}(Tensor self, int stride, float offset, *, Tensor(a!) out) -> Tensor(a!)"
126138
)
127139
lib.impl(name, grid_priors_out_impl, "CompositeExplicitAutograd")
140+
141+
########################
142+
## linear_weight_int4 ##
143+
########################
144+
145+
146+
def linear_weight_int4_impl(
147+
x: torch.Tensor,
148+
weights_4x8: torch.Tensor,
149+
groupsize: int,
150+
scales_and_zeros: torch.Tensor,
151+
inner_k_tiles: int,
152+
):
153+
original_x_size = x.size()
154+
out_features = weights_4x8.size(0)
155+
x = x.reshape(-1, original_x_size[-1])
156+
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
157+
weights_4x8, inner_k_tiles
158+
)
159+
out = torch.ops.aten._weight_int4pack_mm(
160+
x, weight_int4pack, groupsize, scales_and_zeros
161+
)
162+
out_shape = original_x_size[:-1] + (out_features,)
163+
return out.reshape(out_shape)
164+
165+
166+
name = "linear_weight_int4"
167+
lib.define(
168+
f"{name}(Tensor self, Tensor mat2, int qGroupSize, Tensor qScaleAndZeros, int inner_k_tiles) -> Tensor"
169+
)
170+
lib.impl(name, linear_weight_int4_impl, "CompositeExplicitAutograd")
171+
linear_weight_int4_op = getattr(getattr(torch.ops, namespace), name)
Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
import logging
2+
from typing import Any, Callable, Dict, Optional, Type
3+
4+
import torch
5+
import torch.nn.functional as F
6+
7+
from torchao.quantization.GPTQ import _check_linear_int4_k
8+
from torchao.quantization.unified import Quantizer
9+
from torchao.quantization.utils import groupwise_affine_quantize_tensor
10+
11+
12+
# This module is copied from torchao.quantization.GPTQ.WeightOnlyInt4Linear with
13+
# changes at the annotated lines.
14+
class VkWeightOnlyInt4Linear(torch.nn.Module):
15+
__constants__ = ["in_features", "out_features"]
16+
in_features: int
17+
out_features: int
18+
weight: torch.Tensor
19+
20+
def __init__(
21+
self,
22+
in_features: int,
23+
out_features: int,
24+
# TODO: remove dtype field, not used
25+
bias=False,
26+
device=None,
27+
dtype=None,
28+
groupsize: int = 128,
29+
inner_k_tiles: int = 8,
30+
precision: torch.dtype = torch.bfloat16,
31+
scales_precision: torch.dtype = torch.bfloat16,
32+
) -> None:
33+
super().__init__()
34+
self.padding = not _check_linear_int4_k(in_features, groupsize, inner_k_tiles)
35+
if self.padding:
36+
from torchao.quantization.utils import find_multiple
37+
38+
self.origin_in_features = in_features
39+
in_features = find_multiple(in_features, (1024,))
40+
41+
self.in_features = in_features
42+
self.out_features = out_features
43+
assert not bias, "require bias=False"
44+
self.device = device
45+
self.groupsize = groupsize
46+
self.inner_k_tiles = inner_k_tiles
47+
self.precision = precision
48+
self.scales_precision = scales_precision
49+
50+
if dtype is not None:
51+
raise ValueError("Please specify 'precision' instead of 'dtype'")
52+
53+
assert out_features % 8 == 0, "require out_features % 8 == 0"
54+
assert (
55+
in_features % (inner_k_tiles * 16) == 0
56+
), "require in_features % (innerKTiles * 16) == 0"
57+
# In the original implementation, the weight buffer is registered with the packed
58+
# sizes, i.e. the result of calling the _convert_weight_to_int4pack operator.
59+
# However, the Vulkan implementation does not expect the weights to be packed
60+
# therefore the weight tensor is registered with the unpacked sizes instead.
61+
# Note that in_features is divided by 2 because each `uint8` tensor element
62+
# contains 2 4-bit packed values.
63+
self.register_buffer(
64+
"weight",
65+
torch.empty(
66+
(out_features, in_features // 2),
67+
dtype=torch.uint8,
68+
device=device,
69+
),
70+
)
71+
self.dtype = dtype
72+
self.register_buffer(
73+
"scales_and_zeros",
74+
torch.empty(
75+
(in_features // groupsize, out_features, 2),
76+
dtype=self.scales_precision,
77+
device=device,
78+
),
79+
)
80+
81+
def forward(self, input: torch.Tensor) -> torch.Tensor:
82+
if self.padding:
83+
input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
84+
# The forward method is replaced. In the original implementation, the forward
85+
# method is torchao.quantization.GPTQ.linear_forward_int4; here a Vulkan custom
86+
# operator is called instead.
87+
return torch.ops.et_vk.linear_weight_int4(
88+
input,
89+
self.weight,
90+
self.groupsize,
91+
self.scales_and_zeros,
92+
self.inner_k_tiles,
93+
)
94+
95+
96+
# This function is coped from torchao.quantization.GPTQ._replace_linear_int4
97+
# with small changes at the annotated locations.
98+
def _vk_replace_linear_int4(
99+
module: torch.nn.Module,
100+
groupsize: int,
101+
inner_k_tiles: Optional[int],
102+
padding_allowed: bool,
103+
skip_layer_func: Optional[Callable] = None,
104+
precision: torch.dtype = torch.bfloat16,
105+
scales_precision: torch.dtype = torch.bfloat16,
106+
# Use custom vulkan linear layer as default
107+
linear_class: Type[torch.nn.Module] = VkWeightOnlyInt4Linear,
108+
copy_weights: bool = False,
109+
# Serves the same purpose as `tensor_dim_limit` in
110+
# executorch.backends.vulkan.partitioner.VulkanSupportedOperators
111+
feature_limit: int = 16384,
112+
):
113+
for name, child in module.named_children():
114+
if isinstance(child, torch.nn.Linear) and (
115+
skip_layer_func is None or not skip_layer_func(child.weight)
116+
):
117+
# Add an additional condition that the out/in features must not exceed the
118+
# `feature_limit` argument.
119+
if (
120+
_check_linear_int4_k(child.in_features, groupsize, inner_k_tiles)
121+
or padding_allowed
122+
) and (
123+
child.out_features < feature_limit and child.in_features < feature_limit
124+
):
125+
new_linear = linear_class(
126+
child.in_features,
127+
child.out_features,
128+
bias=False,
129+
device=child.weight.device,
130+
groupsize=groupsize,
131+
inner_k_tiles=inner_k_tiles,
132+
precision=precision,
133+
scales_precision=scales_precision,
134+
)
135+
if copy_weights and child.weight.device != torch.device("meta"):
136+
new_linear.weight = child.weight
137+
setattr(module, name, new_linear)
138+
else:
139+
_vk_replace_linear_int4(
140+
child,
141+
groupsize,
142+
inner_k_tiles,
143+
padding_allowed,
144+
skip_layer_func,
145+
precision,
146+
scales_precision,
147+
linear_class,
148+
copy_weights,
149+
)
150+
151+
152+
# This module is copied from torchao.quantization.GPTQ.Int4WeightOnlyQuantizer
153+
# with some changes at the annotated lines.
154+
class VkInt4WeightOnlyQuantizer(Quantizer):
155+
def __init__(
156+
self,
157+
groupsize: int = 256,
158+
padding_allowed: bool = True,
159+
inner_k_tiles: Optional[int] = 8,
160+
device: torch.device = torch.device("cpu"), # noqa
161+
precision: torch.dtype = torch.float32,
162+
feature_limit: int = 16384,
163+
) -> None:
164+
super().__init__()
165+
assert inner_k_tiles in [2, 4, 8]
166+
assert groupsize in [32, 64, 128, 256]
167+
168+
self.inner_k_tiles = inner_k_tiles
169+
self.groupsize: int = groupsize
170+
self.padding_allowed: bool = padding_allowed
171+
self.device: torch.device = device
172+
self.precision: torch.dtype = precision
173+
# Serves the same purpose as `tensor_dim_limit` in
174+
# executorch.backends.vulkan.partitioner.VulkanSupportedOperators
175+
self.feature_limit = feature_limit
176+
177+
@torch.no_grad()
178+
def _create_quantized_state_dict(
179+
self, model: torch.nn.Module
180+
) -> Dict[str, torch.Tensor]:
181+
cur_state_dict = model.state_dict()
182+
for fqn, mod in model.named_modules():
183+
# Add additional check to make sure features do not exceed feature limit
184+
if isinstance(mod, torch.nn.Linear) and (
185+
mod.out_features < self.feature_limit
186+
and mod.in_features < self.feature_limit
187+
):
188+
assert not mod.bias
189+
out_features = mod.out_features
190+
in_features = mod.in_features
191+
logging.info(f"linear: {fqn}, in={in_features}, out={out_features}")
192+
193+
assert (
194+
in_features % self.groupsize == 0
195+
), f"require in_features:{in_features} % self.groupsize:{self.groupsize} == 0"
196+
197+
weight = mod.weight.data
198+
if not _check_linear_int4_k(
199+
in_features, self.groupsize, self.inner_k_tiles
200+
):
201+
if self.padding_allowed:
202+
import torch.nn.functional as F
203+
204+
from torchao.quantization.utils import find_multiple
205+
206+
logging.warn(
207+
f"warning: {fqn} is padded to satisfy in_features % 1024 == 0"
208+
)
209+
padded_in_features = find_multiple(in_features, (1024,))
210+
weight = F.pad(
211+
weight, pad=(0, padded_in_features - in_features)
212+
)
213+
else:
214+
logging.warn(
215+
f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, "
216+
+ "and that groupsize and inner_k_tiles*16 evenly divide into it"
217+
)
218+
continue
219+
(w_int4x8, scales_and_zeros) = groupwise_affine_quantize_tensor(
220+
weight,
221+
4, # n_bit
222+
self.groupsize,
223+
self.precision, # dtype for scales_and_zeros
224+
)
225+
# In the original implementation, w_int4x8 is packed via calling the
226+
# _convert_weight_to_int4pack operator before storing the weight. However
227+
# the Vulkan implementation does not expect the weights to be packed, so
228+
# the w_int4x8 tensor is stored as the weight instead.
229+
cur_state_dict[f"{fqn}.weight"] = w_int4x8.to(self.device)
230+
cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to(
231+
self.device
232+
)
233+
return cur_state_dict
234+
235+
def _convert_for_runtime(self, model: torch.nn.Module) -> torch.nn.Module:
236+
_vk_replace_linear_int4(
237+
model,
238+
self.groupsize,
239+
self.inner_k_tiles,
240+
self.padding_allowed,
241+
skip_layer_func=None,
242+
precision=self.precision,
243+
scales_precision=self.precision,
244+
)
245+
return model
246+
247+
def quantize(
248+
self, model: torch.nn.Module, *args: Any, **kwargs: Any
249+
) -> torch.nn.Module:
250+
state_dict = self._create_quantized_state_dict(model)
251+
model = self._convert_for_runtime(model)
252+
model.load_state_dict(state_dict, strict=False)
253+
return model

backends/vulkan/partitioner/supported_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def __contains__(self, op):
8383
exir_ops.edge.aten.mm.default,
8484
exir_ops.edge.aten.addmm.default,
8585
exir_ops.edge.aten.linear.default,
86+
exir_ops.edge.et_vk.linear_weight_int4.default,
8687
# Reduction
8788
exir_ops.edge.aten._log_softmax.default,
8889
exir_ops.edge.aten._softmax.default,

0 commit comments

Comments
 (0)