Skip to content

Commit 85e2e27

Browse files
author
ssjia
committed
[ET-VK] Implemement linear_dq8ta_q4gsw
Pull Request resolved: #14068 Title says it all! Build upon the support for quantized linear introduced in the previous diffs to enable dynamically quantized linear. Also included in this diff is a cleanup of the glslh files used across quantized linear implementations. ghstack-source-id: 308270099 @exported-using-ghexport Differential Revision: [D81931060](https://our.internmc.facebook.com/intern/diff/D81931060/)
1 parent 6c12956 commit 85e2e27

File tree

58 files changed

+1824
-1569
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+1824
-1569
lines changed

.github/workflows/pull.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -933,6 +933,8 @@ jobs:
933933
PYTHON_EXECUTABLE=python bash backends/vulkan/test/custom_ops/build_and_run.sh add
934934
./cmake-out/backends/vulkan/test/custom_ops/q8csw_linear
935935
./cmake-out/backends/vulkan/test/custom_ops/q8csw_conv2d
936+
./cmake-out/backends/vulkan/test/custom_ops/q4gsw_linear
937+
./cmake-out/backends/vulkan/test/custom_ops/choose_qparams_per_row
936938
937939
# Run e2e testing for selected operators. More operators will be tested via this
938940
# route in the future.

