From 90a7527d80e700fa87ba78e05e3ae6009b395359 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Wed, 4 Jun 2025 11:03:41 -0700 Subject: [PATCH] Temp fix to unblock diff train (#11361) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/11361 Temp fix to unblock diff train Reviewed By: lucylq Differential Revision: D75966594 --- .../_passes/int4_weight_only_quantizer.py | 22 ++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/backends/vulkan/_passes/int4_weight_only_quantizer.py b/backends/vulkan/_passes/int4_weight_only_quantizer.py index cd05738120d..34ff5937822 100644 --- a/backends/vulkan/_passes/int4_weight_only_quantizer.py +++ b/backends/vulkan/_passes/int4_weight_only_quantizer.py @@ -7,11 +7,31 @@ import torch import torch.nn.functional as F -from torchao.quantization.GPTQ.GPTQ import _check_linear_int4_k from torchao.quantization.unified import Quantizer from torchao.quantization.utils import groupwise_affine_quantize_tensor +# TODO: import from from torchao.quantization.GPTQ.GPTQ import _check_linear_int4_k +# Once diff train catches up +def _check_linear_int4_k(k, group_size=1, inner_k_tiles=None): + """ + Check if the dimensions are compatible with int4 quantization. + + Args: + k: The dimension size to check + group_size: The group size for quantization + inner_k_tiles: The inner k tiles size + + Returns: + bool: Whether the dimensions are compatible + """ + k_divisible_by_group_size = k % group_size == 0 + if inner_k_tiles is not None: + k_divisible_by_16_times_inner_k_tiles = k % (inner_k_tiles * 16) == 0 + return k_divisible_by_group_size and k_divisible_by_16_times_inner_k_tiles + return k_divisible_by_group_size + + # This module is copied from torchao.quantization.GPTQ.WeightOnlyInt4Linear with # changes at the annotated lines. class VkWeightOnlyInt4Linear(torch.nn.Module):