Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/pull.yml
Original file line number Diff line number Diff line change
Expand Up @@ -933,6 +933,8 @@ jobs:
PYTHON_EXECUTABLE=python bash backends/vulkan/test/custom_ops/build_and_run.sh add
./cmake-out/backends/vulkan/test/custom_ops/q8csw_linear
./cmake-out/backends/vulkan/test/custom_ops/q8csw_conv2d
./cmake-out/backends/vulkan/test/custom_ops/q4gsw_linear
./cmake-out/backends/vulkan/test/custom_ops/choose_qparams_per_row

# Run e2e testing for selected operators. More operators will be tested via this
# route in the future.
Expand Down
30 changes: 30 additions & 0 deletions backends/vulkan/custom_ops_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,19 @@ def linear_q4gsw(
return out


def linear_dq8ca_q4gsw(
x: torch.Tensor,
input_scale: torch.Tensor,
input_zero_point: torch.Tensor,
weights: torch.Tensor,
weight_sums: torch.Tensor,
weight_scales: torch.Tensor,
group_size: int,
bias: Optional[torch.Tensor] = None,
):
return linear_q4gsw(x, weights, weight_scales, group_size)


name = "linear_q4gsw"
lib.define(
f"""
Expand All @@ -307,6 +320,23 @@ def linear_q4gsw(
lib.impl(name, linear_q4gsw, "CompositeExplicitAutograd")
linear_qc4w_op = getattr(getattr(torch.ops, namespace), name)

name = "linear_dq8ca_q4gsw"
lib.define(
f"""
{name}(
Tensor input,
Tensor input_scales,
Tensor input_zp,
Tensor weights,
Tensor weight_sums,
Tensor weight_scales,
int group_size,
Tensor? bias = None) -> Tensor
"""
)
lib.impl(name, linear_dq8ca_q4gsw, "CompositeExplicitAutograd")
linear_dq8ca_q4gsw_op = getattr(getattr(torch.ops, namespace), name)

########################
## linear_qta8a_qga4w ##
########################
Expand Down
22 changes: 19 additions & 3 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,8 @@ def register_torchao_choose_qparams_affine():
return OpFeatures(
inputs_storage=utils.CONTIGUOUS_ANY,
outputs_storage=[
utils.CONTIGUOUS_BUFFER, # scales
utils.CONTIGUOUS_BUFFER, # zero_points
utils.WIDTH_PACKED_TEXTURE, # scales
utils.WIDTH_PACKED_TEXTURE, # zero_points
],
supports_resize=True,
)
Expand Down Expand Up @@ -341,7 +341,23 @@ def register_quantized_linear_ops():
return OpFeatures(
inputs_storage=utils.CONTIGUOUS_ANY,
supports_prepacking=True,
supports_resize=False,
)


@update_features(exir_ops.edge.et_vk.linear_dq8ca_q4gsw.default)
def register_linear_dqa_qw_ops():
return OpFeatures(
inputs_storage=[
utils.CONTIGUOUS_ANY, # input
utils.WIDTH_PACKED_TEXTURE, # input_scale
utils.WIDTH_PACKED_TEXTURE, # input_zero_point
utils.NO_STORAGE, # weight (prepacked)
utils.NO_STORAGE, # weight_sums (prepacked)
utils.NO_STORAGE, # weight_scales (prepacked)
utils.NO_STORAGE, # group_size (scalar)
utils.NO_STORAGE, # bias (prepacked)
],
supports_prepacking=True,
)


Expand Down
143 changes: 139 additions & 4 deletions backends/vulkan/patterns/quantized_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import operator

from typing import Optional

import executorch.backends.vulkan.utils as utils
Expand Down Expand Up @@ -117,8 +119,19 @@ def __init__(self, mm_node: torch.fx.Node) -> None:
self.match_found = True
return

self.input_scales_node = self.quantize_input_node.args[1]
self.input_zeros_node = self.quantize_input_node.args[2]
scales_arg_idx = 1
zeros_arg_idx = 2

# torchao op has a slightly different function schema
if (
self.quantize_input_node.target
== exir_ops.edge.torchao.quantize_affine.default
):
scales_arg_idx = 2
zeros_arg_idx = 3

self.input_scales_node = self.quantize_input_node.args[scales_arg_idx]
self.input_zeros_node = self.quantize_input_node.args[zeros_arg_idx]

assert dq_node is not None
self.all_nodes.extend(
Expand Down Expand Up @@ -164,6 +177,27 @@ def is_input_static_per_tensor_quantized(self) -> bool:
# are scalars.
return isinstance(self.input_scales_node, float)

def is_input_dynamic_perchannel_quantized(self) -> bool:
if self.quantize_input_node is None:
return False

if not isinstance(self.input_scales_node, torch.fx.Node):
return False

# For dynamic quantization, input scale node should be a getitem operator
# retrieving the output of a choose_qparams op
if self.input_scales_node.target != operator.getitem:
return False

# The getitem node should be retrieving from a choose_qparams op
if not utils.is_choose_qparams_node(self.input_scales_node.args[0]):
return False

scales_shape = self.input_scales_node.meta["val"].shape
input_shape = self.fp_input_node.meta["val"].shape

return input_shape[-2] == scales_shape[-1]


linear_anchor_nodes = {
exir_ops.edge.aten.linear.default,
Expand Down Expand Up @@ -230,6 +264,34 @@ def pack_4bit_weight_tensor(weight_tensor: torch.Tensor) -> torch.Tensor:
return weight_tensor[::, 1::2] << 4 | weight_tensor[::, ::2]


def compute_per_group_sums(weight_tensor: torch.Tensor, group_size: int):
"""
Compute the sum of weights per quantization group.

Args:
weight_tensor (torch.Tensor): Tensor of shape [out_channels, in_channels], dtype int8.
group_size (int): Number of input channels per quantization group.

Returns:
torch.Tensor: Tensor of shape [num_groups, out_channels], where num_groups = in_channels // group_size.
"""
out_channels, in_channels = weight_tensor.shape
num_groups = in_channels // group_size
# Reshape to [out_channels, num_groups, group_size]
reshaped = weight_tensor.view(out_channels, num_groups, group_size)
# Sum over group_size dimension to get [out_channels, num_groups]
sums = reshaped.sum(dim=2)
# Transpose to [num_groups, out_channels]
sums = sums.transpose(0, 1).contiguous()
# Pad out_channels dim (dim=1) to be a multiple of 8 if needed
out_channels = sums.shape[1]
if out_channels % 8 != 0:
num_pad = 8 - (out_channels % 8)
sums = F.pad(sums, (0, num_pad))

return sums.to(torch.int32).contiguous()


##
## Pattern Replacement
##
Expand Down Expand Up @@ -281,6 +343,73 @@ def make_linear_q4gsw_op(
match.output_node.replace_all_uses_with(linear_q4gsw_node)


def make_linear_dq8ca_q4gsw_op(
ep: ExportedProgram,
graph_module: torch.fx.GraphModule,
match: QuantizedLinearMatch,
weight_tensor: torch.Tensor,
weight_scales_tensor: torch.Tensor,
):
num_groups = weight_scales_tensor.shape[-1]
in_channels = weight_tensor.shape[-1]
group_size = in_channels // num_groups

# Compute per quant group sums before packing the weight tensor
sum_per_quant_group = compute_per_group_sums(weight_tensor, group_size)

weight_tensor = pack_4bit_weight_tensor(weight_tensor)
# Use this function for convenience to update the state dict with the packed
# weight tensor. Alignment will already have been done in the above function.
weight_tensor = utils.align_width_and_update_state_dict(
ep, match.weight_node, weight_tensor, align_to=1, force_update=True
)

# Also transpose the weight scales tensor to shape [num_groups, N]
weight_scales_tensor = weight_scales_tensor.transpose(0, 1).contiguous()
utils.align_width_and_update_state_dict(
ep,
match.weight_scales_node,
weight_scales_tensor,
align_to=1,
force_update=True,
)

first_graph_node = list(graph_module.graph.nodes)[0]
with graph_module.graph.inserting_before(first_graph_node):
weight_tensor_name = utils.get_tensor_name(ep, match.weight_node)
# Pre-compute the weight sums which are needed to apply activation zero point
# when using integer accumulation.
sums_name = weight_tensor_name + "_sums"
# Sanitize the name
sums_name = sums_name.replace(".", "_")

weight_sums_node = create_constant_placeholder(
exp_program=ep,
graph=graph_module.graph,
kind=InputKind.CONSTANT_TENSOR,
name=sums_name,
data=sum_per_quant_group,
)

with graph_module.graph.inserting_before(match.output_node):
qlinear_node = graph_module.graph.create_node(
"call_function",
exir_ops.edge.et_vk.linear_dq8ca_q4gsw.default,
args=(
match.fp_input_node,
match.input_scales_node,
match.input_zeros_node,
match.weight_node,
weight_sums_node,
match.weight_scales_node,
group_size,
),
)

qlinear_node.meta["val"] = match.output_node.meta["val"]
match.output_node.replace_all_uses_with(qlinear_node)


def make_linear_q8ta_q8csw_custom_op(
ep: ExportedProgram,
graph_module: torch.fx.GraphModule,
Expand Down Expand Up @@ -354,10 +483,16 @@ def replace_quantized_linear_patterns(
make_linear_q4gsw_op(
ep, graph_module, match, weight_tensor, weight_scales_tensor
)
elif (
match.is_input_dynamic_perchannel_quantized()
and match.is_weight_pergroup_quantized()
and utils.is_in_4bit_range(weight_tensor)
):
make_linear_dq8ca_q4gsw_op(
ep, graph_module, match, weight_tensor, weight_scales_tensor
)
elif (
match.is_input_static_per_tensor_quantized()
and match.is_weight_perchannel_quantized()
):
make_linear_q8ta_q8csw_custom_op(ep, graph_module, match, weight_tensor)

# No-op for unsupported quant patterns
7 changes: 7 additions & 0 deletions backends/vulkan/runtime/graph/ops/DispatchNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@ void DispatchNode::encode(ComputeGraph* graph) {
if (!shader_) {
return;
}

// If any global wg size element is 0, then skip encoding this shader
if (global_workgroup_size_[0] == 0 || global_workgroup_size_[1] == 0 ||
global_workgroup_size_[2] == 0) {
return;
}

api::Context* const context = graph->context();
vkapi::PipelineBarrier pipeline_barrier{};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ vec4 matmul_naive_k_dim_packed_row_dim_packed(const ivec3 out_lpos) {
const vec4 mat1_tex = texelFetch(mat1_tensor, mat1_pos, 0);

for (int r = 0; r < 4; ++r) {
if (4 * i + r >= mat2_sizes.y) {
continue;
}
// On-demand construction of mat2_pos appears to provide the lowest
// latency. Surprisingly, this doesn't translate to mat1_pos.
ivec3 mat2_pos = ivec3(0);
Expand Down
Loading
Loading