Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 20 additions & 47 deletions src/compressed_tensors/quantization/lifecycle/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

from functools import wraps
from math import ceil
from typing import Optional

import torch
Expand All @@ -28,6 +27,7 @@
from compressed_tensors.quantization.utils import (
calculate_range,
compute_dynamic_scales_and_zp,
strategy_cdiv,
)
from torch.nn import Module

Expand Down Expand Up @@ -257,53 +257,25 @@ def _process_quantization(
global_scale=global_scale,
)
# restore original shape
output = x_blocks.transpose(1, 2).reshape(original_shape)
x = x_blocks.transpose(1, 2).reshape(original_shape)
elif args.strategy in (
QuantizationStrategy.GROUP,
QuantizationStrategy.TENSOR_GROUP,
):

output_dtype = dtype if dtype is not None else x.dtype
output = torch.zeros_like(x).to(output_dtype)
columns = output.shape[-1]

# TODO: make validation step for inputs

while scale.ndim < 2:
# pad scale and zero point dims for slicing
scale = scale.unsqueeze(1)
zero_point = zero_point.unsqueeze(1) if zero_point is not None else None

if columns >= group_size:
if columns % group_size != 0:
raise ValueError(
"tensor column shape must be divisble "
f"by the given group_size {group_size}"
)

# support column-order (default) quantization as well as other orderings
# such as activation ordering. Below checks if g_idx has been initialized
is_column_order = g_idx is None or -1 in g_idx
if is_column_order:
num_groups = int(ceil(columns / group_size))
group_sizes = torch.full((num_groups,), group_size, dtype=torch.int)

else:
group_indices, group_sizes = torch.unique(g_idx, return_counts=True)
group_sizes = group_sizes[torch.argsort(group_indices)]

# activation ordering
if g_idx is not None:
perm = torch.argsort(g_idx)
x = x.index_select(-1, perm)

# Maintain all dimensions except the last dim, which is divided by group_size
reshaped_dims = (
ceil(x.shape[-1] / group_size),
group_size,
)
# reshape into groups
num_groups = strategy_cdiv(x.size(-1), group_size, args.strategy, strict=True)
reshaped_dims = (num_groups, group_size)
x = x.unflatten(-1, reshaped_dims)

if do_quantize:
output = _quantize(
x = _quantize(
x=x,
scale=scale.unsqueeze(-1),
zero_point=zero_point.unsqueeze(-1) if zero_point is not None else None,
Expand All @@ -315,24 +287,25 @@ def _process_quantization(
)

if do_dequantize:
input = output if do_quantize else x
output = _dequantize(
x_q=input,
x = _dequantize(
x_q=x,
scale=scale.unsqueeze(-1),
zero_point=zero_point.unsqueeze(-1) if zero_point is not None else None,
global_scale=global_scale,
)

output = output.flatten(start_dim=-2)
output = output.to(output_dtype)
# undo reshape into groups
x = x.flatten(-2, -1)
x = x.to(output_dtype)

if not is_column_order:
# undo activation ordering
if g_idx is not None:
inv_perm = torch.argsort(perm)
output = output.index_select(-1, inv_perm)
x = x.index_select(-1, inv_perm)

else: # covers channel, token and tensor strategies
if do_quantize:
output = _quantize(
x = _quantize(
x=x,
scale=scale,
zero_point=zero_point,
Expand All @@ -343,14 +316,14 @@ def _process_quantization(
global_scale=global_scale,
)
if do_dequantize:
output = _dequantize(
output if do_quantize else x,
x = _dequantize(
x_q=x,
scale=scale,
zero_point=zero_point,
global_scale=global_scale,
)

return output
return x


def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
Expand Down
Loading