Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 61 additions & 19 deletions src/nncf/quantization/algorithms/weight_compression/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@
import math
from typing import Optional, TypeVar

import numpy as np
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AlexanderDokuchaev Is it possible to use np here ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if not I can use the same approach as

reduce(mul, shape[:act_ch_axis] + shape[act_ch_axis % len(shape) + 1 :], 1) for shape in stats.shape_values

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


import nncf
from nncf import Dataset
from nncf.common.graph import NNCFGraph
from nncf.common.graph import NNCFNode
from nncf.common.logging import nncf_logger
from nncf.common.logging.track_progress import track
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
from nncf.common.utils.backend import BackendType
Expand Down Expand Up @@ -130,8 +133,39 @@ def apply(
raise nncf.UnsupportedModelError(msg)

_, input_tensors = next(iter(inputs.items()))
hessian = self._calculate_hessian(node, input_tensors)
scale, zero_point = self._quantize_weights(model, graph, wc_params, hessian, input_tensors)
weight_tensor = self._backend_entity.get_weight(
wc_params.node_with_weight, wc_params.weight_port_id, model, graph
)
weight_tensor = fns.astype(weight_tensor, TensorDataType.float32)

is_3d_weight = len(weight_tensor.shape) == 3

node = wc_params.node_with_weight
hessian = self._calculate_hessian(node, input_tensors, is_3d_weight)
weight_tensor = fns.unsqueeze(weight_tensor, 0) if not is_3d_weight else weight_tensor
scales = []
zero_points = []
weights = []
for batch_idx in range(hessian.shape[0]):
batch_hessian = hessian[batch_idx]
batch_weight = weight_tensor[batch_idx]
reduction_axes = wc_params.reduction_axes
assert len(reduction_axes) == 1, "2D reduction axes is not currently supported in GPTQ"
wc_params.reduction_axes = (reduction_axes[0] - 1,) if is_3d_weight else reduction_axes
# Input tensors is a List of tensors with shape [batch_size, seq_len, hidden_dim] for 3D weights case
# So we need to prepare the list by selecting only the current batch inputs only
input_tensor = [inp[batch_idx] for inp in input_tensors] if is_3d_weight else input_tensors
batch_quantized_weight, batch_scale, batch_zero_point = self._quantize_weights(
wc_params, batch_hessian, batch_weight, input_tensor
)
wc_params.reduction_axes = reduction_axes
weights.append(batch_quantized_weight)
scales.append(batch_scale)
zero_points.append(batch_zero_point)
scale = fns.stack(scales, axis=0) if is_3d_weight else scales[0]
zero_point = fns.stack(zero_points, axis=0) if is_3d_weight and None not in zero_points else zero_points[0]
weight = fns.stack(weights, axis=0) if is_3d_weight else weights[0]
self._backend_entity.set_weight(wc_params.node_with_weight, wc_params.weight_port_id, model, graph, weight)
res[wc_params.weight_name] = CompressedWeight(None, scale, zero_point, None)

return model, res
Expand Down Expand Up @@ -163,7 +197,7 @@ def get_statistic_points(

return self._layerwise_engine.get_statistic_points(model, graph, filtered_nodes)

def _calculate_hessian(self, node: NNCFNode, inputs: list[Tensor]) -> Tensor:
def _calculate_hessian(self, node: NNCFNode, inputs: list[Tensor], is_3d_weight: bool = False) -> Tensor:
"""
Calculates the Hessian matrix for the given node and inputs.

Expand All @@ -179,30 +213,39 @@ def _calculate_hessian(self, node: NNCFNode, inputs: list[Tensor]) -> Tensor:
if node.layer_attributes.input_attributes["transpose"]:
msg = "Transposed input is not supported"
raise nncf.UnsupportedModelError(msg)

# Make hessian 3D. Such that for 2D weights it is only 1 batch and can be squeezed later.
# For 3D weights this dimension matches the weights dimensions
hessian_batch = 1 if not is_3d_weight else np.multiply.reduce(inputs[0].shape[:-2])
hessian = fns.zeros(
(inputs[0].shape[-1], inputs[0].shape[-1]), backend=inputs[0].backend, dtype=TensorDataType.float32
(hessian_batch, inputs[0].shape[-1], inputs[0].shape[-1]),
backend=inputs[0].backend,
dtype=TensorDataType.float32,
)

for inp in inputs:
batch_size = 1 if len(inp.shape) == 2 else inp.shape[0]
is_3d_act = len(inp.shape) == 3
# For 3D weights case, batch size will always be 1. Each "batch"/expert of the activation is treated as
# single 2D matmuls
batch_size = 1 if is_3d_weight or not is_3d_act else inp.shape[0]
if node.metatype in self._backend_entity.matmul_metatypes:
if len(inp.shape) == 3:
# For 3D act + 2D weight case we should reshape activation to 2D to match weight
# For 3D act + 3D weight it should remain in 3D and the last 2 dimensions should be activation per
# batch/0-th dimension
if is_3d_act and not is_3d_weight:
inp = inp.reshape((-1, inp.shape[-1]))
inp = fns.transpose(inp)
inp = fns.moveaxis(inp, -1, -2)
hessian *= nsamples / (nsamples + batch_size)
nsamples += batch_size
inp = fns.astype(inp, TensorDataType.float32) * math.sqrt(2 / nsamples)
hessian += fns.matmul(inp, fns.transpose(inp))
hessian += fns.matmul(inp, fns.moveaxis(inp, -1, -2))

return hessian

def _quantize_weights(
self,
model: TModel,
graph: NNCFGraph,
wc_params: WeightCompressionParameters,
hessian: Tensor,
weight_tensor: Tensor,
inputs: list[Tensor],
):
"""
Expand All @@ -221,10 +264,11 @@ def _quantize_weights(
msg = "Transpose is not supported"
raise RuntimeError(msg)

weight_tensor = self._backend_entity.get_weight(
wc_params.node_with_weight, wc_params.weight_port_id, model, graph
)
weight_tensor = fns.astype(weight_tensor, TensorDataType.float32)
if len(hessian.shape) == 3 and hessian.shape[0] == 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In which model does this happen?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now this is just a safety check. Since for 3D case also we pass 2D hessian to this function. I added it when the older test called this function manually. Would it be better to remove it?

hessian = fns.squeeze(hessian)
msg = "The hessian passed to quantize_weights is 3D. It should be 2D"
nncf_logger.warning(msg=msg)
assert len(hessian.shape) == 2, "Hessian should be 2D"

dead_indices = fns.diag(hessian) == 0
hessian[dead_indices, dead_indices] = 1
Expand Down Expand Up @@ -278,6 +322,7 @@ def _quantize_weights(
else:
if self._scale_estimation and block_compression_config.num_bits == 4:
activations = [inp[..., (i1 + i) : (i1 + i + group_size)] for inp in inputs]
# TODO(anazir): Make it work for 3D weights
wc_statistics = ScaleEstimation.activations_to_wc_statistics(activations)
scale, zero_point = ScaleEstimation.calculate_quantization_params(
wc_statistics,
Expand Down Expand Up @@ -323,9 +368,6 @@ def _quantize_weights(
weight_tensor[:, i2:] -= fns.matmul(error_block, hessian_inv[i1:i2, i2:])

quantized_tensor = quantized_tensor.reshape(weight_tensor.shape).astype(weight_tensor.dtype)
self._backend_entity.set_weight(
wc_params.node_with_weight, wc_params.weight_port_id, model, graph, quantized_tensor
)

scales = fns.stack(scales, axis=1)
if wc_params.compression_config.group_size == -1:
Expand All @@ -339,4 +381,4 @@ def _quantize_weights(
zero_points = fns.squeeze(zero_points, axis=-1)
else:
zero_points = None
return scales, zero_points
return quantized_tensor, scales, zero_points
139 changes: 108 additions & 31 deletions tests/openvino/native/quantization/test_gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@

import numpy as np
import openvino as ov
import pytest
import torch

from nncf import Dataset
from nncf.common.factory import build_graph
from nncf.parameters import CompressWeightsMode
from nncf.quantization.algorithms.weight_compression.config import WeightCompressionConfig
Expand Down Expand Up @@ -323,53 +325,128 @@ def fasterquant(
return scale, zero, g_idx


def test_calculate_scale_linear():
# generate inputs
class Linear3DModel(torch.nn.Module):
def __init__(self, weight_3d: np.ndarray):
super().__init__()
w = torch.from_numpy(weight_3d)
self.weight = torch.nn.Parameter(w)

def forward(self, x):
# OV expects transposed constant when applying GPTQ
return torch.matmul(x, self.weight.transpose(1, 2))


def _create_ov_model(weights: np.ndarray, input_shape: tuple, is_3d_weights: bool = False):
import openvino.runtime.opset13 as opset

param = opset.parameter(input_shape, dtype=np.float32, name="input")
const = opset.constant(weights, dtype=np.float32, name="self.weight")
matmul = opset.matmul(param, const, transpose_a=False, transpose_b=True)
result = opset.result(matmul, name="output")
return ov.Model([result], [param])


def _make_nncf_dataset(ov_model, inputs: list[np.ndarray]) -> Dataset:
input_name = ov_model.inputs[0].get_any_name()
items = [{input_name: inp} for inp in inputs]
return Dataset(items, lambda x: x)


@pytest.mark.parametrize("is_3d_weights", [False, True], ids=["2D_weights", "3D_weights"])
def test_calculate_scale_linear(is_3d_weights: bool):
np.random.seed(0)
inputs = [np.random.rand(128, 32).astype(np.float32) for _ in range(10)]
weights = np.random.rand(20, 32).astype(np.float32)

# calculate reference
with torch.no_grad():
layer = torch.nn.Linear(32, 20)
layer.weight.copy_(torch.from_numpy(weights))
hidden_dims = 32
out_dims = 20
group_size = 16
n_inputs = 10
batch_size = 4

ref_gptq = GPTQReference(layer)
for inp in inputs:
ref_gptq.add_batch(torch.from_numpy(inp))
inputs = [np.random.rand(batch_size, 128, hidden_dims).astype(np.float32) for _ in range(n_inputs)]
weights = np.random.rand(batch_size, out_dims, hidden_dims).astype(np.float32)

ref_scale, _, _ = ref_gptq.fasterquant(percdamp=0.1, group_size=16)
# Select only first batch for 2D case and make it 2D
inputs = inputs if is_3d_weights else [activation[0] for activation in inputs]
weights = weights if is_3d_weights else weights[0]

# convert PyTorch model to OpenVINO
ov_model = ov.convert_model(layer, example_input=inputs[0])
# Step 1: We apply a reference GPTQ implementation on the Pytorch model. This gives us a ground truth of scales and
# Quantized values
with torch.no_grad():
weight = weights if is_3d_weights else np.expand_dims(weights, axis=0)
# Inputs for 3D is a list of shape [n_inputs, batch_size, 128, in_dim]
# For 2D it is [n_inputs, 1, 128, in_dim]. We unsqueeze at 0 for every input tensor
batched_inputs = inputs if is_3d_weights else [np.expand_dims(x, axis=0) for x in inputs]

pt_model = Linear3DModel(weights) if is_3d_weights else torch.nn.Linear(hidden_dims, out_dims, bias=False)
if not is_3d_weights:
pt_model.weight.copy_(torch.from_numpy(weights))

ref_gptqs, ref_scales = [], []
for batch in range(weight.shape[0]):
layer = torch.nn.Linear(hidden_dims, out_dims, bias=False)
layer.weight.copy_(torch.from_numpy(weight[batch]))

ref_gptq = GPTQReference(layer)
for inp in batched_inputs:
ref_gptq.add_batch(torch.from_numpy(inp[batch]))

ref_scale_for_batch, _, _ = ref_gptq.fasterquant(percdamp=0.1, group_size=group_size)
ref_gptqs.append(ref_gptq)
ref_scales.append(ref_scale_for_batch)

ref_scale = np.stack([s.detach().cpu().numpy() for s in ref_scales], axis=0)

# Step 2: Create OV models so that we can use nncf to compress this with our own GPTQ
# We do not use ov.convert_model() here since we expect a specific transposed weight and
# not transposed activation. It is hard to create and convert such a model from PT -> OV
# due to things like constant folding etc. which are automatically performed or generally
# hard to translate.
ov_model = _create_ov_model(weights, inputs[0].shape, is_3d_weights)
graph = build_graph(ov_model)

# GPTQ
# Step 3: Setup and apply GPTQ as usual
gptq = GPTQ()
gptq._set_backend_entity(ov_model)

nodes = graph.get_all_nodes()
node_with_weight = graph.get_all_nodes()[1]

wrapped_inputs = [Tensor(inp) for inp in inputs]
H = gptq._calculate_hessian(nodes[1], wrapped_inputs)
H = gptq._calculate_hessian(node_with_weight, wrapped_inputs, is_3d_weight=is_3d_weights)

ref_H = ref_gptq.H.numpy()
assert np.all(np.isclose(ref_H, H.data))
reference_gptq_list = ref_gptqs if is_3d_weights else [ref_gptq]
nncf_hessian_list = H.data if is_3d_weights else np.expand_dims(H.data, axis=0)

wc_params = WeightCompressionParameters(
for batch_index, (reference_gptq, nncf_hessian) in enumerate(zip(reference_gptq_list, nncf_hessian_list)):
reference_hessian = reference_gptq.H.detach().numpy()
assert np.all(np.isclose(reference_hessian, nncf_hessian)), f"Hessian mismatch for batch {batch_index}"

nncf_dataset = _make_nncf_dataset(ov_model, inputs)
reduction_axes = (1,) if not is_3d_weights else (2,)
wc_param = WeightCompressionParameters(
weight_name="self.weight",
node_with_weight=nodes[1],
node_with_weight=node_with_weight,
weight_port_id=1,
weight_dtype=TensorDataType.float32,
weight_shape=weights.shape,
reduction_axes=(1,),
reduction_axes=reduction_axes,
)
wc_param.compression_config = WeightCompressionConfig(mode=CompressWeightsMode.INT4_SYM, group_size=group_size)
wc_params = [wc_param]

_, res = gptq.apply(ov_model, graph, nncf_dataset, wc_params)

# Step 4: Obtain the scales from our GPTQ implementation and compare with referencee
scale_from_nncf = res.get("self.weight").scale.data
ref_scale = ref_scale.numpy() if isinstance(ref_scale, torch.Tensor) else ref_scale
ref_scale = ref_scale.reshape(scale_from_nncf.shape)

assert np.all(np.isclose(ref_scale, scale_from_nncf))
# Here we obtain weights from the model instead of apply directly so that we also check
# if the weight is changed in the OV model.
ov_weight = gptq._backend_entity.get_weight(node_with_weight, 1, ov_model, graph)
ref_weights = np.stack(
[ref_gptq.layer.weight.detach().numpy() for ref_gptq in ref_gptqs],
axis=0,
)
wc_params.compression_config = WeightCompressionConfig(mode=CompressWeightsMode.INT4_SYM, group_size=16)

scale, _ = gptq._quantize_weights(ov_model, graph, wc_params, H, wrapped_inputs)
ref_scale = ref_scale.numpy()
scale = scale.reshape(ref_scale.shape)
assert np.all(np.isclose(ref_scale, scale.data))

torch_weight = layer.weight.detach().numpy()
ov_weight = gptq._backend_entity.get_weight(nodes[1], 1, ov_model, graph)
assert np.all(np.isclose(torch_weight, ov_weight.data))
assert np.all(np.isclose(ref_weights, ov_weight.data))