Skip to content
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
ada9c0a
Lora correction input transpose support
rk119 Feb 17, 2025
90381d7
Merge branch 'openvinotoolkit:develop' into support_transposed_input
rk119 Feb 19, 2025
a18ba70
OV backend act transpose support
rk119 Feb 24, 2025
64f769e
Merge branch 'develop' into support_transposed_input
rk119 Feb 24, 2025
d69ff3f
pre-commit fix
rk119 Feb 24, 2025
c78b033
Merge branch 'develop' into support_transposed_input
rk119 Feb 27, 2025
d493314
Brute force solution
rk119 Feb 27, 2025
202815b
Minor fix
rk119 Feb 27, 2025
1dc14ba
pre-commit fix
rk119 Feb 27, 2025
50eddb4
attempt fix
rk119 Feb 27, 2025
9dccf21
Add doc string
rk119 Mar 2, 2025
f5d2a1f
Merge branch 'develop' into support_transposed_input
rk119 Mar 2, 2025
3fde78a
Merge branch 'develop' into support_transposed_input
rk119 Mar 4, 2025
aa062dc
Implement get_activation_channel_axis
rk119 Mar 5, 2025
051a183
fix test
rk119 Mar 5, 2025
7f7f468
Fix error
rk119 Mar 5, 2025
b82b0c4
Merge branch 'develop' into support_transposed_input
rk119 Mar 13, 2025
e19804c
Fix OV NNCF Graph Builder add edges
rk119 Mar 13, 2025
8016399
Update
rk119 Mar 18, 2025
5489669
Merge branch 'develop' into support_transposed_input
rk119 Mar 18, 2025
298b587
Update
rk119 Mar 18, 2025
9a7cb02
Update
rk119 Mar 19, 2025
93c9cef
Merge branch 'develop' into support_transposed_input
rk119 Mar 19, 2025
341c4a8
Update
rk119 Mar 19, 2025
cda3fed
Update
rk119 Mar 20, 2025
a324587
Merge branch 'develop' into support_transposed_input
rk119 Mar 20, 2025
cc0d909
Merge branch 'develop' into support_transposed_input
rk119 Mar 26, 2025
2b9a6c2
Merge branch 'develop' into HEAD
rk119 Apr 8, 2025
b104c4d
Merge branch 'develop' into support_transposed_input
rk119 Apr 9, 2025
cb8c0d3
Merge branch 'develop' into support_transposed_input
rk119 May 4, 2025
f854405
Update torch_backend.py
rk119 May 4, 2025
78b2993
Update backend.py
rk119 May 4, 2025
670a5b4
Update openvino_backend.py
rk119 May 4, 2025
ab9c054
Update torch_fx_backend.py
rk119 May 4, 2025
19febb6
update
rk119 May 4, 2025
d808f70
pre-commit fix
rk119 May 4, 2025
8897321
Update nncf/quantization/algorithms/weight_compression/scale_estimati…
rk119 May 5, 2025
31a0c40
Update nncf/quantization/algorithms/weight_compression/gptq.py
rk119 May 5, 2025
a495284
Update nncf/quantization/algorithms/weight_compression/algorithm.py
rk119 May 5, 2025
90cb1c7
pre-commit fix
rk119 May 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions nncf/quantization/algorithms/weight_compression/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,9 +774,15 @@ def get_statistic_points(
)
# Reduce activations across all but the last dimension. The last dimension is assumed to be the hidden
# size dimension.
n_dims = len(graph.get_output_edges_by_port_id(node, output_port_id)[0].tensor_shape)
output_edge = graph.get_output_edges_by_port_id(node, output_port_id)[0]
n_dims = len(output_edge.tensor_shape)
reduction_axes = tuple(
i
for i in range(n_dims)
if i != n_dims + self._backend_entity.get_input_hidden_dim(output_edge.to_node)
)
stat_collector = self._backend_entity.mean_statistic_collector(
reduction_axes=tuple(range(n_dims - 1)), subset_size=self._subset_size
reduction_axes=reduction_axes, subset_size=self._subset_size
)
statistic_container.add_statistic_point(
StatisticPoint(
Expand Down
10 changes: 10 additions & 0 deletions nncf/quantization/algorithms/weight_compression/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,16 @@ def get_filter_fn_for_statistics(activation_port_id: int, algorithm_key: str) ->
:return: Backend-specific callable to filter statistic containers according to its statistic point.
"""

@staticmethod
def get_input_hidden_dim(input_node: NNCFNode) -> int:
"""
Returns the index of the hidden dimension in the shape of the input node.

:param input_node: The input node.
:return: The index of the hidden dimension in the shape of the input node.
"""
return -1


class AWQAlgoBackend(WeightCompressionAlgoBackend):
@staticmethod
Expand Down
19 changes: 12 additions & 7 deletions nncf/quantization/algorithms/weight_compression/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,19 +170,19 @@ def _calculate_hessian(self, node: NNCFNode, inputs: List[Tensor]) -> Tensor:
if node.metatype in self._backend_entity.convolution_metatypes:
msg = "Convolution metatypes are not supported"
raise nncf.UnsupportedModelError(msg)
if node.layer_attributes.input_attributes["transpose"]:
msg = "Transposed input is not supported"
raise nncf.UnsupportedModelError(msg)

hidden_dim = self._backend_entity.get_input_hidden_dim(node)
hessian = fns.zeros(
(inputs[0].shape[-1], inputs[0].shape[-1]), backend=inputs[0].backend, dtype=TensorDataType.float32
(inputs[0].shape[hidden_dim], inputs[0].shape[hidden_dim]),
backend=inputs[0].backend,
dtype=TensorDataType.float32,
)

for inp in inputs:
batch_size = 1 if len(inp.shape) == 2 else inp.shape[0]
if node.metatype in self._backend_entity.matmul_metatypes:
if len(inp.shape) == 3:
inp = inp.reshape((-1, inp.shape[-1]))
inp = inp.reshape((-1, inp.shape[hidden_dim]))
inp = fns.transpose(inp)
hessian *= nsamples / (nsamples + batch_size)
nsamples += batch_size
Expand Down Expand Up @@ -267,8 +267,13 @@ def _quantize_weights(
scales.append(scale)
else:
if self._scale_estimation and block_compression_config.num_bits == 4:
activations = [inp[..., (i1 + i) : (i1 + i + group_size)] for inp in inputs]
wc_statistics = ScaleEstimation.activations_to_wc_statistics(activations)
transpose = self._backend_entity.get_input_hidden_dim(wc_params.node_with_weight) == -2
activations = (
[inp[:, (i1 + i) : (i1 + i + group_size), ...] for inp in inputs]
if transpose
else [inp[..., (i1 + i) : (i1 + i + group_size)] for inp in inputs]
)
wc_statistics = ScaleEstimation.activations_to_wc_statistics(activations, transpose)
scale, zero_point = ScaleEstimation.calculate_quantization_params(
wc_statistics,
weight_tensor[:, (i1 + i) : (i1 + i + group_size)],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,6 @@ def mean_statistic_collector(

@staticmethod
def get_activation_port_id(node: NNCFNode, nncf_graph: NNCFGraph) -> int:
if node.layer_attributes.input_attributes["transpose"]:
msg = "Transposed input is not supported"
raise nncf.UnsupportedModelError(msg)
constant_ports = node.layer_attributes.get_const_port_ids()
activation_ports = [
e.input_port_id for e in nncf_graph.get_input_edges(node) if e.input_port_id not in constant_ports
Expand Down Expand Up @@ -204,7 +201,12 @@ def insert_adapters(
A_W = opset.constant(lora_A.data)
B_W = opset.constant(lora_B.data)

A_MM = opset.matmul(input_node, A_W, transpose_a=False, transpose_b=True)
A_MM = opset.matmul(
input_node,
A_W,
transpose_a=wc_params.node_with_weight.layer_attributes.input_attributes["transpose"],
transpose_b=True,
)
B_MM = opset.matmul(A_MM, B_W, transpose_a=False, transpose_b=True)

node_output_port = mm_node.output(0)
Expand Down Expand Up @@ -364,6 +366,12 @@ def filter_func(point: StatisticPoint) -> bool:

return filter_func

@staticmethod
def get_input_hidden_dim(node: NNCFNode) -> int:
if (node is not None) and (node.layer_attributes is not None):
return -2 if node.layer_attributes.input_attributes["transpose"] else -1
return -1


class OVAWQAlgoAlgoBackend(AWQAlgoBackend, OVWeightCompressionAlgoBackend):
@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def calculate_quantization_params(
return result_scale, zp

@staticmethod
def activations_to_wc_statistics(activations: List[Tensor]) -> WCTensorStatistic:
def activations_to_wc_statistics(activations: List[Tensor], transpose: bool) -> WCTensorStatistic:
"""
Mimic the activation reducing logic from WeightCompression.get_statistic_points.

Expand All @@ -376,7 +376,9 @@ def activations_to_wc_statistics(activations: List[Tensor]) -> WCTensorStatistic
shapes = []
for act in activations:
shapes.append(act.shape)
reduction_shape = tuple(range(act.ndim - 1))
reduction_shape = (
tuple(i for i in range(act.ndim) if i != act.ndim - 2) if transpose else tuple(range(act.ndim - 1))
)
mean_values.append(fns.mean(act, axis=reduction_shape))
wc_statistics = WCTensorStatistic(mean_values, shapes)
return wc_statistics
Expand Down
23 changes: 11 additions & 12 deletions tests/openvino/native/quantization/test_weights_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -1440,21 +1440,20 @@ def test_compression_with_different_algo_combinations(input_shape, kwargs):
)
def test_compression_with_transposed_activations(kwargs):
dataset_size = 4
model = LMLinearModel(transpose_a=True, transpose_b=False).ov_model
model = LMLinearModel(transpose_a=True, transpose_b=True).ov_model
input_data = [np.ones(inp.shape) for inp in model.inputs] * dataset_size
dataset = Dataset(input_data)

with pytest.raises(nncf.UnsupportedModelError):
compress_weights(
model,
mode=CompressWeightsMode.INT4_SYM,
ratio=1.0,
group_size=8,
subset_size=2,
dataset=dataset,
all_layers=True,
**kwargs,
)
compress_weights(
model,
mode=CompressWeightsMode.INT4_SYM,
ratio=1.0,
group_size=8,
subset_size=2,
dataset=dataset,
all_layers=True,
**kwargs,
)


class TestOVTemplateWeightCompression(TemplateWeightCompression):
Expand Down