Skip to content

Commit a313d0e

Browse files
committed
Update base for Update on "[Ez] Enable Vulkan 4-bit weight only quantization in export_llama"
As title. Differential Revision: [D64406456](https://our.internmc.facebook.com/intern/diff/D64406456/) [ghstack-poisoned]
1 parent 55a743b commit a313d0e

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

backends/vulkan/_passes/int4_weight_only_quantizer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
import torch
55
import torch.nn.functional as F
66

7+
from executorch.backends.vulkan._passes.custom_ops_defs import ( # noqa
8+
linear_weight_int4_op,
9+
)
10+
711
from torchao.quantization.GPTQ import _check_linear_int4_k
812
from torchao.quantization.unified import Quantizer
913
from torchao.quantization.utils import groupwise_affine_quantize_tensor

backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -209,11 +209,7 @@ void add_q_4w_linear_node(
209209
ubos.append(graph.strides_ubo(mat2));
210210
ubos.append(graph.strides_ubo(scales_and_zeros));
211211

212-
auto out_sizes = graph.sizes_of(out);
213-
uint32_t N = utils::val_at(-1, out_sizes);
214-
uint32_t M = utils::val_at(-2, out_sizes);
215-
216-
utils::uvec3 global_wg_size = {N, M, 1};
212+
utils::uvec3 global_wg_size = graph.logical_limits_of(out);
217213
utils::uvec3 local_wg_size = graph.create_local_wg_size(global_wg_size);
218214

219215
graph.execute_nodes().emplace_back(new DispatchNode(

0 commit comments

Comments
 (0)