Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 63 additions & 9 deletions src/torchmetrics/functional/retrieval/ndcg.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,39 @@ def _dcg_sample_scores(target: Tensor, preds: Tensor, top_k: int, ignore_ties: b
return cumulative_gain


def retrieval_normalized_dcg(preds: Tensor, target: Tensor, top_k: Optional[int] = None) -> Tensor:
def _handle_empty_target(action: str, device: torch.device) -> Optional[Tensor]:
"""Return a default nDCG score when the target contains no positive labels.

Args:
action: policy for handling empty targets:
- "skip": return None (exclude from batch average)
- "pos": return a score of 1.0
- "neg": return a score of 0.0
device: the torch device on which to create the output tensor.

Returns:
A scalar tensor with the default score if action is "pos" or "neg".
None if action is "skip".

Raises:
ValueError: if ``action`` is not one of {"skip", "pos", "neg"}.

"""
if action == "skip":
return None
if action == "pos":
return torch.tensor(1.0, device=device)
if action == "neg":
return torch.tensor(0.0, device=device)
raise ValueError(f"Invalid empty_target_action: {action}")


def retrieval_normalized_dcg(
preds: Tensor,
target: Tensor,
top_k: Optional[int] = None,
empty_target_action: str = "skip",
) -> Tensor:
"""Compute `Normalized Discounted Cumulative Gain`_ (for information retrieval).

``preds`` and ``target`` should be of the same shape and live on the same device.
Expand All @@ -79,6 +111,10 @@ def retrieval_normalized_dcg(preds: Tensor, target: Tensor, top_k: Optional[int]
preds: estimated probabilities of each document to be relevant.
target: ground truth about each document relevance.
top_k: consider only the top k elements (default: ``None``, which considers them all)
empty_target_action: what to do when the target has no positives:
- "skip": exclude from average
- "pos": assign score 1.0
- "neg": assign score 0.0

Return:
A single-value tensor with the nDCG of the predictions ``preds`` w.r.t. the labels ``target``.
Expand All @@ -95,19 +131,37 @@ def retrieval_normalized_dcg(preds: Tensor, target: Tensor, top_k: Optional[int]
tensor(0.6957)

"""
original_shape = preds.shape
preds, target = _check_retrieval_functional_inputs(preds, target, allow_non_binary_target=True)

top_k = preds.shape[-1] if top_k is None else top_k
# reshape back if input was 2D
if len(original_shape) == 2:
preds = preds.view(original_shape)
target = target.view(original_shape)
else:
preds = preds.unsqueeze(0)
Comment on lines +134 to +142
Copy link

Copilot AI Sep 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line 143 should use target.unsqueeze(0) instead of target.view(original_shape). The else branch handles 1D inputs which need to be unsqueezed to match the preds tensor on line 142.

Copilot uses AI. Check for mistakes.

target = target.unsqueeze(0)

n_samples, n_labels = preds.shape
top_k = n_labels if top_k is None else top_k
top_k = min(top_k, n_labels)

if not (isinstance(top_k, int) and top_k > 0):
raise ValueError("`top_k` has to be a positive integer or None")

gain = _dcg_sample_scores(target, preds, top_k, ignore_ties=False)
normalized_gain = _dcg_sample_scores(target, target, top_k, ignore_ties=True)
scores = []
for p, t in zip(preds, target):
gain = _dcg_sample_scores(t, p, top_k, ignore_ties=False)
ideal_gain = _dcg_sample_scores(t, t, top_k, ignore_ties=True)

if ideal_gain == 0:
score = _handle_empty_target(empty_target_action, preds.device)
if score is not None:
scores.append(score)
else:
scores.append(gain / ideal_gain)

# filter undefined scores
all_irrelevant = normalized_gain == 0
gain[all_irrelevant] = 0
gain[~all_irrelevant] /= normalized_gain[~all_irrelevant]
if not scores:
return torch.tensor(0.0, device=preds.device)

return gain.mean()
return torch.stack(scores).mean()
6 changes: 6 additions & 0 deletions tests/unittests/retrieval/_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ class _Input(NamedTuple):
),
)

_input_retrieval_scores_2d = _Input(
indexes=torch.randint(high=10, size=(NUM_BATCHES, BATCH_SIZE, 2)),
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 2),
target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, 2)),
)

