diff --git a/backends/vulkan/_passes/int4_weight_only_quantizer.py b/backends/vulkan/_passes/int4_weight_only_quantizer.py index cd05738120d..34ff5937822 100644 --- a/backends/vulkan/_passes/int4_weight_only_quantizer.py +++ b/backends/vulkan/_passes/int4_weight_only_quantizer.py @@ -7,11 +7,31 @@ import torch import torch.nn.functional as F -from torchao.quantization.GPTQ.GPTQ import _check_linear_int4_k from torchao.quantization.unified import Quantizer from torchao.quantization.utils import groupwise_affine_quantize_tensor +# TODO: import from from torchao.quantization.GPTQ.GPTQ import _check_linear_int4_k +# Once diff train catches up +def _check_linear_int4_k(k, group_size=1, inner_k_tiles=None): + """ + Check if the dimensions are compatible with int4 quantization. + + Args: + k: The dimension size to check + group_size: The group size for quantization + inner_k_tiles: The inner k tiles size + + Returns: + bool: Whether the dimensions are compatible + """ + k_divisible_by_group_size = k % group_size == 0 + if inner_k_tiles is not None: + k_divisible_by_16_times_inner_k_tiles = k % (inner_k_tiles * 16) == 0 + return k_divisible_by_group_size and k_divisible_by_16_times_inner_k_tiles + return k_divisible_by_group_size + + # This module is copied from torchao.quantization.GPTQ.WeightOnlyInt4Linear with # changes at the annotated lines. class VkWeightOnlyInt4Linear(torch.nn.Module):