Skip to content

Commit fcd9a21

Browse files
committed
Update on "[ET-VK] Add custom VkInt4WeightOnlyQuantizer for vulkan"
## Context This diff adds the `VkInt4WeightOnlyQuantizer` class which enables 4-bit quantization of linear layers via source transformation. This quantizer class is copied from `torchao.quantization.GPTQ.WeightOnlyInt4Linear` with some minor changes as annotated in the implementation. Note that the pt2e quantization flow does not yet support groupwise quantization, so source transformation is the only way to perform groupwise quantization at the moment. Differential Revision: [D64406457](https://our.internmc.facebook.com/intern/diff/D64406457/) [ghstack-poisoned]
2 parents 0d56e9a + d62c427 commit fcd9a21

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

backends/vulkan/_passes/int4_weight_only_quantizer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
import torch
55
import torch.nn.functional as F
66

7+
from executorch.backends.vulkan._passes.custom_ops_defs import ( # noqa
8+
linear_weight_int4_op,
9+
)
10+
711
from torchao.quantization.GPTQ import _check_linear_int4_k
812
from torchao.quantization.unified import Quantizer
913
from torchao.quantization.utils import groupwise_affine_quantize_tensor

backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -209,11 +209,7 @@ void add_q_4w_linear_node(
209209
ubos.append(graph.strides_ubo(mat2));
210210
ubos.append(graph.strides_ubo(scales_and_zeros));
211211

212-
auto out_sizes = graph.sizes_of(out);
213-
uint32_t N = utils::val_at(-1, out_sizes);
214-
uint32_t M = utils::val_at(-2, out_sizes);
215-
216-
utils::uvec3 global_wg_size = {N, M, 1};
212+
utils::uvec3 global_wg_size = graph.logical_limits_of(out);
217213
utils::uvec3 local_wg_size = graph.create_local_wg_size(global_wg_size);
218214

219215
graph.execute_nodes().emplace_back(new DispatchNode(

0 commit comments

Comments
 (0)