diff --git a/src/nncf/openvino/quantization/quantize_model.py b/src/nncf/openvino/quantization/quantize_model.py index 4ac077a17d7..e329a7f3154 100644 --- a/src/nncf/openvino/quantization/quantize_model.py +++ b/src/nncf/openvino/quantization/quantize_model.py @@ -376,6 +376,7 @@ def compress_weights_impl( scale_estimation: bool, gptq: bool, lora_correction: bool, + codebook_estimation: bool, backup_mode: BackupMode, compression_format: CompressionFormat, advanced_parameters: Optional[AdvancedCompressionParameters] = None, @@ -397,6 +398,7 @@ def compress_weights_impl( scale_estimation, gptq, lora_correction, + codebook_estimation, backup_mode, compression_format, advanced_parameters, diff --git a/src/nncf/quantization/algorithms/weight_compression/algorithm.py b/src/nncf/quantization/algorithms/weight_compression/algorithm.py index 50a5399a5c8..d02de7e33ab 100644 --- a/src/nncf/quantization/algorithms/weight_compression/algorithm.py +++ b/src/nncf/quantization/algorithms/weight_compression/algorithm.py @@ -41,6 +41,7 @@ from nncf.quantization.advanced_parameters import convert_to_dict_recursively from nncf.quantization.algorithms.algorithm import Algorithm from nncf.quantization.algorithms.weight_compression.awq import AWQ +from nncf.quantization.algorithms.weight_compression.codebook_estimation import CodebookEstimation from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters from nncf.quantization.algorithms.weight_compression.constants import CB4_QUANTILES from nncf.quantization.algorithms.weight_compression.gptq import GPTQ @@ -86,6 +87,7 @@ def get_weight_compression_configuration( scale_estimation: Optional[bool] = None, gptq: Optional[bool] = None, lora_correction: Optional[bool] = None, + codebook_estimation: Optional[bool] = None, ignored_scope: Optional[IgnoredScope] = None, sensitivity_metric: Optional[SensitivityMetric] = None, backup_mode: Optional[BackupMode] = None, @@ -111,6 +113,7 @@ def get_weight_compression_configuration( "scale_estimation": scale_estimation or False, "gptq": gptq or False, "lora_correction": lora_correction or False, + "codebook_estimation": codebook_estimation or False, "ignored_scope": ignored_scope or IgnoredScope(), "sensitivity_metric": ( ( @@ -137,6 +140,7 @@ def check_user_compression_configuration( scale_estimation: Optional[bool], gptq: Optional[bool], lora_correction: Optional[bool], + codebook_estimation: Optional[bool], ignored_scope: Optional[IgnoredScope], sensitivity_metric: Optional[SensitivityMetric], backup_mode: Optional[BackupMode], @@ -167,6 +171,7 @@ def check_user_compression_configuration( "gptq": gptq, "lora_correction": lora_correction, "backup_mode": backup_mode, + "codebook_estimation": codebook_estimation, } unsupported_for_int8 = [name for name, value in unsupported_options.items() if value is not None] if unsupported_for_int8: @@ -280,6 +285,7 @@ def __init__( scale_estimation: bool, gptq: bool, lora_correction: bool, + codebook_estimation: bool, backup_mode: BackupMode = BackupMode.INT8_ASYM, compression_format: CompressionFormat = CompressionFormat.DQ, advanced_parameters: Optional[AdvancedCompressionParameters] = None, @@ -339,6 +345,7 @@ def __init__( self._scale_estimation = scale_estimation self._gptq = gptq self._lora_correction = lora_correction + self._codebook_estimation = codebook_estimation self._backup_mode = backup_mode self._compression_format = compression_format self._advanced_parameters = ( @@ -379,6 +386,9 @@ def __init__( scale_estimation_params.weight_penalty, ) + if self._codebook_estimation: + self._codebook_estimation_algo = CodebookEstimation() + self._data_aware_mixed_precision = ( self._sensitivity_metric != SensitivityMetric.WEIGHT_QUANTIZATION_ERROR and self._ratio != 1.0 ) @@ -387,6 +397,7 @@ def __init__( or self._scale_estimation or self._lora_correction or self._gptq + or self._codebook_estimation ) @property @@ -938,6 +949,15 @@ def apply( lora_correction_algo = None description = "Applying Weight Compression" + if self._codebook_estimation: + precomputed_compressed_weights = self._codebook_estimation_algo.apply( + model=model, + graph=graph, + all_weight_params=all_weight_params, + statistics=statistics, + backend_entity=self._backend_entity, + ) + if self._gptq: del statistics model, precomputed_compressed_weights = self._gptq_algo.apply( diff --git a/src/nncf/quantization/algorithms/weight_compression/codebook_estimation.py b/src/nncf/quantization/algorithms/weight_compression/codebook_estimation.py new file mode 100644 index 00000000000..33f1bd81321 --- /dev/null +++ b/src/nncf/quantization/algorithms/weight_compression/codebook_estimation.py @@ -0,0 +1,376 @@ +# Copyright (c) 2025 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from copy import deepcopy +from dataclasses import dataclass +from typing import Optional, TypeVar + +import nncf +from nncf.common.graph.graph import NNCFGraph +from nncf.common.logging.track_progress import track +from nncf.common.utils.backend import BackendType +from nncf.common.utils.backend import get_backend +from nncf.experimental.common.tensor_statistics.statistics import WCTensorStatistic +from nncf.quantization.algorithms.weight_compression.activation_stats import process_stats +from nncf.quantization.algorithms.weight_compression.backend import WeightCompressionAlgoBackend +from nncf.quantization.algorithms.weight_compression.config import WeightCompressionConfig +from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters +from nncf.quantization.algorithms.weight_compression.constants import CB4_QUANTILES +from nncf.quantization.algorithms.weight_compression.parameters import CompressedWeight +from nncf.quantization.algorithms.weight_compression.weight_lowering import _calculate_normalized_weight +from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_float_quantization_params +from nncf.quantization.algorithms.weight_compression.weight_lowering import float_quantize_dequantize_weight +from nncf.quantization.algorithms.weight_compression.weight_lowering import reshape_weight_for_grouped_quantization +from nncf.tensor import Tensor +from nncf.tensor import TensorDataType +from nncf.tensor import functions as fns + +TModel = TypeVar("TModel") + + +class CodebookEstimation: + """ + Codebook estimation algorithm implementation. + """ + + def __init__( + self, + ): + """ + Initializes the CodebookEstimation algorithm. + """ + super().__init__() + + @property + def available_backends(self) -> list[BackendType]: + return [BackendType.OPENVINO] + + def _set_backend_entity(self, model: TModel) -> None: + """ + Creates a helper class with a backed-specific logic of the algorithm. + + :param model: Backend-specific input model. + """ + model_backend = get_backend(model) + if model_backend == BackendType.OPENVINO: + from nncf.quantization.algorithms.weight_compression.openvino_backend import OVWeightCompressionAlgoBackend + + self._backend_entity = OVWeightCompressionAlgoBackend(model) + else: + msg = ( + "Cannot return backend-specific Codebook Estimation entity because" + f" {model_backend.value} is not supported!" + ) + raise nncf.UnsupportedBackendError(msg) + + def apply( + self, + model: TModel, + graph: NNCFGraph, + all_weight_params: list[WeightCompressionParameters], + statistics: dict[str, WCTensorStatistic], + backend_entity: Optional[WeightCompressionAlgoBackend] = None, + ) -> dict[str, CompressedWeight]: + """ + Estimates better codebook. + Minimizes difference between floating point MatMul and + MatMul with compressed weights. + The algorithm computes codebook and indexes for MatMul compression. + + :param model: Model for applying algorithm. + :param graph: Model graph. + :param all_weight_params: List of all weight parameters. + :param statistics: Input activation statistics for each node. + :param statistic_points: Statistic points with collected statistics values. + :param dataset: A representative dataset for the calibration process. + :param backend_entity: Weight compression algorithm backend. + :return: A dictionary that maps weight names to CompressedWeight with codebook, codebook indexes and scale. + """ + self._backend_entity = backend_entity + if self._backend_entity is None: + self._set_backend_entity(model) + res = dict() + + for wp in track(all_weight_params, description="Applying Codebook Estimation"): + weight_name = wp.weight_name + node_name = wp.node_with_weight.node_name + config = wp.compression_config + + if config.num_bits != 4: # or node_name not in statistics: + res[weight_name] = CompressedWeight() + continue + + stats = statistics[node_name] + + weight_data = self._backend_entity.get_weight_names_and_port_ids(wp.node_with_weight, graph) + if len(weight_data) != 1: # not supported by the algorithm + continue + _, weight_port_id = weight_data[0] + + weight = self._backend_entity.get_weight(wp.node_with_weight, weight_port_id, model, graph) + + codebook, scale, indexes = self.calculate_codebook(stats, weight, wp.reduction_axes, config, wp) + res[weight_name] = CompressedWeight(indexes, scale, None, codebook) + config.codebook_values = codebook + + return res + + @staticmethod + def calculate_codebook( + statistics: WCTensorStatistic, + weight: Tensor, + reduction_axes: tuple[int, ...], + config: WeightCompressionConfig, + wp: WeightCompressionParameters, + ) -> Tensor: + reduction_axis = reduction_axes[0] + weight = deepcopy(weight.astype(TensorDataType.float32)) + + s, X = process_stats(statistics, 128) + + if reduction_axis == 0: + weight = fns.transpose(weight) + reduction_axis = 1 + + if config.group_size != -1: + weight, reduction_axes = reshape_weight_for_grouped_quantization(weight, reduction_axes, config.group_size) + + orig_shape = weight.shape + + importance = fns.ones_like(weight) + importance = importance * s + + scale = calculate_float_quantization_params(weight, reduction_axes, config, signed=True) + norm_weight = _calculate_normalized_weight(weight, scale) + + codebook, indexes, variants = weights_clusterization_k_means(norm_weight, importance) + + indexes = indexes.reshape(orig_shape) + + best_codebook = codebook.as_openvino_tensor().astype(TensorDataType.f8e4m3) + + fp_outs = fns.matmul(weight, X) + diff = float("inf") + + variants[0] = fns.tensor(CB4_QUANTILES, backend=weight.backend, dtype=weight.dtype) + variants[1] = fns.tensor(list(range(-8, 8)), backend=weight.backend, dtype=weight.dtype) + + for var in variants: + var = var.as_openvino_tensor().astype(TensorDataType.f8e4m3) + config.codebook_values = Tensor(var) + qw = float_quantize_dequantize_weight(weight, config, wp.reduction_axes) + q_outs = fns.matmul(qw, X) + + cur_diff = fns.mean(fns.abs(fp_outs - q_outs)).item() + if cur_diff < diff: + diff = cur_diff + best_codebook = var + + return Tensor(best_codebook), None, None + + +def round_to_left(quantiles, values): + center_of_quantiles = 0.5 * (quantiles[1:] + quantiles[:-1]) + return fns.searchsorted(center_of_quantiles, values, side="left", sorter=None) + + +@dataclass +class KMeansAlgoData: + centroids: Tensor + hist: Tensor + weighted_hist: Tensor | None = None + + frequencies: Tensor | None = None + weights: Tensor | None = None + + +class KMeansWeighted: + def __init__(self, n_clusters=8, max_iter=300): + self.n_clusters = n_clusters + self.max_iter = max_iter + self.variants = [] + self.centroids = None + + @staticmethod + def get_init(values, frequencies, n_clusters): + step = 1.0 / (n_clusters - 1) + denum = fns.sum(frequencies) + quants = [i * step for i in range(n_clusters)] + n_frequencies = frequencies / denum + n_frequencies = fns.cumsum(n_frequencies, axis=0) + + res = fns.zeros((n_clusters,), backend=values.backend, dtype=values.dtype) + for i in range(n_clusters): + if i == 0: + res[i] = values[0] + elif i == n_clusters - 1: + res[i] = values[-1] + else: + prev_val = values[fns.nonzero(n_frequencies <= quants[i])[0][-1].item()].item() + next_val = values[fns.nonzero(n_frequencies <= quants[i + 1])[0][-1].item()].item() + res[i] = (prev_val + next_val) / 2 + + # avoid close centroids + th = 0.05 + for i in range(1, n_clusters - 1): + if (res[i] - res[i + 1]).abs() / max(res[i].abs(), res[i + 1].abs()) < th: + res[i] = (res[i - 1] + res[i + 1]) / 2 + + return res + + @staticmethod + def create_histogramm(data, granularity=0.01): + centers = [] + step = granularity + + data_range = (data.min().item(), data.max().item()) + prev = data_range[0] + + while prev < data_range[1]: + centers.append(prev + step / 2) + prev += step + + centers = fns.tensor(centers, backend=data.backend) + centroid_idxs = round_to_left(centers, data) + + res = [[], [], []] + for i in range(centers.size): + idxs = fns.nonzero(centroid_idxs == i) + if len(idxs[0]) == 0: + continue + res[0].append(centers[i]) + res[1].append(fns.sum(data[idxs])) + res[2].append(len(idxs[0])) + + res[0] = fns.tensor(res[0], backend=data.backend) # centers of histogram bins + res[1] = fns.tensor(res[1], backend=data.backend) # sum of values in each bin + res[2] = fns.tensor(res[2], backend=data.backend) # count of values in each bin + + return res + + @staticmethod + def add_weighted_data_and_weights(res, data, importance): + res[1].append(fns.sum(fns.multiply(data, importance)).item()) + res[2].append(fns.sum(importance).item()) + + @staticmethod + def create_histogramm_sorted(data_, importance, granularity=0.01): + centers = [] + ranges = [] + step = data_.max().item() * granularity / 3.5 + + sorted_idx = fns.argsort(data_) + data = data_[sorted_idx] + importance = importance[sorted_idx] + + data_range = (data.min().item(), data.max().item()) + prev = data_range[0] + + while prev < data_range[1]: + centers.append(prev + step / 2) + prev += step + + if len(centers) > 1: + ranges.append(0.5 * (centers[-2] + centers[-1])) + ranges.append(centers[-1]) + + centers = fns.tensor(centers, backend=data_.backend, dtype=data_.dtype) + ranges = fns.tensor(ranges, backend=data_.backend, dtype=data_.dtype) + + ranges_idxs = round_to_left(data, ranges) + + res = [[], [], []] + for i in range(centers.size): + res[0].append(centers[i]) + if i == 0: + KMeansWeighted.add_weighted_data_and_weights( + res, data[: ranges_idxs[1].item()], importance[: ranges_idxs[1].item()] + ) + elif i == centers.size - 1: + KMeansWeighted.add_weighted_data_and_weights( + res, data[ranges_idxs[-2].item() :], importance[ranges_idxs[-2].item() :] + ) + else: + idx = 2 * i + KMeansWeighted.add_weighted_data_and_weights( + res, + data[ranges_idxs[idx - 1].item() : ranges_idxs[idx + 1].item()], + importance[ranges_idxs[idx - 1].item() : ranges_idxs[idx + 1].item()], + ) + + res[0] = centers + res[1] = fns.tensor(res[1], backend=data_.backend, dtype=data_.dtype) + res[2] = fns.tensor(res[2], backend=data_.backend, dtype=data_.dtype) + + return res + + def fit(self, X_train, importance, init, fixed=None): + if self.max_iter == 1: + self.centroids = deepcopy(init) + return + if fixed is None: + fixed = [0, len(init) // 2, len(init) - 1] + + self.hist = KMeansWeighted.create_histogramm_sorted(X_train, importance) + + init_by_hist = self.get_init(self.hist[0], self.hist[2], self.n_clusters) + init_by_hist[0] = init[0] + init_by_hist[-1] = init[-1] + zero_idx = fns.argmin(fns.abs(init_by_hist[:]), axis=0).item() + init_by_hist[zero_idx] = 0.0 # to have zero in codebook + fixed[1] = zero_idx + init = init_by_hist + + self.centroids = deepcopy(init) + + iteration = 0 + prev_centroids = self.centroids + while iteration < self.max_iter: + prev_centroids = deepcopy(self.centroids) + + if iteration % 5 == 0: + self.variants.append(deepcopy(self.centroids)) + + centroid_idxs = round_to_left(self.centroids, self.hist[0]) + for i in range(self.n_clusters): + idxs = fns.nonzero(centroid_idxs == i) + self.centroids[i] = fns.sum(self.hist[1][idxs]).item() / fns.sum(self.hist[2][idxs]).item() + + for idx in fixed: + self.centroids[idx] = init[idx] + iteration += 1 + if fns.any(fns.all(fns.abs(self.centroids - prev_centroids) < 0.00001)): + break + + self.variants.append(deepcopy(self.centroids)) + + def evaluate(self, X): + centroid_idxs = round_to_left(self.centroids, X) + return deepcopy(self.centroids).flatten(), centroid_idxs + + +def weights_clusterization_k_means(weight, importance, n_centroids=2**4): + orig_shape = weight.shape + weight = weight.flatten() + importance = importance.flatten() + + n_init = [0, 0] + n_init[0] = weight.min() + n_init[-1] = weight.max() + + kmeans = KMeansWeighted(n_centroids, max_iter=70) + + kmeans.fit(weight, importance, n_init, fixed=[0, 7, 15]) + codebook, indexes = kmeans.evaluate(weight) + + indexes = fns.reshape(indexes, orig_shape) + + return codebook, indexes, kmeans.variants diff --git a/src/nncf/quantization/algorithms/weight_compression/openvino_backend.py b/src/nncf/quantization/algorithms/weight_compression/openvino_backend.py index ca6f83573c7..eebd7600cd6 100644 --- a/src/nncf/quantization/algorithms/weight_compression/openvino_backend.py +++ b/src/nncf/quantization/algorithms/weight_compression/openvino_backend.py @@ -259,9 +259,7 @@ def _create_compression_subgraph( n_quants = compressed_weight.codebook.size - 1 compression_dtype = ov.Type.u16 if n_quants > 255 else (ov.Type.u8 if n_quants > 15 else ov.Type.u4) converted_const = create_ov_codebook_subgraph( - codebook=compressed_weight.codebook - if compression_config.mode == CompressWeightsMode.CODEBOOK - else compressed_weight.codebook.as_openvino_tensor().astype(TensorDataType.f8e4m3), + compressed_weight.codebook.as_openvino_tensor(), indexes=compressed_weight.tensor, dtype=compression_dtype, name=const_node_name, diff --git a/src/nncf/quantization/algorithms/weight_compression/scale_estimation.py b/src/nncf/quantization/algorithms/weight_compression/scale_estimation.py index e185e65bcb5..e74320af17e 100644 --- a/src/nncf/quantization/algorithms/weight_compression/scale_estimation.py +++ b/src/nncf/quantization/algorithms/weight_compression/scale_estimation.py @@ -227,6 +227,7 @@ def calculate_quantization_params( # all weight in group has importance based on corresponding input activations importance = fns.ones_like(original_weight) + importance = importance * s target, zero_mask = get_target_zero_mask(compressed_weights, zp) diff --git a/src/nncf/quantization/algorithms/weight_compression/weight_lowering.py b/src/nncf/quantization/algorithms/weight_compression/weight_lowering.py index 0e0783cf468..1d462b019b6 100644 --- a/src/nncf/quantization/algorithms/weight_compression/weight_lowering.py +++ b/src/nncf/quantization/algorithms/weight_compression/weight_lowering.py @@ -78,7 +78,7 @@ def reshape_weight_for_grouped_quantization( def calculate_float_quantization_params( - weight: Tensor, reduction_axes: ReductionAxes, config: WeightCompressionConfig + weight: Tensor, reduction_axes: ReductionAxes, config: WeightCompressionConfig, signed: bool = False ) -> Tensor: """ Calculates the scale for nf4 or mxfp4/mxfp8_e4m3 quantization. @@ -93,7 +93,13 @@ def calculate_float_quantization_params( if weight.dtype != TensorDataType.float32: weight = weight.astype(TensorDataType.float32) - scale = fns.max(fns.abs(weight), axis=reduction_axes, keepdims=True) + if signed: + scale_neg = fns.min(weight, axis=reduction_axes, keepdims=True) + scale_pos = fns.max(weight, axis=reduction_axes, keepdims=True) + scale = fns.where(fns.abs(scale_neg) >= fns.abs(scale_pos), scale_neg, scale_pos) + else: + scale = fns.max(fns.abs(weight), axis=reduction_axes, keepdims=True) + FP_MAX_VALS = { CompressWeightsMode.MXFP4: 6.0, CompressWeightsMode.MXFP8_E4M3: 448.0, @@ -341,6 +347,13 @@ def compress_weight( ) if not config.is_integer: + if ( + precomputed_compressed_weight is not None + and precomputed_compressed_weight.tensor is not None + and precomputed_compressed_weight.codebook is not None + ): + return precomputed_compressed_weight + compressed_weight, scale, indexes = do_float_quantization(weight, config, reduction_axes, precomputed_scale) if indexes is not None: return CompressedWeight( diff --git a/src/nncf/quantization/quantize_model.py b/src/nncf/quantization/quantize_model.py index 8d89bd19821..f0cc32e69c1 100644 --- a/src/nncf/quantization/quantize_model.py +++ b/src/nncf/quantization/quantize_model.py @@ -436,6 +436,7 @@ def compress_weights( scale_estimation: Optional[bool] = None, gptq: Optional[bool] = None, lora_correction: Optional[bool] = None, + codebook_estimation: Optional[bool] = None, backup_mode: Optional[BackupMode] = None, compression_format: CompressionFormat = CompressionFormat.DQ, advanced_parameters: Optional[AdvancedCompressionParameters] = None, @@ -580,6 +581,7 @@ def compress_weights( options = { "gptq": gptq, "lora_correction": lora_correction, + "codebook_estimation": codebook_estimation, } unsupported_options = [name for name, value in options.items() if value is not None] if unsupported_options: @@ -606,7 +608,7 @@ def compress_weights( elif backend == BackendType.OPENVINO: from nncf.openvino.quantization.quantize_model import compress_weights_impl as ov_compress_weights_impl - if any((scale_estimation, gptq, lora_correction)) and dataset is None: + if any((scale_estimation, gptq, lora_correction, codebook_estimation)) and dataset is None: msg = "Scale estimation, GPTQ or Lora Correction algorithm is defined, but dataset is None." raise nncf.ParameterNotSupportedError(msg) @@ -645,6 +647,7 @@ def compress_weights( options = { "gptq": gptq, "lora_correction": lora_correction, + "codebook_estimation": codebook_estimation, } unsupported_options = [name for name, value in options.items() if value is not None] if unsupported_options: @@ -670,6 +673,7 @@ def compress_weights( scale_estimation, gptq, lora_correction, + codebook_estimation, ignored_scope, sensitivity_metric, backup_mode, @@ -686,6 +690,7 @@ def compress_weights( scale_estimation, gptq, lora_correction, + codebook_estimation, ignored_scope, sensitivity_metric, backup_mode, diff --git a/src/nncf/tensor/functions/__init__.py b/src/nncf/tensor/functions/__init__.py index f72986ba7be..18a4fdd61ca 100644 --- a/src/nncf/tensor/functions/__init__.py +++ b/src/nncf/tensor/functions/__init__.py @@ -16,6 +16,7 @@ from nncf.tensor.functions.numeric import allclose as allclose from nncf.tensor.functions.numeric import any as any from nncf.tensor.functions.numeric import arange as arange +from nncf.tensor.functions.numeric import argmin as argmin from nncf.tensor.functions.numeric import argsort as argsort from nncf.tensor.functions.numeric import as_tensor_like as as_tensor_like from nncf.tensor.functions.numeric import astype as astype @@ -53,6 +54,7 @@ from nncf.tensor.functions.numeric import minimum as minimum from nncf.tensor.functions.numeric import moveaxis as moveaxis from nncf.tensor.functions.numeric import multiply as multiply +from nncf.tensor.functions.numeric import nonzero as nonzero from nncf.tensor.functions.numeric import ones_like as ones_like from nncf.tensor.functions.numeric import percentile as percentile from nncf.tensor.functions.numeric import power as power diff --git a/src/nncf/tensor/functions/numeric.py b/src/nncf/tensor/functions/numeric.py index 6a413e81fb5..d4f0c947e4f 100644 --- a/src/nncf/tensor/functions/numeric.py +++ b/src/nncf/tensor/functions/numeric.py @@ -324,6 +324,16 @@ def where(condition: Tensor, x: Union[Tensor, float], y: Union[Tensor, float]) - """ +@tensor_dispatcher +def nonzero(condition: Tensor) -> tuple[Tensor, ...]: + """ + Return the indices of the elements that are non-zero. + + :param condition: The input tensor. + :return: A tensor containing the indices of the non-zero elements. + """ + + @tensor_dispatcher def zeros_like(a: Tensor) -> Tensor: """ @@ -662,6 +672,17 @@ def argsort(a: Tensor, axis: int = -1, descending: bool = False, stable: bool = """ +@tensor_dispatcher +def argmin(a: Tensor, axis: None) -> Tensor: + """ + Returns the indices of the minimum values along an axis. + + :param a: The tensor for which to find the minimum values. + :param axis: Axis or tuple of axes along which to find the minimum values. + :return: Indices of the minimum values along an axis. + """ + + @tensor_dispatcher def diag(a: Tensor, k: int = 0) -> Tensor: """ diff --git a/src/nncf/tensor/functions/numpy_numeric.py b/src/nncf/tensor/functions/numpy_numeric.py index f3bdc755552..d6562f5b268 100644 --- a/src/nncf/tensor/functions/numpy_numeric.py +++ b/src/nncf/tensor/functions/numpy_numeric.py @@ -193,6 +193,13 @@ def _( return np.where(condition, x, y) +@numeric.nonzero.register +def _( + condition: T_NUMPY, +) -> tuple[T_NUMPY_ARRAY, ...]: + return np.nonzero(condition) + + @numeric.zeros_like.register def _(a: T_NUMPY) -> T_NUMPY_ARRAY: return np.zeros_like(a) @@ -314,16 +321,16 @@ def _(a: T_NUMPY) -> T_NUMBER: return a.item() -@numeric.cumsum.register -def _(a: T_NUMPY, axis: int) -> T_NUMPY: - return np.cumsum(a, axis=axis) - - @numeric.sum.register def _(a: T_NUMPY, axis: T_AXIS = None, keepdims: bool = False) -> T_NUMPY_ARRAY: return np.array(np.sum(a, axis=axis, keepdims=keepdims)) +@numeric.cumsum.register +def _(a: T_NUMPY, axis: int) -> T_NUMPY: + return np.cumsum(a, axis=axis) + + @numeric.multiply.register def _(x1: T_NUMPY, x2: Union[T_NUMPY, float]) -> T_NUMPY_ARRAY: return np.multiply(x1, x2) @@ -368,6 +375,11 @@ def _(a: T_NUMPY, axis: int = -1, descending: bool = False, stable: bool = False return np.argsort(a, axis=axis, kind="stable" if stable else None) +@numeric.argmin.register +def _(a: T_NUMPY, axis: None) -> T_NUMPY: + return np.argmin(a, axis=axis) + + @numeric.diag.register def _(a: T_NUMPY, k: int = 0) -> T_NUMPY_ARRAY: return np.diag(a, k=k) diff --git a/src/nncf/tensor/functions/torch_numeric.py b/src/nncf/tensor/functions/torch_numeric.py index f7894217918..9000be673f3 100644 --- a/src/nncf/tensor/functions/torch_numeric.py +++ b/src/nncf/tensor/functions/torch_numeric.py @@ -341,16 +341,16 @@ def _(a: torch.Tensor) -> T_NUMBER: return a.item() -@numeric.cumsum.register -def _(a: torch.Tensor, axis: int) -> torch.Tensor: - return torch.cumsum(a, dim=axis) - - @numeric.sum.register def _(a: torch.Tensor, axis: T_AXIS = None, keepdims: bool = False) -> torch.Tensor: return torch.sum(a, dim=axis, keepdim=keepdims) +@numeric.cumsum.register +def _(a: torch.Tensor, axis: int) -> torch.Tensor: + return torch.cumsum(a, dim=axis) + + @numeric.multiply.register def _(x1: torch.Tensor, x2: Union[torch.Tensor, float]) -> torch.Tensor: return torch.multiply(x1, x2)