diff --git a/src/torchmetrics/functional/retrieval/ndcg.py b/src/torchmetrics/functional/retrieval/ndcg.py index d381718c793..5601d4bffbd 100644 --- a/src/torchmetrics/functional/retrieval/ndcg.py +++ b/src/torchmetrics/functional/retrieval/ndcg.py @@ -68,7 +68,39 @@ def _dcg_sample_scores(target: Tensor, preds: Tensor, top_k: int, ignore_ties: b return cumulative_gain -def retrieval_normalized_dcg(preds: Tensor, target: Tensor, top_k: Optional[int] = None) -> Tensor: +def _handle_empty_target(action: str, device: torch.device) -> Optional[Tensor]: + """Return a default nDCG score when the target contains no positive labels. + + Args: + action: policy for handling empty targets: + - "skip": return None (exclude from batch average) + - "pos": return a score of 1.0 + - "neg": return a score of 0.0 + device: the torch device on which to create the output tensor. + + Returns: + A scalar tensor with the default score if action is "pos" or "neg". + None if action is "skip". + + Raises: + ValueError: if ``action`` is not one of {"skip", "pos", "neg"}. + + """ + if action == "skip": + return None + if action == "pos": + return torch.tensor(1.0, device=device) + if action == "neg": + return torch.tensor(0.0, device=device) + raise ValueError(f"Invalid empty_target_action: {action}") + + +def retrieval_normalized_dcg( + preds: Tensor, + target: Tensor, + top_k: Optional[int] = None, + empty_target_action: str = "skip", +) -> Tensor: """Compute `Normalized Discounted Cumulative Gain`_ (for information retrieval). ``preds`` and ``target`` should be of the same shape and live on the same device. @@ -79,6 +111,10 @@ def retrieval_normalized_dcg(preds: Tensor, target: Tensor, top_k: Optional[int] preds: estimated probabilities of each document to be relevant. target: ground truth about each document relevance. top_k: consider only the top k elements (default: ``None``, which considers them all) + empty_target_action: what to do when the target has no positives: + - "skip": exclude from average + - "pos": assign score 1.0 + - "neg": assign score 0.0 Return: A single-value tensor with the nDCG of the predictions ``preds`` w.r.t. the labels ``target``. @@ -95,19 +131,37 @@ def retrieval_normalized_dcg(preds: Tensor, target: Tensor, top_k: Optional[int] tensor(0.6957) """ + original_shape = preds.shape preds, target = _check_retrieval_functional_inputs(preds, target, allow_non_binary_target=True) - top_k = preds.shape[-1] if top_k is None else top_k + # reshape back if input was 2D + if len(original_shape) == 2: + preds = preds.view(original_shape) + target = target.view(original_shape) + else: + preds = preds.unsqueeze(0) + target = target.unsqueeze(0) + + n_samples, n_labels = preds.shape + top_k = n_labels if top_k is None else top_k + top_k = min(top_k, n_labels) if not (isinstance(top_k, int) and top_k > 0): raise ValueError("`top_k` has to be a positive integer or None") - gain = _dcg_sample_scores(target, preds, top_k, ignore_ties=False) - normalized_gain = _dcg_sample_scores(target, target, top_k, ignore_ties=True) + scores = [] + for p, t in zip(preds, target): + gain = _dcg_sample_scores(t, p, top_k, ignore_ties=False) + ideal_gain = _dcg_sample_scores(t, t, top_k, ignore_ties=True) + + if ideal_gain == 0: + score = _handle_empty_target(empty_target_action, preds.device) + if score is not None: + scores.append(score) + else: + scores.append(gain / ideal_gain) - # filter undefined scores - all_irrelevant = normalized_gain == 0 - gain[all_irrelevant] = 0 - gain[~all_irrelevant] /= normalized_gain[~all_irrelevant] + if not scores: + return torch.tensor(0.0, device=preds.device) - return gain.mean() + return torch.stack(scores).mean() diff --git a/tests/unittests/retrieval/_inputs.py b/tests/unittests/retrieval/_inputs.py index cf6b9c40377..c01ec5c5908 100644 --- a/tests/unittests/retrieval/_inputs.py +++ b/tests/unittests/retrieval/_inputs.py @@ -64,6 +64,12 @@ class _Input(NamedTuple): ), ) +_input_retrieval_scores_2d = _Input( + indexes=torch.randint(high=10, size=(NUM_BATCHES, BATCH_SIZE, 2)), + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 2), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, 2)), +) + # with errors _input_retrieval_scores_no_target = _Input( indexes=torch.randint(high=10, size=(NUM_BATCHES, BATCH_SIZE)), diff --git a/tests/unittests/retrieval/helpers.py b/tests/unittests/retrieval/helpers.py index 6747be475ed..2398fe167cb 100644 --- a/tests/unittests/retrieval/helpers.py +++ b/tests/unittests/retrieval/helpers.py @@ -25,6 +25,7 @@ from unittests._helpers import seed_all from unittests._helpers.testers import Metric, MetricTester from unittests.retrieval._inputs import _input_retrieval_scores as _irs +from unittests.retrieval._inputs import _input_retrieval_scores_2d as _irs_2d from unittests.retrieval._inputs import _input_retrieval_scores_all_target as _irs_all from unittests.retrieval._inputs import _input_retrieval_scores_empty as _irs_empty from unittests.retrieval._inputs import _input_retrieval_scores_extra as _irs_extra @@ -99,15 +100,20 @@ def _compute_sklearn_metric( target: Union[Tensor, array], indexes: Optional[np.ndarray] = None, metric: Optional[Callable] = None, - empty_target_action: str = "skip", + empty_target_action: Optional[str] = None, ignore_index: Optional[int] = None, reverse: bool = False, aggregation: Union[Literal["mean", "median", "min", "max"], Callable] = "mean", + metric_name: Optional[str] = None, **kwargs: Any, ) -> Tensor: """Compute metric with multiple iterations over every query predictions set.""" if indexes is None: - indexes = np.full_like(preds, fill_value=0, dtype=np.int64) + if metric_name == "ndcg" and preds.ndim == 2: + row_indexes = np.arange(preds.shape[0], dtype=np.int64)[:, None] + indexes = np.tile(row_indexes, (1, preds.shape[1])) + else: + indexes = np.zeros_like(preds, dtype=np.int64) if isinstance(indexes, Tensor): indexes = indexes.cpu().numpy() if isinstance(preds, Tensor): @@ -393,6 +399,7 @@ def _concat_tests(*tests: tuple[dict]) -> dict: "argnames": "preds,target", "argvalues": [ (_irs.preds, _irs.target), + (_irs_2d.preds, _irs_2d.target), (_irs_extra.preds, _irs_extra.target), (_irs_no_tgt.preds, _irs_no_tgt.target), (_irs_int_tgt.preds, _irs_int_tgt.target), @@ -494,11 +501,14 @@ def run_functional_metric_test( metric_functional: Callable, reference_metric: Callable, metric_args: dict, + metric_name: Optional[str] = None, reverse: bool = False, **kwargs: Any, ): """Test functional implementation of metric.""" - _ref_metric_adapted = partial(_compute_sklearn_metric, metric=reference_metric, reverse=reverse, **metric_args) + _ref_metric_adapted = partial( + _compute_sklearn_metric, metric=reference_metric, reverse=reverse, metric_name=metric_name, **metric_args + ) super().run_functional_metric_test( preds=preds, diff --git a/tests/unittests/retrieval/test_ndcg.py b/tests/unittests/retrieval/test_ndcg.py index 5c8fe20a2f7..fc2049f8ab6 100644 --- a/tests/unittests/retrieval/test_ndcg.py +++ b/tests/unittests/retrieval/test_ndcg.py @@ -120,17 +120,20 @@ def test_class_metric_ignore_index( ) @pytest.mark.parametrize(**_default_metric_functional_input_arguments_with_non_binary_target) + @pytest.mark.parametrize("empty_target_action", ["skip", "pos", "neg"]) @pytest.mark.parametrize("k", [None, 1, 4, 10]) - def test_functional_metric(self, preds: Tensor, target: Tensor, k: int): + def test_functional_metric(self, preds: Tensor, target: Tensor, empty_target_action: str, k: int): """Test functional implementation of metric.""" + metric_args = {"empty_target_action": empty_target_action, "top_k": k} + target = target if target.min() >= 0 else target - target.min() self.run_functional_metric_test( preds=preds, target=target, metric_functional=retrieval_normalized_dcg, reference_metric=_ndcg_at_k, - metric_args={}, - top_k=k, + metric_args=metric_args, + metric_name="ndcg", ) @pytest.mark.parametrize(**_default_metric_class_input_arguments_with_non_binary_target)