# with errors
_input_retrieval_scores_no_target = _Input(
indexes=torch.randint(high=10, size=(NUM_BATCHES, BATCH_SIZE)),
Expand Down
16 changes: 13 additions & 3 deletions tests/unittests/retrieval/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from unittests._helpers import seed_all
from unittests._helpers.testers import Metric, MetricTester
from unittests.retrieval._inputs import _input_retrieval_scores as _irs
from unittests.retrieval._inputs import _input_retrieval_scores_2d as _irs_2d
from unittests.retrieval._inputs import _input_retrieval_scores_all_target as _irs_all
from unittests.retrieval._inputs import _input_retrieval_scores_empty as _irs_empty
from unittests.retrieval._inputs import _input_retrieval_scores_extra as _irs_extra
Expand Down Expand Up @@ -99,15 +100,20 @@ def _compute_sklearn_metric(
target: Union[Tensor, array],
indexes: Optional[np.ndarray] = None,
metric: Optional[Callable] = None,
empty_target_action: str = "skip",
empty_target_action: Optional[str] = None,
ignore_index: Optional[int] = None,
reverse: bool = False,
aggregation: Union[Literal["mean", "median", "min", "max"], Callable] = "mean",
metric_name: Optional[str] = None,
**kwargs: Any,
) -> Tensor:
"""Compute metric with multiple iterations over every query predictions set."""
if indexes is None:
indexes = np.full_like(preds, fill_value=0, dtype=np.int64)
if metric_name == "ndcg" and preds.ndim == 2:
row_indexes = np.arange(preds.shape[0], dtype=np.int64)[:, None]
indexes = np.tile(row_indexes, (1, preds.shape[1]))
else:
indexes = np.zeros_like(preds, dtype=np.int64)
if isinstance(indexes, Tensor):
indexes = indexes.cpu().numpy()
if isinstance(preds, Tensor):
Expand Down Expand Up @@ -393,6 +399,7 @@ def _concat_tests(*tests: tuple[dict]) -> dict:
"argnames": "preds,target",
"argvalues": [
(_irs.preds, _irs.target),
(_irs_2d.preds, _irs_2d.target),
(_irs_extra.preds, _irs_extra.target),
(_irs_no_tgt.preds, _irs_no_tgt.target),
(_irs_int_tgt.preds, _irs_int_tgt.target),
Expand Down Expand Up @@ -494,11 +501,14 @@ def run_functional_metric_test(
metric_functional: Callable,
reference_metric: Callable,
metric_args: dict,
metric_name: Optional[str] = None,
reverse: bool = False,
**kwargs: Any,
):
"""Test functional implementation of metric."""
_ref_metric_adapted = partial(_compute_sklearn_metric, metric=reference_metric, reverse=reverse, **metric_args)
_ref_metric_adapted = partial(
_compute_sklearn_metric, metric=reference_metric, reverse=reverse, metric_name=metric_name, **metric_args
)

super().run_functional_metric_test(
preds=preds,
Expand Down
9 changes: 6 additions & 3 deletions tests/unittests/retrieval/test_ndcg.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,17 +120,20 @@ def test_class_metric_ignore_index(
)

@pytest.mark.parametrize(**_default_metric_functional_input_arguments_with_non_binary_target)
@pytest.mark.parametrize("empty_target_action", ["skip", "pos", "neg"])
@pytest.mark.parametrize("k", [None, 1, 4, 10])
def test_functional_metric(self, preds: Tensor, target: Tensor, k: int):
def test_functional_metric(self, preds: Tensor, target: Tensor, empty_target_action: str, k: int):
"""Test functional implementation of metric."""
metric_args = {"empty_target_action": empty_target_action, "top_k": k}

target = target if target.min() >= 0 else target - target.min()
self.run_functional_metric_test(
preds=preds,
target=target,
metric_functional=retrieval_normalized_dcg,
reference_metric=_ndcg_at_k,
metric_args={},
top_k=k,
metric_args=metric_args,
metric_name="ndcg",
)

@pytest.mark.parametrize(**_default_metric_class_input_arguments_with_non_binary_target)
Expand Down
Loading