Skip to content

Commit fe39f9b

Browse files
committed
[ET-VK] Add custom VkInt4WeightOnlyQuantizer for vulkan
Pull Request resolved: #6234 ## Context This diff adds the `VkInt4WeightOnlyQuantizer` class which enables 4-bit quantization of linear layers via source transformation. This quantizer class is copied from `torchao.quantization.GPTQ.WeightOnlyInt4Linear` with some minor changes as annotated in the implementation. Note that the pt2e quantization flow does not yet support groupwise quantization, so source transformation is the only way to perform groupwise quantization at the moment. ghstack-source-id: 248349848 @exported-using-ghexport Differential Revision: [D64406457](https://our.internmc.facebook.com/intern/diff/D64406457/)
1 parent ef402c9 commit fe39f9b

File tree

5 files changed

+321
-0
lines changed

5 files changed

+321
-0
lines changed

backends/vulkan/_passes/TARGETS

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,20 @@ runtime.python_library(
4141
],
4242
)
4343

44+
runtime.python_library(
45+
name = "int4_weight_only_quantizer",
46+
srcs = [
47+
"int4_weight_only_quantizer.py",
48+
],
49+
visibility = [
50+
"//executorch/backends/...",
51+
],
52+
deps = [
53+
":custom_ops_defs",
54+
"//pytorch/ao:torchao",
55+
]
56+
)
57+
4458
runtime.python_library(
4559
name = "vulkan_passes",
4660
srcs = [
@@ -50,6 +64,7 @@ runtime.python_library(
5064
"//executorch/backends/...",
5165
],
5266
deps = [
67+
":int4_weight_only_quantizer",
5368
":remove_local_scalar_dense",
5469
]
5570
)
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1+
from executorch.backends.vulkan._passes.int4_weight_only_quantizer import (
2+
VkInt4WeightOnlyQuantizer,
3+
)
14
from executorch.backends.vulkan._passes.remove_local_scalar_dense_ops import (
25
RemoveLocalScalarDenseOpsTransform,
36
)
47

58
__all__ = [
9+
"VkInt4WeightOnlyQuantizer",
610
"RemoveLocalScalarDenseOpsTransform",
711
]

backends/vulkan/_passes/custom_ops_defs.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@
99
namespace = "et_vk"
1010
lib = torch.library.Library(namespace, "DEF")
1111

12+
#####################
13+
## conv_with_clamp ##
14+
#####################
15+
1216

1317
def conv_with_clamp_impl(
1418
input,
@@ -47,6 +51,10 @@ def conv_with_clamp_impl(
4751
lib.impl(name, conv_with_clamp_impl, "CompositeExplicitAutograd")
4852
conv_with_clamp_op = getattr(getattr(torch.ops, namespace), name)
4953

54+
#########################
55+
## conv_with_clamp.out ##
56+
#########################
57+
5058

5159
def conv_with_clamp_out_impl(
5260
input,
@@ -84,6 +92,10 @@ def conv_with_clamp_out_impl(
8492
)
8593
lib.impl(name, conv_with_clamp_out_impl, "CompositeExplicitAutograd")
8694

95+
#################
96+
## grid_priors ##
97+
#################
98+
8799

88100
# The dimension of x should be larger than 1
89101
def grid_priors_impl(
@@ -125,3 +137,35 @@ def grid_priors_out_impl(
125137
f"{name}(Tensor self, int stride, float offset, *, Tensor(a!) out) -> Tensor(a!)"
126138
)
127139
lib.impl(name, grid_priors_out_impl, "CompositeExplicitAutograd")
140+
141+
########################
142+
## linear_weight_int4 ##
143+
########################
144+
145+
146+
def linear_weight_int4_impl(
147+
x: torch.Tensor,
148+
weights_4x8: torch.Tensor,
149+
groupsize: int,
150+
scales_and_zeros: torch.Tensor,
151+
inner_k_tiles: int,
152+
):
153+
original_x_size = x.size()
154+
out_features = weights_4x8.size(0)
155+
x = x.reshape(-1, original_x_size[-1])
156+
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
157+
weights_4x8, inner_k_tiles
158+
)
159+
out = torch.ops.aten._weight_int4pack_mm(
160+
x, weight_int4pack, groupsize, scales_and_zeros
161+
)
162+
out_shape = original_x_size[:-1] + (out_features,)
163+
return out.reshape(out_shape)
164+
165+
166+
name = "linear_weight_int4"
167+
lib.define(
168+
f"{name}(Tensor self, Tensor mat2, int qGroupSize, Tensor qScaleAndZeros, int inner_k_tiles) -> Tensor"
169+
)
170+
lib.impl(name, linear_weight_int4_impl, "CompositeExplicitAutograd")
171+
linear_weight_int4_op = getattr(getattr(torch.ops, namespace), name)
Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
import logging
2+
from typing import Any, Callable, Dict, Optional, Type
3+
4+
import torch
5+
import torch.nn.functional as F
6+
7+
from executorch.backends.vulkan._passes.custom_ops_defs import ( # noqa
8+
linear_weight_int4_op,
9+
)
10+
11+
from torchao.quantization.GPTQ import _check_linear_int4_k
12+
from torchao.quantization.unified import Quantizer
13+
from torchao.quantization.utils import groupwise_affine_quantize_tensor
14+
15+
16+
# This module is copied from torchao.quantization.GPTQ.WeightOnlyInt4Linear with
17+
# changes at the annotated lines.
18+
class VkWeightOnlyInt4Linear(torch.nn.Module):
19+
__constants__ = ["in_features", "out_features"]
20+
in_features: int
21+
out_features: int
22+
weight: torch.Tensor
23+
24+
def __init__(
25+
self,
26+
in_features: int,
27+
out_features: int,
28+
# TODO: remove dtype field, not used
29+
bias=False,
30+
device=None,
31+
dtype=None,
32+
groupsize: int = 128,
33+
inner_k_tiles: int = 8,
34+
precision: torch.dtype = torch.bfloat16,
35+
scales_precision: torch.dtype = torch.bfloat16,
36+
) -> None:
37+
super().__init__()
38+
self.padding = not _check_linear_int4_k(in_features, groupsize, inner_k_tiles)
39+
if self.padding:
40+
from torchao.quantization.utils import find_multiple
41+
42+
self.origin_in_features = in_features
43+
in_features = find_multiple(in_features, (1024,))
44+
45+
self.in_features = in_features
46+
self.out_features = out_features
47+
assert not bias, "require bias=False"
48+
self.device = device
49+
self.groupsize = groupsize
50+
self.inner_k_tiles = inner_k_tiles
51+
self.precision = precision
52+
self.scales_precision = scales_precision
53+
54+
if dtype is not None:
55+
raise ValueError("Please specify 'precision' instead of 'dtype'")
56+
57+
assert out_features % 8 == 0, "require out_features % 8 == 0"
58+
assert (
59+
in_features % (inner_k_tiles * 16) == 0
60+
), "require in_features % (innerKTiles * 16) == 0"
61+
# In the original implementation, the weight buffer is registered with the packed
62+
# sizes, i.e. the result of calling the _convert_weight_to_int4pack operator.
63+
# However, the Vulkan implementation does not expect the weights to be packed
64+
# therefore the weight tensor is registered with the unpacked sizes instead.
65+
# Note that in_features is divided by 2 because each `uint8` tensor element
66+
# contains 2 4-bit packed values.
67+
self.register_buffer(
68+
"weight",
69+
torch.empty(
70+
(out_features, in_features // 2),
71+
dtype=torch.uint8,
72+
device=device,
73+
),
74+
)
75+
self.dtype = dtype
76+
self.register_buffer(
77+
"scales_and_zeros",
78+
torch.empty(
79+
(in_features // groupsize, out_features, 2),
80+
dtype=self.scales_precision,
81+
device=device,
82+
),
83+
)
84+
85+
def forward(self, input: torch.Tensor) -> torch.Tensor:
86+
if self.padding:
87+
input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
88+
# The forward method is replaced. In the original implementation, the forward
89+
# method is torchao.quantization.GPTQ.linear_forward_int4; here a Vulkan custom
90+
# operator is called instead.
91+
return torch.ops.et_vk.linear_weight_int4(
92+
input,
93+
self.weight,
94+
self.groupsize,
95+
self.scales_and_zeros,
96+
self.inner_k_tiles,
97+
)
98+
99+
100+
# This function is coped from torchao.quantization.GPTQ._replace_linear_int4
101+
# with small changes at the annotated locations.
102+
def _vk_replace_linear_int4(
103+
module: torch.nn.Module,
104+
groupsize: int,
105+
inner_k_tiles: Optional[int],
106+
padding_allowed: bool,
107+
skip_layer_func: Optional[Callable] = None,
108+
precision: torch.dtype = torch.bfloat16,
109+
scales_precision: torch.dtype = torch.bfloat16,
110+
# Use custom vulkan linear layer as default
111+
linear_class: Type[torch.nn.Module] = VkWeightOnlyInt4Linear,
112+
copy_weights: bool = False,
113+
# Serves the same purpose as `tensor_dim_limit` in
114+
# executorch.backends.vulkan.partitioner.VulkanSupportedOperators
115+
feature_limit: int = 16384,
116+
):
117+
for name, child in module.named_children():
118+
if isinstance(child, torch.nn.Linear) and (
119+
skip_layer_func is None or not skip_layer_func(child.weight)
120+
):
121+
# Add an additional condition that the out/in features must not exceed the
122+
# `feature_limit` argument.
123+
if (
124+
_check_linear_int4_k(child.in_features, groupsize, inner_k_tiles)
125+
or padding_allowed
126+
) and (
127+
child.out_features < feature_limit and child.in_features < feature_limit
128+
):
129+
new_linear = linear_class(
130+
child.in_features,
131+
child.out_features,
132+
bias=False,
133+
device=child.weight.device,
134+
groupsize=groupsize,
135+
inner_k_tiles=inner_k_tiles,
136+
precision=precision,
137+
scales_precision=scales_precision,
138+
)
139+
if copy_weights and child.weight.device != torch.device("meta"):
140+
new_linear.weight = child.weight
141+
setattr(module, name, new_linear)
142+
else:
143+
_vk_replace_linear_int4(
144+
child,
145+
groupsize,
146+
inner_k_tiles,
147+
padding_allowed,
148+
skip_layer_func,
149+
precision,
150+
scales_precision,
151+
linear_class,
152+
copy_weights,
153+
)
154+
155+
156+
# This module is copied from torchao.quantization.GPTQ.Int4WeightOnlyQuantizer
157+
# with some changes at the annotated lines.
158+
class VkInt4WeightOnlyQuantizer(Quantizer):
159+
def __init__(
160+
self,
161+
groupsize: int = 256,
162+
padding_allowed: bool = True,
163+
inner_k_tiles: Optional[int] = 8,
164+
device: torch.device = torch.device("cpu"), # noqa
165+
precision: torch.dtype = torch.float32,
166+
feature_limit: int = 16384,
167+
) -> None:
168+
super().__init__()
169+
assert inner_k_tiles in [2, 4, 8]
170+
assert groupsize in [32, 64, 128, 256]
171+
172+
self.inner_k_tiles = inner_k_tiles
173+
self.groupsize: int = groupsize
174+
self.padding_allowed: bool = padding_allowed
175+
self.device: torch.device = device
176+
self.precision: torch.dtype = precision
177+
# Serves the same purpose as `tensor_dim_limit` in
178+
# executorch.backends.vulkan.partitioner.VulkanSupportedOperators
179+
self.feature_limit = feature_limit
180+
181+
@torch.no_grad()
182+
def _create_quantized_state_dict(
183+
self, model: torch.nn.Module
184+
) -> Dict[str, torch.Tensor]:
185+
cur_state_dict = model.state_dict()
186+
for fqn, mod in model.named_modules():
187+
# Add additional check to make sure features do not exceed feature limit
188+
if isinstance(mod, torch.nn.Linear) and (
189+
mod.out_features < self.feature_limit
190+
and mod.in_features < self.feature_limit
191+
):
192+
assert not mod.bias
193+
out_features = mod.out_features
194+
in_features = mod.in_features
195+
logging.info(f"linear: {fqn}, in={in_features}, out={out_features}")
196+
197+
assert (
198+
in_features % self.groupsize == 0
199+
), f"require in_features:{in_features} % self.groupsize:{self.groupsize} == 0"
200+
201+
weight = mod.weight.data
202+
if not _check_linear_int4_k(
203+
in_features, self.groupsize, self.inner_k_tiles
204+
):
205+
if self.padding_allowed:
206+
import torch.nn.functional as F
207+
208+
from torchao.quantization.utils import find_multiple
209+
210+
logging.warn(
211+
f"warning: {fqn} is padded to satisfy in_features % 1024 == 0"
212+
)
213+
padded_in_features = find_multiple(in_features, (1024,))
214+
weight = F.pad(
215+
weight, pad=(0, padded_in_features - in_features)
216+
)
217+
else:
218+
logging.warn(
219+
f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, "
220+
+ "and that groupsize and inner_k_tiles*16 evenly divide into it"
221+
)
222+
continue
223+
(w_int4x8, scales_and_zeros) = groupwise_affine_quantize_tensor(
224+
weight,
225+
4, # n_bit
226+
self.groupsize,
227+
self.precision, # dtype for scales_and_zeros
228+
)
229+
# In the original implementation, w_int4x8 is packed via calling the
230+
# _convert_weight_to_int4pack operator before storing the weight. However
231+
# the Vulkan implementation does not expect the weights to be packed, so
232+
# the w_int4x8 tensor is stored as the weight instead.
233+
cur_state_dict[f"{fqn}.weight"] = w_int4x8.to(self.device)
234+
cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to(
235+
self.device
236+
)
237+
return cur_state_dict
238+
239+
def _convert_for_runtime(self, model: torch.nn.Module) -> torch.nn.Module:
240+
_vk_replace_linear_int4(
241+
model,
242+
self.groupsize,
243+
self.inner_k_tiles,
244+
self.padding_allowed,
245+
skip_layer_func=None,
246+
precision=self.precision,
247+
scales_precision=self.precision,
248+
)
249+
return model
250+
251+
def quantize(
252+
self, model: torch.nn.Module, *args: Any, **kwargs: Any
253+
) -> torch.nn.Module:
254+
state_dict = self._create_quantized_state_dict(model)
255+
model = self._convert_for_runtime(model)
256+
model.load_state_dict(state_dict, strict=False)
257+
return model

backends/vulkan/partitioner/supported_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def __contains__(self, op):
8383
exir_ops.edge.aten.mm.default,
8484
exir_ops.edge.aten.addmm.default,
8585
exir_ops.edge.aten.linear.default,
86+
exir_ops.edge.et_vk.linear_weight_int4.default,
8687
# Reduction
8788
exir_ops.edge.aten._log_softmax.default,
8889
exir_ops.edge.aten._softmax.default,

0 commit comments

Comments
 (0)