diff --git a/library/docs/source/guide/tutorials/advanced/hier_metric_collection.rst b/library/docs/source/guide/tutorials/advanced/hier_metric_collection.rst new file mode 100644 index 0000000000..1f3fa4c8be --- /dev/null +++ b/library/docs/source/guide/tutorials/advanced/hier_metric_collection.rst @@ -0,0 +1,104 @@ +Hierarchical Classification Metric Collection +============================================= + +.. note:: + + The hierarchical classification metrics are designed for structured label spaces (e.g., taxonomies in biology or medicine). + See the method in ``otx.metrics.hier_metric_collection.hier_metric_collection_callable``. + +Overview +-------- + +OpenVINO™ Training Extensions provides a unified **hierarchical metric collection** for classification tasks +that involve taxonomic or multi-level labels. This extends flat classification metrics (accuracy, mAP) with +hierarchy-aware evaluation. + +Benefits +-------- + +- **Structure-Aware**: Evaluates not only flat accueacy, but also taxonomy-aware metrics. +- **Robustness**: Partial credit is given when higher-level predictions are correct, even if fine-grained labels are wrong. +- **Flexibility**: Works seamlessly across multiclass and hierarchical-label tasks. + +Supported +--------- + +- **Label Types**: + - Hierarchical-label classification +- **Tasks**: + - Taxonomy-aware hierarchical classification + +How to Use Hierarchical Metric Collection +----------------------------------------- + +.. tab-set:: + + .. tab-item:: API + + .. code-block:: python + + from otx.metrics.hier_metric_collection import hier_metric_collection_callable + from otx.core.types.label import HLabelInfo + + # Suppose label_info is loaded from a Datumaro dataset + metric = hier_metric_collection_callable(label_info) + + # During training / validation + metric.update(preds, targets) + results = metric.compute() + + .. tab-item:: CLI + + .. code-block:: bash + + (otx) $ otx train ... --metric otx.metrics.hier_metric_collection.hier_metric_collection_callable + + .. tab-item:: YAML + + .. code-block:: yaml + + task: H_LABEL_CLS + model: + class_path: + init_args: + label_info: + + metric: + class_path: otx.metrics.hier_metric_collection.hier_metric_collection_callable + +How to Use the Metric Collection with the Engine +------------------------------------------- + +.. tab-set:: + + .. tab-item:: API + + .. code-block:: python + + from otx.engine import create_engine + from otx.metrics.hier_metric_collection import hier_metric_collection_callable + from otx.core.types.label import HLabelInfo + + # 1) Build or load your label_info (e.g., from a Datumaro dataset) + # label_info: HLabelInfo = ... + + # 2) Create your model and data objects (specific to your project) + model = ... + data = ... + + # 3) Create an Engine and pass the metric callable into train/test + engine = create_engine(model, data) + engine.train(metric=hier_metric_collection_callable) # the Engine will construct the MetricCollection + engine.test(metric=hier_metric_collection_callable) + + + .. tab-item:: What gets computed? + + The callable returns a ``torchmetrics.MetricCollection`` with keys: + + - ``"accuracy"`` — hierarchical head accuracy + - ``"leaf_accuracy"`` — macro accuracy on the leaf level + - ``"full_path_accuracy"`` — exact match across all hierarchy levels + - ``"inconsistent_path_ratio"`` — ratio of parent→child violations in predicted paths + - ``"weighted_precision"`` — label-count–weighted macro precision across levels + diff --git a/library/docs/source/guide/tutorials/advanced/index.rst b/library/docs/source/guide/tutorials/advanced/index.rst index 84d3484afb..397c9d212e 100644 --- a/library/docs/source/guide/tutorials/advanced/index.rst +++ b/library/docs/source/guide/tutorials/advanced/index.rst @@ -9,5 +9,6 @@ Advanced Tutorials multi_gpu peft torch_compile + hier_metric_collection .. Once we have enough material, we might need to categorize these into `data`, `model learning` sections. \ No newline at end of file diff --git a/library/src/otx/metrics/hier_metric_collection.py b/library/src/otx/metrics/hier_metric_collection.py new file mode 100644 index 0000000000..dd2beb3f5f --- /dev/null +++ b/library/src/otx/metrics/hier_metric_collection.py @@ -0,0 +1,234 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""A ruff-friendly, single-file collection of hierarchical classification metrics. + +Exports +------- +- :class:`LeafAccuracy` - macro-averaged accuracy at the leaf level. +- :class:`FullPathAccuracy` - exact match across all hierarchy levels. +- :class:`InconsistentPathRatio` - fraction of *predicted* paths violating the tree. +- :class:`WeightedHierarchicalPrecision` - label-count-weighted macro precision over levels. +- :func:`hierMetricCollectionCallable` - returns a ``torchmetrics.MetricCollection`` containing the above metrics. +- :data:`hierMetricCollection` - ``MetricCallable`` alias for integration. + +All metrics are compatible with OTX-style :class:`otx.types.label.HLabelInfo`. + +""" + +from __future__ import annotations + +from typing import Callable + +import torch +from torch import nn +from torchmetrics import Metric, MetricCollection +from torchmetrics.classification import Precision as TorchPrecision + +from otx.metrics.accuracy import HlabelAccuracy +from otx.types.label import HLabelInfo + +__all__ = [ + "LeafAccuracy", + "FullPathAccuracy", + "InconsistentPathRatio", + "WeightedHierarchicalPrecision", + "HierMetricCollection", + "hier_metric_collection_callable", +] + +_INVALID_SHAPE_MSG = "preds and target must have the same shape" +_INVALID_2D_SHAPE = "preds must be 2D (N, L)" + + +def _build_level_idx_to_name(label_groups: list[list[str]]) -> dict[tuple[int, int], str]: + """Create a mapping ``(level, index) -> label_name``. + + Args: + label_groups: ``L`` lists of label names per hierarchy level. + """ + out: dict[tuple[int, int], str] = {} + for lvl, labels in enumerate(label_groups): + for idx, name in enumerate(labels): + out[(lvl, idx)] = name + return out + + +def _make_child_to_parent(edges: list[list[str]]) -> dict[str, str]: + """Create a mapping ``child -> parent`` from edges.""" + c2p = {} + for child, parent in edges: + if child in c2p: # defensive programming in case of duplicates + error_msg = f"duplicate child: {child}" + raise ValueError(error_msg) + c2p[child] = parent + return c2p + + +# --------------------------------------------------------------------------- +# Metrics +# --------------------------------------------------------------------------- + + +class LeafAccuracy(Metric): + """Macro-averaged accuracy at the leaf (last) group. + + Assumes targets/preds are class indices shaped ``(N, L)``. + """ + + full_state_update: bool = False + + def __init__(self, label_info: HLabelInfo) -> None: + super().__init__() + self.label_info = label_info + + leaf_labels = label_info.label_groups[-1] + self.num_leaf_classes = len(leaf_labels) + + self.add_state( + "correct_per_class", + default=torch.zeros(self.num_leaf_classes, dtype=torch.long), + dist_reduce_fx="sum", + ) + self.add_state( + "total_per_class", + default=torch.zeros(self.num_leaf_classes, dtype=torch.long), + dist_reduce_fx="sum", + ) + + def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: # type: ignore[override] + """Update state with predictions and targets.""" + pred_leaf = preds[:, -1] + target_leaf = target[:, -1] + for cls in range(self.num_leaf_classes): + mask = target_leaf == cls + self.total_per_class[cls] += mask.sum() + self.correct_per_class[cls] += (pred_leaf[mask] == cls).sum() + + def compute(self) -> torch.Tensor: # type: ignore[override] + """Compute the leaf accuracy metric.""" + total = self.total_per_class.clamp_min_(1) + per_class_acc = self.correct_per_class.float() / total.float() + return per_class_acc.mean() + + +class FullPathAccuracy(Metric): + """Exact-match accuracy across all hierarchy levels.""" + + full_state_update: bool = False + + def __init__(self) -> None: + super().__init__() + self.add_state("correct", default=torch.tensor(0, dtype=torch.long), dist_reduce_fx="sum") + self.add_state("total", default=torch.tensor(0, dtype=torch.long), dist_reduce_fx="sum") + + def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: # type: ignore[override] + """Update state with predictions and targets.""" + if preds.shape != target.shape: + raise ValueError(_INVALID_SHAPE_MSG) + matches = (preds == target).all(dim=1) + self.correct += matches.sum() + self.total += preds.size(0) + + def compute(self) -> torch.Tensor: # type: ignore[override] + """Compute the full path accuracy metric.""" + return self.correct.float() / self.total.clamp_min(1).float() + + +class InconsistentPathRatio(Metric): + """Ratio of *predicted* paths violating the parent→child constraints.""" + + full_state_update: bool = False + + def __init__(self, label_info: HLabelInfo) -> None: + super().__init__() + self.level_idx_to_name = _build_level_idx_to_name(label_info.label_groups) + self.child_to_parent = _make_child_to_parent(label_info.label_tree_edges) + self.add_state("invalid", default=torch.tensor(0, dtype=torch.long), dist_reduce_fx="sum") + self.add_state("total", default=torch.tensor(0, dtype=torch.long), dist_reduce_fx="sum") + + def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: # type: ignore[override] + """Update state with predictions.""" + if preds.ndim != 2: + raise ValueError(_INVALID_2D_SHAPE) + n, level = preds.shape + for i in range(n): + ok = True + for lvl in range(1, level): + child = self.level_idx_to_name[(lvl, int(preds[i, lvl]))] + parent = self.level_idx_to_name[(lvl - 1, int(preds[i, lvl - 1]))] + if self.child_to_parent.get(child) != parent: + ok = False + break + if not ok: + self.invalid += 1 + self.total += n + + def compute(self) -> torch.Tensor: # type: ignore[override] + """Compute the inconsistent path ratio error metric.""" + return self.invalid.float() / self.total.clamp_min(1).float() + + +class WeightedHierarchicalPrecision(Metric): + """Label-count-weighted macro precision across hierarchy levels. + + At each level ``l``, computes macro precision and aggregates with weight + ``|labels_l| / sum_k |labels_k|``. Inputs are class indices ``(N, L)``. + """ + + full_state_update: bool = False + + def __init__(self, label_info: HLabelInfo) -> None: + super().__init__() + self.level_sizes: list[int] = [] + self.level_metrics = nn.ModuleList() + for lvl in sorted(label_info.head_idx_to_logits_range): + lo, hi = label_info.head_idx_to_logits_range[lvl] + num_classes = int(hi - lo) + self.level_sizes.append(num_classes) + self.level_metrics.append( + TorchPrecision(task="multiclass", num_classes=num_classes, average="macro"), + ) + + def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: # type: ignore[override] + """Update state with predictions and targets.""" + # Each column corresponds to a level. + for lvl, metric in enumerate(self.level_metrics): + metric.update(preds[:, lvl], target[:, lvl]) + + def compute(self) -> torch.Tensor: # type: ignore[override] + """Compute the wAP.""" + total = float(sum(self.level_sizes)) + weights = [s / total for s in self.level_sizes] + per_level = [metric.compute() for metric in self.level_metrics] + return torch.stack([w * v for w, v in zip(weights, per_level)]).sum() + + def reset(self) -> None: # type: ignore[override] + """Reset the metric calculation.""" + for metric in self.level_metrics: + metric.reset() + + +def hier_metric_collection_callable(label_info: HLabelInfo) -> MetricCollection: + """Create a ``MetricCollection`` with all hierarchical metrics. + + Returns: + ------- + torchmetrics.MetricCollection + Collection with keys: ``leaf_accuracy``, ``full_path_accuracy``, + ``inconsistent_path_ratio``, ``weighted_precision``. + """ + return MetricCollection( + { + "accuracy": HlabelAccuracy(label_info=label_info), + "leaf_accuracy": LeafAccuracy(label_info=label_info), + "full_path_accuracy": FullPathAccuracy(), + "inconsistent_path_ratio": InconsistentPathRatio(label_info=label_info), + "weighted_precision": WeightedHierarchicalPrecision(label_info=label_info), + }, + ) + + +HMetricCallable = Callable[[HLabelInfo], Metric | MetricCollection] + +HierMetricCollection: HMetricCallable = hier_metric_collection_callable diff --git a/library/tests/unit/metrics/test_hier_metric_collection.py b/library/tests/unit/metrics/test_hier_metric_collection.py new file mode 100644 index 0000000000..0f4be94588 --- /dev/null +++ b/library/tests/unit/metrics/test_hier_metric_collection.py @@ -0,0 +1,189 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +import re +import types + +import pytest +import torch +from torchmetrics.classification import Precision as TorchPrecision + +from otx.metrics.hier_metric_collection import ( + FullPathAccuracy, + InconsistentPathRatio, + LeafAccuracy, + WeightedHierarchicalPrecision, + hier_metric_collection_callable, +) + + +@pytest.fixture() +def label_info_stub(): + """Minimal stub that mimics the LabelInfo attributes used by our metrics. + + - label_groups: list[list[str]] per hierarchy level + - label_tree_edges: list[tuple[str, str]] as (child, parent) + - head_idx_to_logits_range: dict[level, (lo, hi)] for per-level class counts + """ + li = types.SimpleNamespace() + li.label_groups = [ + ["Boeing", "Airbus"], + ["737", "A320"], + ["737-800", "737-900", "A320-200", "A320-neo"], + ] + # edges are (child, parent) + li.label_tree_edges = [ + ("737", "Boeing"), + ("A320", "Airbus"), + ("737-800", "737"), + ("737-900", "737"), + ("A320-200", "A320"), + ("A320-neo", "A320"), + ] + # per-level class ranges (concatenated logits indices convention) + li.head_idx_to_logits_range = {0: (0, 2), 1: (2, 4), 2: (4, 8)} + li.num_multiclass_heads = len(li.label_groups) + li.num_multilabel_classes = sum(len(g) for g in li.label_groups) + return li + + +@pytest.fixture() +def sample_tensors(): + """Return (target, preds) shaped (N, L) with class indices. + + N=4, L=3 (3 hierarchy levels) + """ + # targets (true) indices per level + target = torch.tensor( + [ + [0, 0, 0], # Boeing, 737, 737-800 + [1, 1, 2], # Airbus, A320, A320-200 + [0, 0, 1], # Boeing, 737, 737-900 + [1, 1, 3], # Airbus, A320, A320-neo + ] + ) + # preds: 2 exact matches (rows 1 and 2); two leaf errors + preds = torch.tensor( + [ + [0, 0, 1], # leaf wrong + [1, 1, 2], # exact + [0, 0, 1], # exact + [1, 1, 0], # leaf wrong + ] + ) + return target, preds + + +# ------------------------------ LeafAccuracy -------------------------------- + + +def test_leaf_accuracy_macro_mean(label_info_stub, sample_tensors): + target, preds = sample_tensors + metric = LeafAccuracy(label_info_stub) + metric.update(preds, target) + val = metric.compute().item() + + # Compute expected macro mean at leaf by hand + y_true_leaf = target[:, -1] + y_pred_leaf = preds[:, -1] + per_class = [] + for cls in range(4): + mask = y_true_leaf == cls + tot = int(mask.sum()) + if tot == 0: + per_class.append(0.0) + else: + correct = int((y_pred_leaf[mask] == cls).sum()) + per_class.append(correct / tot) + expected = sum(per_class) / 4.0 + assert val == pytest.approx(expected) + + +# --------------------------- FullPathAccuracy ------------------------------- +def test_full_path_accuracy(sample_tensors): + target, preds = sample_tensors + metric = FullPathAccuracy() + metric.update(preds, target) + val = metric.compute().item() + # exact rows: 1 and 2 -> 2/4 + assert val == pytest.approx(0.5) + + +# ------------------------ Inconsistent Path Ratio --------------------------- +def test_inconsistent_path_ratio_inconsistent(label_info_stub): + # Make structurally invalid predictions (wrong parent chain) + preds_bad = torch.tensor( + [ + [0, 1, 0], # 737 belongs to Boeing, not Airbus + [1, 0, 3], # A320 belongs to Airbus, not Boeing + ] + ) + target_dummy = torch.zeros_like(preds_bad) + + metric = InconsistentPathRatio(label_info_stub) + metric.update(preds_bad, target_dummy) + val = metric.compute().item() + assert val == pytest.approx(1.0) + + +# --------------------- Weighted Hierarchical Precision ---------------------- +def test_weighted_hierarchical_precision_matches_reference(label_info_stub, sample_tensors): + target, preds = sample_tensors + metric = WeightedHierarchicalPrecision(label_info_stub) + metric.update(preds, target) + got = metric.compute().item() + + # Reference computation using TorchPrecision per level, then label-count weights + level_sizes = [len(g) for g in label_info_stub.label_groups] + total = float(sum(level_sizes)) + weights = [s / total for s in level_sizes] + + ref_vals = [] + for lvl, ncls in enumerate(level_sizes): + ref = TorchPrecision(task="multiclass", num_classes=ncls, average="macro") + ref.update(preds[:, lvl], target[:, lvl]) + ref_vals.append(ref.compute().item()) + + expected = sum(w * v for w, v in zip(weights, ref_vals)) + assert got == pytest.approx(expected, rel=1e-5, abs=1e-6) + + +# -------------------------- MetricCollection callable ----------------------- +def test_hier_metric_collection_callable(label_info_stub, sample_tensors): + target, preds = sample_tensors + mc = hier_metric_collection_callable(label_info_stub) + + # update/compute over the whole collection + mc.update(preds, target) + out = mc.compute() + + assert set(out.keys()) == { + "leaf_accuracy", + "full_path_accuracy", + "inconsistent_path_ratio", + "weighted_precision", + "accuracy", + "conf_matrix", + } + + # spot-check a couple of values + assert out["full_path_accuracy"].item() == pytest.approx(0.5) + assert out["inconsistent_path_ratio"].item() == pytest.approx(0.25) + + +# ------------------------------ Error handling ------------------------------ + + +def test_full_path_accuracy_shape_mismatch_raises(): + metric = FullPathAccuracy() + preds = torch.tensor([[0, 0, 0]]) + target = torch.tensor([[0, 0]]) # wrong shape + with pytest.raises(ValueError, match=re.escape("preds and target must have the same shape")): + metric.update(preds, target) + + +def test_inconsistent_path_ratio_requires_2d(label_info_stub): + metric = InconsistentPathRatio(label_info_stub) + preds = torch.tensor([0, 1, 2]) # 1D + target = torch.tensor([0, 1, 2]) + with pytest.raises(ValueError, match=re.escape("preds must be 2D (N, L)")): + metric.update(preds, target) diff --git a/library/tests/unit/metrics/test_hier_metric_collection_from_engine.py b/library/tests/unit/metrics/test_hier_metric_collection_from_engine.py new file mode 100644 index 0000000000..14085c8ee6 --- /dev/null +++ b/library/tests/unit/metrics/test_hier_metric_collection_from_engine.py @@ -0,0 +1,30 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +from unittest.mock import MagicMock, patch + +import pytest + +from otx.engine import Engine, create_engine +from otx.metrics.hier_metric_collection import hier_metric_collection_callable + + +class TestCreateEngine: + @pytest.fixture() + def mock_engine_subclass(self): + """Fixture to create a mock Engine subclass.""" + mock_engine_cls = MagicMock(spec=Engine) + mock_engine_cls.is_supported.return_value = True + return mock_engine_cls + + @patch("otx.engine.Engine.__subclasses__", autospec=True) + def test_hier_metric_collection_by_engine(self, mock___subclasses__, mock_engine_subclass): + """Test create_engine with arbitrary Engine.""" + mock___subclasses__.return_value = [mock_engine_subclass] + mock_model = MagicMock() + mock_data = MagicMock() + + engine_instance = create_engine(mock_model, mock_data) + engine_instance.train(metric=hier_metric_collection_callable) + engine_instance.test(metric=hier_metric_collection_callable)