-
Notifications
You must be signed in to change notification settings - Fork 279
Support 3D Weights in GPTQ Algorithm #3835
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 11 commits
5a287f8
4076d08
26e8403
2c8dd51
3d48185
8b46a1a
c52c21e
8ac8fdf
dc355fe
01b1498
e3e21fb
c93419e
a57918c
ad00297
d2091d0
a9f54ff
3cbb0fb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,10 +12,13 @@ | |
| import math | ||
| from typing import Optional, TypeVar | ||
|
|
||
| import numpy as np | ||
|
|
||
| import nncf | ||
| from nncf import Dataset | ||
| from nncf.common.graph import NNCFGraph | ||
| from nncf.common.graph import NNCFNode | ||
| from nncf.common.logging import nncf_logger | ||
| from nncf.common.logging.track_progress import track | ||
| from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer | ||
| from nncf.common.utils.backend import BackendType | ||
|
|
@@ -130,8 +133,39 @@ def apply( | |
| raise nncf.UnsupportedModelError(msg) | ||
|
|
||
| _, input_tensors = next(iter(inputs.items())) | ||
| hessian = self._calculate_hessian(node, input_tensors) | ||
| scale, zero_point = self._quantize_weights(model, graph, wc_params, hessian, input_tensors) | ||
| weight_tensor = self._backend_entity.get_weight( | ||
| wc_params.node_with_weight, wc_params.weight_port_id, model, graph | ||
| ) | ||
| weight_tensor = fns.astype(weight_tensor, TensorDataType.float32) | ||
|
|
||
| is_3d_weight = len(weight_tensor.shape) == 3 | ||
|
|
||
| node = wc_params.node_with_weight | ||
| hessian = self._calculate_hessian(node, input_tensors, is_3d_weight) | ||
| weight_tensor = fns.unsqueeze(weight_tensor, 0) if not is_3d_weight else weight_tensor | ||
| scales = [] | ||
| zero_points = [] | ||
| weights = [] | ||
| for batch_idx in range(hessian.shape[0]): | ||
| batch_hessian = hessian[batch_idx] | ||
| batch_weight = weight_tensor[batch_idx] | ||
| reduction_axes = wc_params.reduction_axes | ||
| assert len(reduction_axes) == 1, "2D reduction axes is not currently supported in GPTQ" | ||
| wc_params.reduction_axes = (reduction_axes[0] - 1,) if is_3d_weight else reduction_axes | ||
| # Input tensors is a List of tensors with shape [batch_size, seq_len, hidden_dim] for 3D weights case | ||
| # So we need to prepare the list by selecting only the current batch inputs only | ||
| input_tensor = [inp[batch_idx] for inp in input_tensors] if is_3d_weight else input_tensors | ||
| batch_quantized_weight, batch_scale, batch_zero_point = self._quantize_weights( | ||
| wc_params, batch_hessian, batch_weight, input_tensor | ||
| ) | ||
| wc_params.reduction_axes = reduction_axes | ||
| weights.append(batch_quantized_weight) | ||
| scales.append(batch_scale) | ||
| zero_points.append(batch_zero_point) | ||
| scale = fns.stack(scales, axis=0) if is_3d_weight else scales[0] | ||
| zero_point = fns.stack(zero_points, axis=0) if is_3d_weight and None not in zero_points else zero_points[0] | ||
| weight = fns.stack(weights, axis=0) if is_3d_weight else weights[0] | ||
| self._backend_entity.set_weight(wc_params.node_with_weight, wc_params.weight_port_id, model, graph, weight) | ||
| res[wc_params.weight_name] = CompressedWeight(None, scale, zero_point, None) | ||
|
|
||
| return model, res | ||
|
|
@@ -163,7 +197,7 @@ def get_statistic_points( | |
|
|
||
| return self._layerwise_engine.get_statistic_points(model, graph, filtered_nodes) | ||
|
|
||
| def _calculate_hessian(self, node: NNCFNode, inputs: list[Tensor]) -> Tensor: | ||
| def _calculate_hessian(self, node: NNCFNode, inputs: list[Tensor], is_3d_weight: bool = False) -> Tensor: | ||
| """ | ||
| Calculates the Hessian matrix for the given node and inputs. | ||
|
|
||
|
|
@@ -179,30 +213,39 @@ def _calculate_hessian(self, node: NNCFNode, inputs: list[Tensor]) -> Tensor: | |
| if node.layer_attributes.input_attributes["transpose"]: | ||
| msg = "Transposed input is not supported" | ||
| raise nncf.UnsupportedModelError(msg) | ||
|
|
||
| # Make hessian 3D. Such that for 2D weights it is only 1 batch and can be squeezed later. | ||
| # For 3D weights this dimension matches the weights dimensions | ||
| hessian_batch = 1 if not is_3d_weight else np.multiply.reduce(inputs[0].shape[:-2]) | ||
| hessian = fns.zeros( | ||
| (inputs[0].shape[-1], inputs[0].shape[-1]), backend=inputs[0].backend, dtype=TensorDataType.float32 | ||
| (hessian_batch, inputs[0].shape[-1], inputs[0].shape[-1]), | ||
| backend=inputs[0].backend, | ||
| dtype=TensorDataType.float32, | ||
| ) | ||
|
|
||
| for inp in inputs: | ||
| batch_size = 1 if len(inp.shape) == 2 else inp.shape[0] | ||
| is_3d_act = len(inp.shape) == 3 | ||
| # For 3D weights case, batch size will always be 1. Each "batch"/expert of the activation is treated as | ||
| # single 2D matmuls | ||
| batch_size = 1 if is_3d_weight or not is_3d_act else inp.shape[0] | ||
| if node.metatype in self._backend_entity.matmul_metatypes: | ||
| if len(inp.shape) == 3: | ||
| # For 3D act + 2D weight case we should reshape activation to 2D to match weight | ||
| # For 3D act + 3D weight it should remain in 3D and the last 2 dimensions should be activation per | ||
| # batch/0-th dimension | ||
| if is_3d_act and not is_3d_weight: | ||
| inp = inp.reshape((-1, inp.shape[-1])) | ||
| inp = fns.transpose(inp) | ||
| inp = fns.moveaxis(inp, -1, -2) | ||
| hessian *= nsamples / (nsamples + batch_size) | ||
| nsamples += batch_size | ||
| inp = fns.astype(inp, TensorDataType.float32) * math.sqrt(2 / nsamples) | ||
| hessian += fns.matmul(inp, fns.transpose(inp)) | ||
| hessian += fns.matmul(inp, fns.moveaxis(inp, -1, -2)) | ||
|
|
||
| return hessian | ||
|
|
||
| def _quantize_weights( | ||
| self, | ||
| model: TModel, | ||
| graph: NNCFGraph, | ||
| wc_params: WeightCompressionParameters, | ||
| hessian: Tensor, | ||
| weight_tensor: Tensor, | ||
| inputs: list[Tensor], | ||
| ): | ||
| """ | ||
|
|
@@ -221,10 +264,11 @@ def _quantize_weights( | |
| msg = "Transpose is not supported" | ||
| raise RuntimeError(msg) | ||
|
|
||
| weight_tensor = self._backend_entity.get_weight( | ||
| wc_params.node_with_weight, wc_params.weight_port_id, model, graph | ||
| ) | ||
| weight_tensor = fns.astype(weight_tensor, TensorDataType.float32) | ||
| if len(hessian.shape) == 3 and hessian.shape[0] == 1: | ||
|
||
| hessian = fns.squeeze(hessian) | ||
| msg = "The hessian passed to quantize_weights is 3D. It should be 2D" | ||
| nncf_logger.warning(msg=msg) | ||
ljaljushkin marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| assert len(hessian.shape) == 2, "Hessian should be 2D" | ||
|
|
||
| dead_indices = fns.diag(hessian) == 0 | ||
| hessian[dead_indices, dead_indices] = 1 | ||
|
|
@@ -278,6 +322,7 @@ def _quantize_weights( | |
| else: | ||
| if self._scale_estimation and block_compression_config.num_bits == 4: | ||
| activations = [inp[..., (i1 + i) : (i1 + i + group_size)] for inp in inputs] | ||
| # TODO(anazir): Make it work for 3D weights | ||
ljaljushkin marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
ljaljushkin marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| wc_statistics = ScaleEstimation.activations_to_wc_statistics(activations) | ||
| scale, zero_point = ScaleEstimation.calculate_quantization_params( | ||
| wc_statistics, | ||
|
|
@@ -323,9 +368,6 @@ def _quantize_weights( | |
| weight_tensor[:, i2:] -= fns.matmul(error_block, hessian_inv[i1:i2, i2:]) | ||
|
|
||
| quantized_tensor = quantized_tensor.reshape(weight_tensor.shape).astype(weight_tensor.dtype) | ||
| self._backend_entity.set_weight( | ||
| wc_params.node_with_weight, wc_params.weight_port_id, model, graph, quantized_tensor | ||
| ) | ||
|
|
||
| scales = fns.stack(scales, axis=1) | ||
| if wc_params.compression_config.group_size == -1: | ||
|
|
@@ -339,4 +381,4 @@ def _quantize_weights( | |
| zero_points = fns.squeeze(zero_points, axis=-1) | ||
| else: | ||
| zero_points = None | ||
| return scales, zero_points | ||
| return quantized_tensor, scales, zero_points | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@AlexanderDokuchaev Is it possible to use np here ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if not I can use the same approach as
nncf/src/nncf/quantization/algorithms/weight_compression/activation_stats.py
Line 46 in d35c32b
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done