-
Notifications
You must be signed in to change notification settings - Fork 272
[NNCF] Enable data-aware weight compression for MatMul with transpose_b=False #3759
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from 2 commits
13fcd44
ea6e5e3
80ef438
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,6 +27,13 @@ | |
| from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters | ||
| from nncf.quantization.algorithms.weight_compression.parameters import CompressedWeight | ||
| from nncf.quantization.algorithms.weight_compression.scale_estimation import ScaleEstimation | ||
| from nncf.quantization.algorithms.weight_compression.utils import ( | ||
| assign_weight_column, | ||
| assign_weight_slice, | ||
| extract_weight_column, | ||
| slice_weight, | ||
| zero_mask_columns, | ||
|
||
| ) | ||
| from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_float_quantization_params | ||
| from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_integer_quantization_params | ||
| from nncf.quantization.algorithms.weight_compression.weight_lowering import float_quantize_dequantize_weight | ||
|
|
@@ -212,18 +219,20 @@ def _quantize_weights( | |
| if wc_params.node_with_weight.metatype in self._backend_entity.convolution_metatypes: | ||
| msg = "Convolution metatypes are not supported" | ||
| raise RuntimeError(msg) | ||
| if not wc_params.node_with_weight.layer_attributes.constant_attributes[wc_params.weight_port_id]["transpose"]: | ||
| msg = "Transpose is not supported" | ||
| raise RuntimeError(msg) | ||
|
|
||
| weight_tensor = self._backend_entity.get_weight( | ||
| wc_params.node_with_weight, wc_params.weight_port_id, model, graph | ||
| ) | ||
| weight_tensor = fns.astype(weight_tensor, TensorDataType.float32) | ||
|
|
||
| # Get transpose_b value to handle weight shape correctly | ||
| transpose_b = wc_params.node_with_weight.layer_attributes.constant_attributes[wc_params.weight_port_id]["transpose"] | ||
|
||
|
|
||
| dead_indices = fns.diag(hessian) == 0 | ||
| hessian[dead_indices, dead_indices] = 1 | ||
| weight_tensor[:, dead_indices] = 0 | ||
|
|
||
| # Zero out dead indices using utility helper | ||
| zero_mask_columns(weight_tensor, dead_indices, transpose_b) | ||
|
|
||
| scales = [] | ||
| zero_points = [] | ||
|
|
@@ -235,7 +244,7 @@ def _quantize_weights( | |
| group_size = ( | ||
| wc_params.compression_config.group_size | ||
| if wc_params.compression_config.group_size != -1 | ||
| else weight_tensor.shape[1] | ||
| else (weight_tensor.shape[1] if transpose_b else weight_tensor.shape[0]) | ||
| ) | ||
| reduction_axes = wc_params.reduction_axes | ||
| block_compression_config = WeightCompressionConfig( | ||
|
|
@@ -254,35 +263,38 @@ def _quantize_weights( | |
| i2 = min(i1 + self._block_size, columns) | ||
| count = i2 - i1 | ||
|
|
||
| weight_block = weight_tensor[:, i1:i2].clone() | ||
| # Extract weight block using utility helper | ||
| weight_block = slice_weight(weight_tensor, i1, i2, transpose_b).clone() | ||
| quantized_block = fns.zeros_like(weight_block) | ||
| error_block = fns.zeros_like(weight_block) | ||
| loss_block = fns.zeros_like(weight_block) | ||
| hessian_inv_block = hessian_inv[i1:i2, i1:i2] | ||
|
|
||
| for i in range(count): | ||
| weight_col = weight_block[:, i] | ||
| weight_col = extract_weight_column(weight_block, i, transpose_b) | ||
| hessian_diag_val = hessian_inv_block[i, i] | ||
|
|
||
| if (i1 + i) % group_size == 0: | ||
| if not block_compression_config.is_integer: | ||
| weight_slice = slice_weight(weight_tensor, i1 + i, i1 + i + group_size, transpose_b) | ||
| scale = calculate_float_quantization_params( | ||
| weight_tensor[:, (i1 + i) : (i1 + i + group_size)], reduction_axes, block_compression_config | ||
| weight_slice, reduction_axes, block_compression_config | ||
| ) | ||
| scales.append(scale) | ||
| else: | ||
| weight_slice = slice_weight(weight_tensor, i1 + i, i1 + i + group_size, transpose_b) | ||
| if self._scale_estimation and block_compression_config.num_bits == 4: | ||
| activations = [inp[..., (i1 + i) : (i1 + i + group_size)] for inp in inputs] | ||
| wc_statistics = ScaleEstimation.activations_to_wc_statistics(activations) | ||
| scale, zero_point = ScaleEstimation.calculate_quantization_params( | ||
| wc_statistics, | ||
| weight_tensor[:, (i1 + i) : (i1 + i + group_size)], | ||
| weight_slice, | ||
| reduction_axes, | ||
| block_compression_config, | ||
| ) | ||
| else: | ||
| scale, zero_point = calculate_integer_quantization_params( | ||
| weight_tensor[:, (i1 + i) : (i1 + i + group_size)], | ||
| weight_slice, | ||
| reduction_axes, | ||
| block_compression_config, | ||
| ) | ||
|
|
@@ -303,19 +315,34 @@ def _quantize_weights( | |
| precomputed_zero_point=zero_points[-1], | ||
| ) | ||
| quantized_col = fns.flatten(quantized_col) | ||
| quantized_block[:, i] = quantized_col | ||
| loss_block[:, i] = (weight_col - quantized_col) ** 2 / hessian_diag_val**2 | ||
| assign_weight_column(quantized_block, i, quantized_col, transpose_b) | ||
| loss_col = (weight_col - quantized_col) ** 2 / hessian_diag_val**2 | ||
| assign_weight_column(loss_block, i, loss_col, transpose_b) | ||
|
|
||
| error_col = (weight_col - quantized_col) / hessian_diag_val | ||
| weight_block[:, i:] -= fns.matmul( | ||
| fns.unsqueeze(error_col, 1), fns.unsqueeze(hessian_inv_block[i, i:], 0) | ||
| ) | ||
| error_block[:, i] = error_col | ||
|
|
||
| quantized_tensor[:, i1:i2] = quantized_block | ||
| losses[:, i1:i2] = loss_block / 2 | ||
|
|
||
| weight_tensor[:, i2:] -= fns.matmul(error_block, hessian_inv[i1:i2, i2:]) | ||
| if transpose_b: | ||
| weight_block[:, i:] -= fns.matmul( | ||
| fns.unsqueeze(error_col, 1), fns.unsqueeze(hessian_inv_block[i, i:], 0) | ||
| ) | ||
| assign_weight_column(error_block, i, error_col, transpose_b) | ||
| else: | ||
| weight_block[i:, :] -= fns.matmul( | ||
| fns.unsqueeze(error_col, 0), fns.unsqueeze(hessian_inv_block[i:, i], 1) | ||
| ) | ||
| assign_weight_column(error_block, i, error_col, transpose_b) | ||
|
|
||
| assign_weight_slice(quantized_tensor, i1, i2, quantized_block, transpose_b) | ||
| assign_weight_slice(losses, i1, i2, loss_block / 2, transpose_b) | ||
|
|
||
| # Update remaining weights with error propagation | ||
| if transpose_b: | ||
| weight_tensor[:, i2:] -= fns.matmul(error_block, hessian_inv[i1:i2, i2:]) | ||
| else: | ||
| # For transpose_b=False: error_block shape is [i2-i1, out_features] | ||
| # hessian_inv[i2:, i1:i2] shape is [columns-i2, i2-i1] | ||
| # We need to transpose error_block to get [out_features, i2-i1] | ||
| # Then: hessian_inv[i2:, i1:i2] @ error_block^T gives [columns-i2, out_features] | ||
| weight_tensor[i2:, :] -= fns.matmul(hessian_inv[i2:, i1:i2], fns.transpose(error_block)) | ||
|
|
||
| quantized_tensor = quantized_tensor.reshape(weight_tensor.shape).astype(weight_tensor.dtype) | ||
| self._backend_entity.set_weight( | ||
|
|
@@ -325,13 +352,14 @@ def _quantize_weights( | |
| scales = fns.stack(scales, axis=1) | ||
| if wc_params.compression_config.group_size == -1: | ||
| scales = fns.squeeze(scales, axis=-1) | ||
| if wc_params.compression_config.mode in [ | ||
|
|
||
| zero_points_tensor = None | ||
| if zero_points and zero_points[0] is not None and wc_params.compression_config.mode in [ | ||
| CompressWeightsMode.INT8_ASYM, | ||
| CompressWeightsMode.INT4_ASYM, | ||
| ]: | ||
| zero_points = fns.stack(zero_points, axis=1) | ||
| zero_points_tensor = fns.stack(zero_points, axis=1) | ||
| if wc_params.compression_config.group_size == -1: | ||
| zero_points = fns.squeeze(zero_points, axis=-1) | ||
| else: | ||
| zero_points = None | ||
| return scales, zero_points | ||
| zero_points_tensor = fns.squeeze(zero_points_tensor, axis=-1) | ||
|
|
||
| return scales, zero_points_tensor | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,102 @@ | ||
| # Copyright (c) 2025 Intel Corporation | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from nncf.tensor import Tensor | ||
|
|
||
|
|
||
| def slice_weight(weight: Tensor, start: int, end: int, transpose_b: bool) -> Tensor: | ||
| """ | ||
| Return a view/clone of the requested block without transposing the whole tensor. | ||
|
|
||
| If transpose_b is True, weight layout is [out_features, in_features] | ||
| and we return weight[:, start:end] (in_features slice). | ||
| If transpose_b is False, layout is [in_features, out_features] | ||
| and we return weight[start:end, :] (in_features slice). | ||
|
|
||
| :param weight: The weight tensor to slice. | ||
| :param start: Start index for the slice (inclusive). | ||
| :param end: End index for the slice (exclusive). | ||
| :param transpose_b: Whether the weight is transposed (True) or not (False). | ||
| :return: A slice of the weight tensor. | ||
| """ | ||
| if transpose_b: | ||
| return weight[:, start:end] | ||
| else: | ||
| return weight[start:end, :] | ||
|
|
||
|
|
||
| def extract_weight_column(weight: Tensor, index: int, transpose_b: bool) -> Tensor: | ||
| """ | ||
| Extract a single column/row from weight based on transpose_b. | ||
|
|
||
| If transpose_b is True: returns weight[:, index] (a column) | ||
| If transpose_b is False: returns weight[index, :] (a row) | ||
|
|
||
| :param weight: The weight tensor to extract from. | ||
| :param index: The index of the column/row to extract. | ||
| :param transpose_b: Whether the weight is transposed (True) or not (False). | ||
| :return: A single column or row from the weight tensor. | ||
| """ | ||
| if transpose_b: | ||
| return weight[:, index] | ||
| else: | ||
| return weight[index, :] | ||
|
|
||
|
|
||
| def assign_weight_slice(target_weight: Tensor, start: int, end: int, block: Tensor, transpose_b: bool) -> None: | ||
| """ | ||
| Assign block back to target_weight in the same orientation used by slice_weight. | ||
| This performs in-place assignment. | ||
|
|
||
| :param target_weight: The target weight tensor to assign to. | ||
| :param start: Start index for the slice (inclusive). | ||
| :param end: End index for the slice (exclusive). | ||
| :param block: The block of data to assign. | ||
| :param transpose_b: Whether the weight is transposed (True) or not (False). | ||
| """ | ||
| if transpose_b: | ||
| target_weight[:, start:end] = block | ||
| else: | ||
| target_weight[start:end, :] = block | ||
|
|
||
|
|
||
| def assign_weight_column(target_weight: Tensor, index: int, column: Tensor, transpose_b: bool) -> None: | ||
| """ | ||
| Assign a single column/row back to target_weight. | ||
| This performs in-place assignment. | ||
|
|
||
| :param target_weight: The target weight tensor to assign to. | ||
| :param index: The index of the column/row to assign. | ||
| :param column: The column/row data to assign. | ||
| :param transpose_b: Whether the weight is transposed (True) or not (False). | ||
| """ | ||
| if transpose_b: | ||
| target_weight[:, index] = column | ||
| else: | ||
| target_weight[index, :] = column | ||
|
|
||
|
|
||
| def zero_mask_columns(weight: Tensor, mask: Tensor, transpose_b: bool) -> None: | ||
| """ | ||
| Zero out columns/rows based on boolean mask. | ||
|
|
||
| If transpose_b is True: zeros weight[:, mask] (columns) | ||
| If transpose_b is False: zeros weight[mask, :] (rows) | ||
|
|
||
| :param weight: The weight tensor to modify in-place. | ||
| :param mask: Boolean mask indicating which columns/rows to zero. | ||
| :param transpose_b: Whether the weight is transposed (True) or not (False). | ||
| """ | ||
| if transpose_b: | ||
| weight[:, mask] = 0 | ||
| else: | ||
| weight[mask, :] = 0 | ||
|
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,93 @@ | ||
| import numpy as np | ||
| import pytest | ||
| import torch | ||
| from nncf.quantization.algorithms.weight_compression import utils | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "shape, transpose_b, start, end", | ||
| [ | ||
| # transpose_b=True means weight layout is [out_features, in_features] -> slice columns | ||
| ((5, 8), True, 1, 4), | ||
| ((3, 6), True, 0, 3), | ||
| # transpose_b=False means weight layout is [in_features, out_features] -> slice rows | ||
| ((8, 5), False, 2, 6), | ||
| ((6, 3), False, 0, 2), | ||
| ], | ||
| ) | ||
| def test_slice_and_assign_weight_block(shape, transpose_b, start, end): | ||
| """ | ||
| Verify slice_weight returns the expected sub-block and assign_weight_slice writes it back | ||
| in the correct orientation for both transpose_b True and False. | ||
| """ | ||
|
|
||
| weight = np.arange(np.prod(shape), dtype=np.int64).reshape(shape) | ||
| block = utils.slice_weight(weight, start, end, transpose_b) | ||
|
|
||
| # Expected block depending on transpose_b semantics | ||
| if transpose_b: | ||
| expected_block = weight[:, start:end] | ||
| else: | ||
| expected_block = weight[start:end, :] | ||
|
|
||
| # The returned block should match the expected slice | ||
| np.testing.assert_array_equal(block, expected_block) | ||
|
|
||
| # Prepare a new block to assign (different values) | ||
| new_block = np.full(expected_block.shape, fill_value=123, dtype=weight.dtype) | ||
|
|
||
| # Assign it back using the helper | ||
| utils.assign_weight_slice(weight, start, end, new_block, transpose_b) | ||
| if transpose_b: | ||
| np.testing.assert_array_equal(weight[:, start:end], new_block) | ||
| else: | ||
| np.testing.assert_array_equal(weight[start:end, :], new_block) | ||
|
|
||
| def test_zero_mask_columns(): | ||
| """ | ||
| Verifies that zero_mask_columns correctly zeros out channels | ||
| based on the boolean mask and transpose_b setting. | ||
| """ | ||
| shape = (4, 4) | ||
| # Create a mask: e.g., index 1 and 3 are True (should be zeroed) | ||
| mask = np.array([False, True, False, True]) | ||
|
|
||
| # CASE 1: transpose_b=True (Layout [Out, In] -> Columns are inputs) | ||
| weight = np.ones(shape, dtype=np.int32) | ||
| utils.zero_mask_columns(weight, mask, transpose_b=True) | ||
|
|
||
| # Columns 1 and 3 should be 0, others 1 | ||
| expected = np.ones(shape, dtype=np.int32) | ||
| expected[:, mask] = 0 | ||
| np.testing.assert_array_equal(weight, expected) | ||
|
|
||
| # CASE 2: transpose_b=False (Layout [In, Out] -> Rows are inputs) | ||
| weight = np.ones(shape, dtype=np.int32) | ||
| utils.zero_mask_columns(weight, mask, transpose_b=False) | ||
|
|
||
| # Rows 1 and 3 should be 0, others 1 | ||
| expected = np.ones(shape, dtype=np.int32) | ||
| expected[mask, :] = 0 | ||
| np.testing.assert_array_equal(weight, expected) | ||
|
|
||
|
|
||
|
|
||
|
|
||
| def test_slice_utils_pytorch_compatibility(): | ||
| """ | ||
| Ensures the helpers work with torch.Tensor objects, not just numpy arrays. | ||
| """ | ||
| # [In, Out] = [4, 2] | ||
| # transpose_b=False | ||
| weight = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]]) | ||
|
|
||
| # 1. Test Slicing (taking middle 2 rows) | ||
| block = utils.slice_weight(weight, 1, 3, transpose_b=False) | ||
| assert torch.equal(block, torch.tensor([[3, 4], [5, 6]])) | ||
|
|
||
| # 2. Test Assigning | ||
| new_data = torch.tensor([[10, 10], [10, 10]]) | ||
| utils.assign_weight_slice(weight, 1, 3, new_data, transpose_b=False) | ||
|
|
||
| expected = torch.tensor([[1, 2], [10, 10], [10, 10], [7, 8]]) | ||
| assert torch.equal(weight, expected) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
utilsname violates the code style: https://github.com/openvinotoolkit/nncf/blob/develop/docs/styleguide/PyGuide.md#474-file-namingPossible name: tensor_slicing.py