Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion backends/vulkan/_passes/int4_weight_only_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,31 @@
import torch
import torch.nn.functional as F

from torchao.quantization.GPTQ.GPTQ import _check_linear_int4_k
from torchao.quantization.unified import Quantizer
from torchao.quantization.utils import groupwise_affine_quantize_tensor


# TODO: import from from torchao.quantization.GPTQ.GPTQ import _check_linear_int4_k
# Once diff train catches up
def _check_linear_int4_k(k, group_size=1, inner_k_tiles=None):
"""
Check if the dimensions are compatible with int4 quantization.

Args:
k: The dimension size to check
group_size: The group size for quantization
inner_k_tiles: The inner k tiles size

Returns:
bool: Whether the dimensions are compatible
"""
k_divisible_by_group_size = k % group_size == 0
if inner_k_tiles is not None:
k_divisible_by_16_times_inner_k_tiles = k % (inner_k_tiles * 16) == 0
return k_divisible_by_group_size and k_divisible_by_16_times_inner_k_tiles
return k_divisible_by_group_size


# This module is copied from torchao.quantization.GPTQ.WeightOnlyInt4Linear with
# changes at the annotated lines.
class VkWeightOnlyInt4Linear(torch.nn.Module):
Expand Down
Loading