backends/vulkan/custom_ops_lib.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,19 @@ def linear_q4gsw(
293293
return out
294294

295295

296+
def linear_dq8ca_q4gsw(
297+
x: torch.Tensor,
298+
input_scale: torch.Tensor,
299+
input_zero_point: torch.Tensor,
300+
weights: torch.Tensor,
301+
weight_sums: torch.Tensor,
302+
weight_scales: torch.Tensor,
303+
group_size: int,
304+
bias: Optional[torch.Tensor] = None,
305+
):
306+
return linear_q4gsw(x, weights, weight_scales, group_size)
307+
308+
296309
name = "linear_q4gsw"
297310
lib.define(
298311
f"""
@@ -307,6 +320,23 @@ def linear_q4gsw(
307320
lib.impl(name, linear_q4gsw, "CompositeExplicitAutograd")
308321
linear_qc4w_op = getattr(getattr(torch.ops, namespace), name)
309322

323+
name = "linear_dq8ca_q4gsw"
324+
lib.define(
325+
f"""
326+
{name}(
327+
Tensor input,
328+
Tensor input_scales,
329+
Tensor input_zp,
330+
Tensor weights,
331+
Tensor weight_sums,
332+
Tensor weight_scales,
333+
int group_size,
334+
Tensor? bias = None) -> Tensor
335+
"""
336+
)
337+
lib.impl(name, linear_dq8ca_q4gsw, "CompositeExplicitAutograd")
338+
linear_dq8ca_q4gsw_op = getattr(getattr(torch.ops, namespace), name)
339+
310340
########################
311341
## linear_qta8a_qga4w ##
312342
########################

backends/vulkan/op_registry.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,8 @@ def register_torchao_choose_qparams_affine():
190190
return OpFeatures(
191191
inputs_storage=utils.CONTIGUOUS_ANY,
192192
outputs_storage=[
193-
utils.CONTIGUOUS_BUFFER, # scales
194-
utils.CONTIGUOUS_BUFFER, # zero_points
193+
utils.WIDTH_PACKED_TEXTURE, # scales
194+
utils.WIDTH_PACKED_TEXTURE, # zero_points
195195
],
196196
supports_resize=True,
197197
)
@@ -341,7 +341,23 @@ def register_quantized_linear_ops():
341341
return OpFeatures(
342342
inputs_storage=utils.CONTIGUOUS_ANY,
343343
supports_prepacking=True,
344-
supports_resize=False,
344+
)
345+
346+
347+
@update_features(exir_ops.edge.et_vk.linear_dq8ca_q4gsw.default)
348+
def register_linear_dqa_qw_ops():
349+
return OpFeatures(
350+
inputs_storage=[
351+
utils.CONTIGUOUS_ANY, # input
352+
utils.WIDTH_PACKED_TEXTURE, # input_scale
353+
utils.WIDTH_PACKED_TEXTURE, # input_zero_point
354+
utils.NO_STORAGE, # weight (prepacked)
355+
utils.NO_STORAGE, # weight_sums (prepacked)
356+
utils.NO_STORAGE, # weight_scales (prepacked)
357+
utils.NO_STORAGE, # group_size (scalar)
358+
utils.NO_STORAGE, # bias (prepacked)
359+
],
360+
supports_prepacking=True,
345361
)
346362

347363

backends/vulkan/patterns/quantized_linear.py

Lines changed: 136 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import operator
8+
79
from typing import Optional
810

911
import executorch.backends.vulkan.utils as utils
@@ -117,8 +119,19 @@ def __init__(self, mm_node: torch.fx.Node) -> None:
117119
self.match_found = True
118120
return
119121

120-
self.input_scales_node = self.quantize_input_node.args[1]
121-
self.input_zeros_node = self.quantize_input_node.args[2]
122+
scales_arg_idx = 1
123+
zeros_arg_idx = 2
124+
125+
# torchao op has a slightly different function schema
126+
if (
127+
self.quantize_input_node.target
128+
== exir_ops.edge.torchao.quantize_affine.default
129+
):
130+
scales_arg_idx = 2
131+
zeros_arg_idx = 3
132+
133+
self.input_scales_node = self.quantize_input_node.args[scales_arg_idx]
134+
self.input_zeros_node = self.quantize_input_node.args[zeros_arg_idx]
122135

123136
assert dq_node is not None
124137
self.all_nodes.extend(
@@ -164,6 +177,24 @@ def is_input_static_per_tensor_quantized(self) -> bool:
164177
# are scalars.
165178
return isinstance(self.input_scales_node, float)
166179

180+
def is_input_dynamic_perchannel_quantized(self) -> bool:
181+
if self.quantize_input_node is None:
182+
return False
183+
184+
# For dynamic quantization, input scale node should be a getitem operator
185+
# retrieving the output of a choose_qparams op
186+
if self.input_scales_node.target != operator.getitem:
187+
return False
188+
189+
# The getitem node should be retrieving from a choose_qparams op
190+
if not utils.is_choose_qparams_node(self.input_scales_node.args[0]):
191+
return False
192+
193+
scales_shape = self.input_scales_node.meta["val"].shape
194+
input_shape = self.fp_input_node.meta["val"].shape
195+
196+
return input_shape[-2] == scales_shape[-1]
197+
167198

168199
linear_anchor_nodes = {
169200
exir_ops.edge.aten.linear.default,
@@ -230,6 +261,34 @@ def pack_4bit_weight_tensor(weight_tensor: torch.Tensor) -> torch.Tensor:
230261
return weight_tensor[::, 1::2] << 4 | weight_tensor[::, ::2]
231262

232263

264+
def compute_per_group_sums(weight_tensor: torch.Tensor, group_size: int):
265+
"""
266+
Compute the sum of weights per quantization group.
267+
268+
Args:
269+
weight_tensor (torch.Tensor): Tensor of shape [out_channels, in_channels], dtype int8.
270+
group_size (int): Number of input channels per quantization group.
271+
272+
Returns:
273+
torch.Tensor: Tensor of shape [num_groups, out_channels], where num_groups = in_channels // group_size.
274+
"""
275+
out_channels, in_channels = weight_tensor.shape
276+
num_groups = in_channels // group_size
277+
# Reshape to [out_channels, num_groups, group_size]
278+
reshaped = weight_tensor.view(out_channels, num_groups, group_size)
279+
# Sum over group_size dimension to get [out_channels, num_groups]
280+
sums = reshaped.sum(dim=2)
281+
# Transpose to [num_groups, out_channels]
282+
sums = sums.transpose(0, 1).contiguous()
283+
# Pad out_channels dim (dim=1) to be a multiple of 8 if needed
284+
out_channels = sums.shape[1]
285+
if out_channels % 8 != 0:
286+
num_pad = 8 - (out_channels % 8)
287+
sums = F.pad(sums, (0, num_pad))
288+
289+
return sums.to(torch.int32).contiguous()
290+
291+
233292
##
234293
## Pattern Replacement
235294
##
@@ -281,6 +340,73 @@ def make_linear_q4gsw_op(
281340
match.output_node.replace_all_uses_with(linear_q4gsw_node)
282341

283342

343+
def make_linear_dq8ca_q4gsw_op(
344+
ep: ExportedProgram,
345+
graph_module: torch.fx.GraphModule,
346+
match: QuantizedLinearMatch,
347+
weight_tensor: torch.Tensor,
348+
weight_scales_tensor: torch.Tensor,
349+
):
350+
num_groups = weight_scales_tensor.shape[-1]
351+
in_channels = weight_tensor.shape[-1]
352+
group_size = in_channels // num_groups
353+
354+
# Compute per quant group sums before packing the weight tensor
355+
sum_per_quant_group = compute_per_group_sums(weight_tensor, group_size)
356+
357+
weight_tensor = pack_4bit_weight_tensor(weight_tensor)
358+
# Use this function for convenience to update the state dict with the packed
359+
# weight tensor. Alignment will already have been done in the above function.
360+
weight_tensor = utils.align_width_and_update_state_dict(
361+
ep, match.weight_node, weight_tensor, align_to=1, force_update=True
362+
)
363+
364+
# Also transpose the weight scales tensor to shape [num_groups, N]
365+
weight_scales_tensor = weight_scales_tensor.transpose(0, 1).contiguous()
366+
utils.align_width_and_update_state_dict(
367+
ep,
368+
match.weight_scales_node,
369+
weight_scales_tensor,
370+
align_to=1,
371+
force_update=True,
372+
)
373+
374+
first_graph_node = list(graph_module.graph.nodes)[0]
375+
with graph_module.graph.inserting_before(first_graph_node):
376+
weight_tensor_name = utils.get_tensor_name(ep, match.weight_node)
377+
# Pre-compute the weight sums which are needed to apply activation zero point
378+
# when using integer accumulation.
379+
sums_name = weight_tensor_name + "_sums"
380+
# Sanitize the name
381+
sums_name = sums_name.replace(".", "_")
382+
383+
weight_sums_node = create_constant_placeholder(
384+
exp_program=ep,
385+
graph=graph_module.graph,
386+
kind=InputKind.CONSTANT_TENSOR,
387+
name=sums_name,
388+
data=sum_per_quant_group,
389+
)
390+
391+
with graph_module.graph.inserting_before(match.output_node):
392+
qlinear_node = graph_module.graph.create_node(
393+
"call_function",
394+
exir_ops.edge.et_vk.linear_dq8ca_q4gsw.default,
395+
args=(
396+
match.fp_input_node,
397+
match.input_scales_node,
398+
match.input_zeros_node,
399+
match.weight_node,
400+
weight_sums_node,
401+
match.weight_scales_node,
402+
group_size,
403+
),
404+
)
405+
406+
qlinear_node.meta["val"] = match.output_node.meta["val"]
407+
match.output_node.replace_all_uses_with(qlinear_node)
408+
409+
284410
def make_linear_q8ta_q8csw_custom_op(
285411
ep: ExportedProgram,
286412
graph_module: torch.fx.GraphModule,
@@ -354,10 +480,16 @@ def replace_quantized_linear_patterns(
354480
make_linear_q4gsw_op(
355481
ep, graph_module, match, weight_tensor, weight_scales_tensor
356482
)
483+
elif (
484+
match.is_input_dynamic_perchannel_quantized()
485+
and match.is_weight_pergroup_quantized()
486+
and utils.is_in_4bit_range(weight_tensor)
487+
):
488+
make_linear_dq8ca_q4gsw_op(
489+
ep, graph_module, match, weight_tensor, weight_scales_tensor
490+
)
357491
elif (
358492
match.is_input_static_per_tensor_quantized()
359493
and match.is_weight_perchannel_quantized()
360494
):
361495
make_linear_q8ta_q8csw_custom_op(ep, graph_module, match, weight_tensor)
362-
363-
# No-op for unsupported quant patterns

backends/vulkan/runtime/graph/ops/DispatchNode.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,13 @@ void DispatchNode::encode(ComputeGraph* graph) {
4444
if (!shader_) {
4545
return;
4646
}
47+
48+
// If any global wg size element is 0, then skip encoding this shader
49+
if (global_workgroup_size_[0] == 0 || global_workgroup_size_[1] == 0 ||
50+
global_workgroup_size_[2] == 0) {
51+
return;
52+
}
53+
4754
api::Context* const context = graph->context();
4855
vkapi::PipelineBarrier pipeline_barrier{};
4956

backends/vulkan/runtime/graph/ops/glsl/addmm_naive_texture3d.glsl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,9 @@ vec4 matmul_naive_k_dim_packed_row_dim_packed(const ivec3 out_lpos) {
136136
const vec4 mat1_tex = texelFetch(mat1_tensor, mat1_pos, 0);
137137

138138
for (int r = 0; r < 4; ++r) {
139+
if (4 * i + r >= mat2_sizes.y) {
140+
continue;
141+
}
139142
// On-demand construction of mat2_pos appears to provide the lowest
140143
// latency. Surprisingly, this doesn't translate to mat1_pos.
141144
ivec3 mat2_pos = ivec3(0);

0 commit comments

Comments
 (0)