Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions backends/xnnpack/operators/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def get_per_channel_dtype(
if quant_params.dtype == torch.int32:
return XNNDatatype.xnn_datatype_qcint32
elif quant_params.dtype == torch.int8:
if quant_params.is_per_channel_group:
if quant_params.per_channel_group:
# 4-bit per channel group quantized weights
# No 8-bit support yet
assert (
Expand Down Expand Up @@ -282,7 +282,7 @@ def get_quant_params(
buffer_idx = len(xnn_graph.constant_data)
num_scales = scale.numel()

if quant_params.is_per_channel_group:
if quant_params.per_channel_group:
scale = scale.to(torch.bfloat16)

num_bytes = scale.untyped_storage().nbytes()
Expand All @@ -300,7 +300,7 @@ def get_quant_params(
scale_name, bytes(scale_array), CONSTANT_TENSOR_ALIGNMENT
)

if quant_params.is_per_channel_group:
if quant_params.per_channel_group:
return PerChannelGroupQuant(
scale=[],
channel_dim=quant_params.axis,
Expand Down Expand Up @@ -335,7 +335,7 @@ def _check_per_channel_group_params(
) -> None:
# Make sure things are lining up for per_channel_group quantization case
# Has to be done this late because we don't have clean access to the actual tensor
assert quant_params.is_per_channel_group, "Not per_channel_group quantization"
assert quant_params.per_channel_group, "Not per_channel_group quantization"
# linear weights will be in [oc, ic]. And per_channel quantization must be on axis 0
num_groups = cast(torch.Tensor, quant_params.scale).shape[1]
assert (
Expand Down
54 changes: 48 additions & 6 deletions backends/xnnpack/operators/quant_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,19 +89,36 @@ def __init__(
# Groupwise quantization for weight
self.per_channel_group = False
self.group_size = group_size

tensor = q_input.meta["val"]

if self.group_size > 0:
assert (
self.per_channel is True
), "Only per channel quantization supports groupwise quantization"
assert (
self.axis == 0, "Only axis 0 is supported for per channel groupwise quant"
)
assert (
cast(torch.Tensor, scale).ndim == 2
), "Scale must be 2D for per channel groupwise quant"
self.per_channel_group = True
assert group_size > 0, "Group size must be greater than 0"
self.is_per_channel_group = self.per_channel and self.group_size > 0

if per_channel and not self.is_per_channel_group:
tensor = q_input.meta["val"]
# Assumed scale shape - [out_channels, in_channels/group_size]
input_channels = cast(torch.Tensor, scale).shape[1] * self.group_size
# 2d weight tensor shape - [out_channels, in_channels]
assert (
tensor.shape[1] == input_channels, "Invalid input channels for groupwise quant"
)
# Prefer per_channel over per_channel_group when group_size == input_channels for non int4 cases only
# int4 case need more fixes to map qb4w to qc4w. Incorrect scales being passed down to xnnpack.
self.per_channel_group = self.group_size <= input_channels if self.is_qc4w else self.group_size < input_channels

if not self.per_channel_group:
if cast(torch.Tensor, scale).ndim == 2:
# TODO: don't reshape scale for per_channel cases
assert (cast(torch.Tensor, scale).shape[1] == 1), "Invalid scale shape for per channel quantization"
scale = cast(torch.Tensor, scale).squeeze(1)

if per_channel and not self.per_channel_group:
assert (
tensor.shape[self.axis] == cast(torch.Tensor, self.scale).shape[0]
), f"Invalid size of per channel quantization scales, axis: {self.axis}, scale size: {self.scale.shape}, tensor shape: {tensor.shape}"
Expand All @@ -110,6 +127,31 @@ def __init__(
tensor.shape[self.axis] == cast(torch.Tensor, self.zp).shape[0]
), f"Invalid size of per channel quantization zero-points, axis: {self.axis}, zp size: {self.zp.shape}, tensor shape: {tensor.shape}"

def __str__(self) -> str:
"""String representation of QuantParams for debugging and logging."""
assert isinstance(self.scale, float) or isinstance(self.scale, torch.Tensor)
scale_str = f"{self.scale}" if isinstance(self.scale, float) else f"tensor{tuple(self.scale.shape)}"
assert isinstance(self.zp, float) or isinstance(self.zp, torch.Tensor)
zp_str = f"{self.zp}" if isinstance(self.zp, float) else f"tensor{tuple(self.zp.shape)}"

return (
f"QuantParams("
f"per_channel={self.per_channel}, "
f"per_channel_group={self.per_channel_group}, "
f"scale={scale_str}, "
f"zp={zp_str}, "
f"axis={self.axis}, "
f"dtype={self.dtype}, "
f"qmin={self.qmin}, "
f"qmax={self.qmax}, "
f"is_dynamic={self.is_dynamic}, "
f"is_input={self.is_input}, "
f"is_output={self.is_output}, "
f"group_size={self.group_size}, "
f"is_qc4w={self.is_qc4w}"
f")"
)

def quantize_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
# Do nothing if already quantized by the Quantizer
if tensor.dtype == self.dtype:
Expand Down
Loading