Skip to content

Commit 2f4b704

Browse files
authored
[ET-VK] Enable IntxWeightOnlyConfig (#13466)
## Motivation Be able to test Vulkan lowering via optimum-executorch. ## Context Very similar to the below PR, Int4 weight only quantization is currently enabled in Vulkan via a custom source transform quantizer that replaces linear layers with a custom linear layer that calls a custom weight only quantized linear op. This diff aims to make it so that no Vulkan specific source transforms need to be applied by adding a fusion pattern for weight only quantized linear. ## Changes * Introduce a fusable graph pattern for weight only quantized linear * Add fusion logic for weight only quantized linear in the fuse patterns pass * Add `4w` qmode to the export llama script Differential Revision: [D80293302](https://our.internmc.facebook.com/intern/diff/D80293302/) [ghstack-poisoned]
1 parent cf669e3 commit 2f4b704

File tree

7 files changed

+421
-3
lines changed

7 files changed

+421
-3
lines changed

backends/vulkan/patterns/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ runtime.python_library(
99
"__init__.py",
1010
"pattern_registry.py",
1111
"rope.py",
12+
"quantized_linear.py",
1213
],
1314
visibility = [
1415
"//executorch/backends/...",

backends/vulkan/patterns/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
from typing import List
88

9+
import executorch.backends.vulkan.patterns.quantized_linear # noqa
10+
911
import executorch.backends.vulkan.patterns.rope # noqa
1012

1113
import torch
Lines changed: 308 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,308 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from functools import lru_cache
8+
from typing import Callable, List, Optional
9+
10+
import executorch.backends.vulkan.utils as utils
11+
12+
import torch
13+
import torch.nn.functional as F
14+
15+
from executorch.backends.transforms.utils import get_param_tensor, is_param_node
16+
17+
from executorch.backends.vulkan.patterns.pattern_registry import (
18+
register_pattern_graph,
19+
register_pattern_replacement,
20+
)
21+
22+
from executorch.exir import EdgeCompileConfig, ExportedProgram, to_edge
23+
from executorch.exir.dialects._ops import ops as exir_ops
24+
25+
from torch.export import export
26+
from torch.fx.passes.utils.matcher_utils import InternalMatch
27+
28+
from torchao.quantization.granularity import PerGroup
29+
from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_
30+
from torchao.utils import unwrap_tensor_subclass
31+
32+
33+
class TorchAOWeightOnlyQuantizedLinearPattern(torch.nn.Module):
34+
"""
35+
Quantized linear pattern produced when quantizing linear layers using
36+
`torchao.quantization.quant_api.quantize_()` with IntxWeightOnlyConfig.
37+
"""
38+
39+
def __init__(
40+
self,
41+
in_features: int = 512,
42+
out_features: int = 256,
43+
bias: bool = False,
44+
group_size: int = 64,
45+
weight_bits: int = 4,
46+
granularity_class: Optional[Callable] = None,
47+
) -> None:
48+
super().__init__()
49+
self.linear = torch.nn.Linear(in_features, out_features, bias=bias)
50+
self.group_size = group_size
51+
self.weight_bits = weight_bits
52+
53+
if self.weight_bits == 4:
54+
# pyre-ignore[16]
55+
self.weight_dtype = torch.int4
56+
else:
57+
self.weight_dtype = torch.int8
58+
59+
if granularity_class is not None:
60+
self.quant_granularity = granularity_class(self.group_size)
61+
else:
62+
self.quant_granularity = PerGroup(self.group_size)
63+
64+
def forward(self, x: torch.Tensor) -> torch.Tensor:
65+
return self.linear(x)
66+
67+
def apply_quantization(self):
68+
q_config = IntxWeightOnlyConfig(
69+
weight_dtype=self.weight_dtype,
70+
granularity=self.quant_granularity,
71+
)
72+
quantize_(self, q_config)
73+
unwrap_tensor_subclass(self)
74+
return self
75+
76+
77+
@lru_cache(maxsize=None)
78+
@register_pattern_graph("torchao_wo_quantized_linear")
79+
def get_torchao_wo_quantized_linear_graphs() -> List[torch.fx.GraphModule]:
80+
graphs = []
81+
82+
# Different configurations to test
83+
configs = [
84+
# gemv pattern
85+
(1, 1, 128, 128, False, 64, 4, PerGroup),
86+
# gemm pattern
87+
(1, 8, 128, 128, False, 64, 4, PerGroup),
88+
]
89+
90+
for (
91+
batch_size,
92+
seq_len,
93+
in_features,
94+
out_features,
95+
bias,
96+
group_size,
97+
weight_bits,
98+
granularity_class,
99+
) in configs:
100+
for dtype in [torch.float32]:
101+
xs = []
102+
xs.append(torch.randn(batch_size, seq_len, in_features, dtype=dtype))
103+
if batch_size == 1:
104+
xs.append(torch.randn(seq_len, in_features, dtype=dtype))
105+
106+
for x in xs:
107+
# Create and quantize the pattern
108+
pattern = TorchAOWeightOnlyQuantizedLinearPattern(
109+
in_features=in_features,
110+
out_features=out_features,
111+
bias=bias,
112+
group_size=group_size,
113+
weight_bits=weight_bits,
114+
granularity_class=granularity_class,
115+
)
116+
117+
# Apply quantization
118+
pattern = pattern.apply_quantization()
119+
120+
# Export the quantized pattern
121+
edge = to_edge(
122+
export(
123+
pattern,
124+
(x,),
125+
),
126+
compile_config=EdgeCompileConfig(_check_ir_validity=False),
127+
)
128+
gm = edge.exported_program().graph_module
129+
graphs.append(gm)
130+
131+
return graphs
132+
133+
134+
def pack_4bit_weight_tensor(inp: torch.Tensor) -> torch.Tensor:
135+
"""
136+
Given a 8-bit weight tensor containing values quantized to 4 bits, create a packed
137+
weight tensor by packing 2 4-bit values in one unsigned 8-bit value.
138+
139+
An input weight tensor of shape (M, K) will produce a packed weight tensor of shape
140+
(M, K / 2).
141+
142+
The packing implemented here is the same as the packing produced by
143+
backends/vulkan/_passes/int4_weight_only_quantizer.py
144+
"""
145+
146+
# Assert we got a properly quantized tensor.
147+
min, max = inp.min().item(), inp.max().item()
148+
assert (
149+
max <= 7 and min >= -8
150+
), f"pack_4bit_weight_tensor: [min,max] out of [-8, 7] range, got [{min}, {max}]"
151+
152+
# Assuming we have a 2d tensor
153+
if inp.ndim != 2:
154+
inp = inp.squeeze()
155+
assert (
156+
inp.ndim == 2
157+
), f"pack_4bit_weight_tensor: expecting input tensor to be 2d, got {inp.ndim}"
158+
159+
# pad ic
160+
if inp.shape[-1] % 2 != 0:
161+
inp = F.pad(input=inp, pad=(0, 1, 0, 0), mode="constant", value=0)
162+
163+
# Shape after padding
164+
oc, ic = inp.shape
165+
assert ic % 2 == 0, "convert_to_qc4w: expecting ic to be even"
166+
167+
# Adjust inp tensor for zp
168+
inp = inp.to(dtype=torch.uint8) + 8
169+
# Pack each 4-bit value into a single 8-bit value
170+
return inp[::, ::2] << 4 | inp[::, 1::2]
171+
172+
173+
def make_combined_scales_and_zeros_tensor(
174+
scales: torch.Tensor, zeros: torch.Tensor
175+
) -> torch.Tensor:
176+
"""
177+
Given a scales and zeros tensor, create a combined tensor by stacking them into a
178+
single tensor.
179+
180+
The scales and zeros tensors are expected to be 2D tensors of shape
181+
(OUTPUT_CHANNELS, NUM_GROUPS). The combined tensor will have the shape
182+
(NUM_GROUPS, OUTPUT_CHANNELS, 2).
183+
184+
This is the scales and zeros format produced by
185+
backends/vulkan/_passes/int4_weight_only_quantizer.py, which in turn is the scales
186+
and zeros format expected by the _weight_int4pack_mm op in ATen.
187+
"""
188+
scales_reshaped = scales.transpose(0, 1).unsqueeze(2)
189+
zeros_reshaped = zeros.transpose(0, 1).unsqueeze(2)
190+
191+
zeros_scaled = zeros_reshaped * scales_reshaped * -1
192+
return torch.cat((scales_reshaped, zeros_scaled), dim=2)
193+
194+
195+
def identify_wo_quantized_linear_io_nodes( # noqa: C901
196+
ep: ExportedProgram,
197+
graph_module: torch.fx.GraphModule,
198+
match: InternalMatch,
199+
) -> Optional[List[torch.fx.Node]]:
200+
dequant_node = None
201+
# First, find the dequant node
202+
for node in match.nodes_map.values():
203+
if utils.is_dequant_node(node):
204+
dequant_node = node
205+
break
206+
207+
if dequant_node is None:
208+
return None
209+
210+
quantized_weight = dequant_node.args[0]
211+
quant_scales = dequant_node.args[2]
212+
quant_zeros = dequant_node.args[3]
213+
214+
if not isinstance(quantized_weight, torch.fx.Node) or not is_param_node(
215+
ep, quantized_weight
216+
):
217+
return None
218+
if not isinstance(quant_scales, torch.fx.Node) or not is_param_node(
219+
ep, quant_scales
220+
):
221+
return None
222+
if not isinstance(quant_zeros, torch.fx.Node) or not is_param_node(ep, quant_zeros):
223+
return None
224+
225+
input_nodes = match.placeholder_nodes
226+
if len(input_nodes) != 4:
227+
return None
228+
229+
in_tensor_node = None
230+
for node in input_nodes:
231+
if node not in dequant_node.args:
232+
in_tensor_node = node
233+
break
234+
235+
if in_tensor_node is None:
236+
return None
237+
238+
output_nodes = match.returning_nodes
239+
240+
if len(output_nodes) != 1:
241+
return None
242+
243+
out_tensor_node = output_nodes[0]
244+
if not isinstance(out_tensor_node, torch.fx.Node):
245+
return None
246+
247+
return [
248+
in_tensor_node,
249+
quantized_weight,
250+
quant_scales,
251+
quant_zeros,
252+
out_tensor_node,
253+
]
254+
255+
256+
# wo = "weight only"
257+
@register_pattern_replacement("torchao_wo_quantized_linear")
258+
def create_wo_quantized_linear_custom_op(
259+
ep: ExportedProgram,
260+
graph_module: torch.fx.GraphModule,
261+
match: InternalMatch,
262+
):
263+
io_nodes = identify_wo_quantized_linear_io_nodes(ep, graph_module, match)
264+
if io_nodes is None:
265+
return
266+
267+
assert len(io_nodes) == 5
268+
in_tensor, quantized_weight, quant_scales, quant_zeros, out_tensor = io_nodes
269+
270+
quantized_weight_tensor = get_param_tensor(ep, quantized_weight)
271+
if not isinstance(quantized_weight_tensor, torch.Tensor):
272+
return
273+
packed_quantized_weight_tensor = pack_4bit_weight_tensor(quantized_weight_tensor)
274+
utils.update_program_state_dict(
275+
ep, quantized_weight.name, packed_quantized_weight_tensor
276+
)
277+
quantized_weight.meta["val"] = quantized_weight.meta["val"][:, ::2].to(torch.uint8)
278+
279+
quant_scales_tensor = get_param_tensor(ep, quant_scales)
280+
quant_zeros_tensor = get_param_tensor(ep, quant_zeros)
281+
282+
assert quantized_weight_tensor is not None
283+
assert quant_scales_tensor is not None
284+
assert quant_zeros_tensor is not None
285+
286+
group_size = quantized_weight_tensor.shape[1] // quant_scales_tensor.shape[1]
287+
288+
combined_scales_zeros_tensor = make_combined_scales_and_zeros_tensor(
289+
quant_scales_tensor, quant_zeros_tensor
290+
)
291+
292+
combined_scales_zeros_name = f"{quantized_weight.name}_scales_zeros"
293+
graph_module.register_parameter(
294+
combined_scales_zeros_name, torch.nn.Parameter(combined_scales_zeros_tensor)
295+
)
296+
297+
with graph_module.graph.inserting_before(out_tensor):
298+
combined_scales_zeros = graph_module.graph.get_attr(combined_scales_zeros_name)
299+
wo_qlinear = graph_module.graph.create_node(
300+
"call_function",
301+
exir_ops.edge.et_vk.linear_weight_int4.default,
302+
args=(in_tensor, quantized_weight, group_size, combined_scales_zeros, 1),
303+
)
304+
305+
if hasattr(out_tensor, "meta") and "val" in out_tensor.meta:
306+
wo_qlinear.meta["val"] = out_tensor.meta["val"]
307+
308+
out_tensor.replace_all_uses_with(wo_qlinear)

0 commit comments

Comments
 (0)