Skip to content

Commit 1882ea1

Browse files
authored
[ET-VK] Implemement linear_dq8ta_q4gsw (#14098)
Title says it all! Build upon the support for quantized linear introduced in the previous diffs to enable dynamically quantized linear. Also included in this diff is a cleanup of the glslh files used across quantized linear implementations. Differential Revision: [D81931060](https://our.internmc.facebook.com/intern/diff/D81931060/)
1 parent d79ef5b commit 1882ea1

File tree

58 files changed

+1828
-1569
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+1828
-1569
lines changed

.github/workflows/pull.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -933,10 +933,13 @@ jobs:
933933
PYTHON_EXECUTABLE=python bash backends/vulkan/test/custom_ops/build_and_run.sh add
934934
./cmake-out/backends/vulkan/test/custom_ops/q8csw_linear
935935
./cmake-out/backends/vulkan/test/custom_ops/q8csw_conv2d
936+
./cmake-out/backends/vulkan/test/custom_ops/q4gsw_linear
937+
./cmake-out/backends/vulkan/test/custom_ops/choose_qparams_per_row
936938
937939
# Run e2e testing for selected operators. More operators will be tested via this
938940
# route in the future.
939941
python -m unittest backends/vulkan/test/test_vulkan_delegate.py -k "*pt2e*"
942+
python -m unittest backends/vulkan/test/test_vulkan_delegate.py -k "*torchao*"
940943
941944
nxp-build-test:
942945
name: nxp-build-test

backends/vulkan/custom_ops_lib.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,19 @@ def linear_q4gsw(
293293
return out
294294

295295

296+
def linear_dq8ca_q4gsw(
297+
x: torch.Tensor,
298+
input_scale: torch.Tensor,
299+
input_zero_point: torch.Tensor,
300+
weights: torch.Tensor,
301+
weight_sums: torch.Tensor,
302+
weight_scales: torch.Tensor,
303+
group_size: int,
304+
bias: Optional[torch.Tensor] = None,
305+
):
306+
return linear_q4gsw(x, weights, weight_scales, group_size)
307+
308+
296309
name = "linear_q4gsw"
297310
lib.define(
298311
f"""
@@ -307,6 +320,23 @@ def linear_q4gsw(
307320
lib.impl(name, linear_q4gsw, "CompositeExplicitAutograd")
308321
linear_qc4w_op = getattr(getattr(torch.ops, namespace), name)
309322

323+
name = "linear_dq8ca_q4gsw"
324+
lib.define(
325+
f"""
326+
{name}(
327+
Tensor input,
328+
Tensor input_scales,
329+
Tensor input_zp,
330+
Tensor weights,
331+
Tensor weight_sums,
332+
Tensor weight_scales,
333+
int group_size,
334+
Tensor? bias = None) -> Tensor
335+
"""
336+
)
337+
lib.impl(name, linear_dq8ca_q4gsw, "CompositeExplicitAutograd")
338+
linear_dq8ca_q4gsw_op = getattr(getattr(torch.ops, namespace), name)
339+
310340
########################
311341
## linear_qta8a_qga4w ##
312342
########################

backends/vulkan/op_registry.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,8 @@ def register_torchao_choose_qparams_affine():
190190
return OpFeatures(
191191
inputs_storage=utils.CONTIGUOUS_ANY,
192192
outputs_storage=[
193-
utils.CONTIGUOUS_BUFFER, # scales
194-
utils.CONTIGUOUS_BUFFER, # zero_points
193+
utils.WIDTH_PACKED_TEXTURE, # scales
194+
utils.WIDTH_PACKED_TEXTURE, # zero_points
195195
],
196196
supports_resize=True,
197197
)
@@ -341,7 +341,23 @@ def register_quantized_linear_ops():
341341
return OpFeatures(
342342
inputs_storage=utils.CONTIGUOUS_ANY,
343343
supports_prepacking=True,
344-
supports_resize=False,
344+
)
345+
346+
347+
@update_features(exir_ops.edge.et_vk.linear_dq8ca_q4gsw.default)
348+
def register_linear_dqa_qw_ops():
349+
return OpFeatures(
350+
inputs_storage=[
351+
utils.CONTIGUOUS_ANY, # input
352+
utils.WIDTH_PACKED_TEXTURE, # input_scale
353+
utils.WIDTH_PACKED_TEXTURE, # input_zero_point
354+
utils.NO_STORAGE, # weight (prepacked)
355+
utils.NO_STORAGE, # weight_sums (prepacked)
356+
utils.NO_STORAGE, # weight_scales (prepacked)
357+
utils.NO_STORAGE, # group_size (scalar)
358+
utils.NO_STORAGE, # bias (prepacked)
359+
],
360+
supports_prepacking=True,
345361
)
346362

347363

backends/vulkan/patterns/quantized_linear.py

Lines changed: 139 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import operator
8+
79
from typing import Optional
810

911
import executorch.backends.vulkan.utils as utils
@@ -117,8 +119,19 @@ def __init__(self, mm_node: torch.fx.Node) -> None:
117119
self.match_found = True
118120
return
119121

120-
self.input_scales_node = self.quantize_input_node.args[1]
121-
self.input_zeros_node = self.quantize_input_node.args[2]
122+
scales_arg_idx = 1
123+
zeros_arg_idx = 2
124+
125+
# torchao op has a slightly different function schema
126+
if (
127+
self.quantize_input_node.target
128+
== exir_ops.edge.torchao.quantize_affine.default
129+
):
130+
scales_arg_idx = 2
131+
zeros_arg_idx = 3
132+
133+
self.input_scales_node = self.quantize_input_node.args[scales_arg_idx]
134+
self.input_zeros_node = self.quantize_input_node.args[zeros_arg_idx]
122135

123136
assert dq_node is not None
124137
self.all_nodes.extend(
@@ -164,6 +177,27 @@ def is_input_static_per_tensor_quantized(self) -> bool:
164177
# are scalars.
165178
return isinstance(self.input_scales_node, float)
166179

180+
def is_input_dynamic_perchannel_quantized(self) -> bool:
181+
if self.quantize_input_node is None:
182+
return False
183+
184+
if not isinstance(self.input_scales_node, torch.fx.Node):
185+
return False
186+
187+
# For dynamic quantization, input scale node should be a getitem operator
188+
# retrieving the output of a choose_qparams op
189+
if self.input_scales_node.target != operator.getitem:
190+
return False
191+
192+
# The getitem node should be retrieving from a choose_qparams op
193+
if not utils.is_choose_qparams_node(self.input_scales_node.args[0]):
194+
return False
195+
196+
scales_shape = self.input_scales_node.meta["val"].shape
197+
input_shape = self.fp_input_node.meta["val"].shape
198+
199+
return input_shape[-2] == scales_shape[-1]
200+
167201

168202
linear_anchor_nodes = {
169203
exir_ops.edge.aten.linear.default,
@@ -230,6 +264,34 @@ def pack_4bit_weight_tensor(weight_tensor: torch.Tensor) -> torch.Tensor:
230264
return weight_tensor[::, 1::2] << 4 | weight_tensor[::, ::2]
231265

232266

267+
def compute_per_group_sums(weight_tensor: torch.Tensor, group_size: int):
268+
"""
269+
Compute the sum of weights per quantization group.
270+
271+
Args:
272+
weight_tensor (torch.Tensor): Tensor of shape [out_channels, in_channels], dtype int8.
273+
group_size (int): Number of input channels per quantization group.
274+
275+
Returns:
276+
torch.Tensor: Tensor of shape [num_groups, out_channels], where num_groups = in_channels // group_size.
277+
"""
278+
out_channels, in_channels = weight_tensor.shape
279+
num_groups = in_channels // group_size
280+
# Reshape to [out_channels, num_groups, group_size]
281+
reshaped = weight_tensor.view(out_channels, num_groups, group_size)
282+
# Sum over group_size dimension to get [out_channels, num_groups]
283+
sums = reshaped.sum(dim=2)
284+
# Transpose to [num_groups, out_channels]
285+
sums = sums.transpose(0, 1).contiguous()
286+
# Pad out_channels dim (dim=1) to be a multiple of 8 if needed
287+
out_channels = sums.shape[1]
288+
if out_channels % 8 != 0:
289+
num_pad = 8 - (out_channels % 8)
290+
sums = F.pad(sums, (0, num_pad))
291+
292+
return sums.to(torch.int32).contiguous()
293+
294+
233295
##
234296
## Pattern Replacement
235297
##
@@ -281,6 +343,73 @@ def make_linear_q4gsw_op(
281343
match.output_node.replace_all_uses_with(linear_q4gsw_node)
282344

283345

346+
def make_linear_dq8ca_q4gsw_op(
347+
ep: ExportedProgram,
348+
graph_module: torch.fx.GraphModule,
349+
match: QuantizedLinearMatch,
350+
weight_tensor: torch.Tensor,
351+
weight_scales_tensor: torch.Tensor,
352+
):
353+
num_groups = weight_scales_tensor.shape[-1]
354+
in_channels = weight_tensor.shape[-1]
355+
group_size = in_channels // num_groups
356+
357+
# Compute per quant group sums before packing the weight tensor
358+
sum_per_quant_group = compute_per_group_sums(weight_tensor, group_size)
359+
360+
weight_tensor = pack_4bit_weight_tensor(weight_tensor)
361+
# Use this function for convenience to update the state dict with the packed
362+
# weight tensor. Alignment will already have been done in the above function.
363+
weight_tensor = utils.align_width_and_update_state_dict(
364+
ep, match.weight_node, weight_tensor, align_to=1, force_update=True
365+
)
366+
367+
# Also transpose the weight scales tensor to shape [num_groups, N]
368+
weight_scales_tensor = weight_scales_tensor.transpose(0, 1).contiguous()
369+
utils.align_width_and_update_state_dict(
370+
ep,
371+
match.weight_scales_node,
372+
weight_scales_tensor,
373+
align_to=1,
374+
force_update=True,
375+
)
376+
377+
first_graph_node = list(graph_module.graph.nodes)[0]
378+
with graph_module.graph.inserting_before(first_graph_node):
379+
weight_tensor_name = utils.get_tensor_name(ep, match.weight_node)
380+
# Pre-compute the weight sums which are needed to apply activation zero point
381+
# when using integer accumulation.
382+
sums_name = weight_tensor_name + "_sums"
383+
# Sanitize the name
384+
sums_name = sums_name.replace(".", "_")
385+
386+
weight_sums_node = create_constant_placeholder(
387+
exp_program=ep,
388+
graph=graph_module.graph,
389+
kind=InputKind.CONSTANT_TENSOR,
390+
name=sums_name,
391+
data=sum_per_quant_group,
392+
)
393+
394+
with graph_module.graph.inserting_before(match.output_node):
395+
qlinear_node = graph_module.graph.create_node(
396+
"call_function",
397+
exir_ops.edge.et_vk.linear_dq8ca_q4gsw.default,
398+
args=(
399+
match.fp_input_node,
400+
match.input_scales_node,
401+
match.input_zeros_node,
402+
match.weight_node,
403+
weight_sums_node,
404+
match.weight_scales_node,
405+
group_size,
406+
),
407+
)
408+
409+
qlinear_node.meta["val"] = match.output_node.meta["val"]
410+
match.output_node.replace_all_uses_with(qlinear_node)
411+
412+
284413
def make_linear_q8ta_q8csw_custom_op(
285414
ep: ExportedProgram,
286415
graph_module: torch.fx.GraphModule,
@@ -354,10 +483,16 @@ def replace_quantized_linear_patterns(
354483
make_linear_q4gsw_op(
355484
ep, graph_module, match, weight_tensor, weight_scales_tensor
356485
)
486+
elif (
487+
match.is_input_dynamic_perchannel_quantized()
488+
and match.is_weight_pergroup_quantized()
489+
and utils.is_in_4bit_range(weight_tensor)
490+
):
491+
make_linear_dq8ca_q4gsw_op(
492+
ep, graph_module, match, weight_tensor, weight_scales_tensor
493+
)
357494
elif (
358495
match.is_input_static_per_tensor_quantized()
359496
and match.is_weight_perchannel_quantized()
360497
):
361498
make_linear_q8ta_q8csw_custom_op(ep, graph_module, match, weight_tensor)
362-
363-
# No-op for unsupported quant patterns

backends/vulkan/runtime/graph/ops/DispatchNode.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,13 @@ void DispatchNode::encode(ComputeGraph* graph) {
4444
if (!shader_) {
4545
return;
4646
}
47+
48+
// If any global wg size element is 0, then skip encoding this shader
49+
if (global_workgroup_size_[0] == 0 || global_workgroup_size_[1] == 0 ||
50+
global_workgroup_size_[2] == 0) {
51+
return;
52+
}
53+
4754
api::Context* const context = graph->context();
4855
vkapi::PipelineBarrier pipeline_barrier{};
4956

backends/vulkan/runtime/graph/ops/glsl/addmm_naive_texture3d.glsl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,9 @@ vec4 matmul_naive_k_dim_packed_row_dim_packed(const ivec3 out_lpos) {
136136
const vec4 mat1_tex = texelFetch(mat1_tensor, mat1_pos, 0);
137137

138138
for (int r = 0; r < 4; ++r) {
139+
if (4 * i + r >= mat2_sizes.y) {
140+
continue;
141+
}
139142
// On-demand construction of mat2_pos appears to provide the lowest
140143
// latency. Surprisingly, this doesn't translate to mat1_pos.
141144
ivec3 mat2_pos = ivec3(0);

0 commit comments

Comments
 (0)