Skip to content

Commit 68eb32f

Browse files
Jyc323leoll2
andauthored
add Tree-Path KL Divergence loss for hier classification + unit test (#4706)
Co-authored-by: Leonardo Lai <[email protected]>
1 parent 7c9705a commit 68eb32f

File tree

11 files changed

+746
-3
lines changed

11 files changed

+746
-3
lines changed

library/docs/source/guide/tutorials/advanced/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,6 @@ Advanced Tutorials
1010
peft
1111
torch_compile
1212
hier_metric_collection
13+
tree_path_kl_loss_hcls
1314

1415
.. Once we have enough material, we might need to categorize these into `data`, `model learning` sections.
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
Using Tree-Path KL Divergence for Hierarchical Classification
2+
=============================================================
3+
4+
This tutorial explains how to train hierarchical classification models in
5+
OpenVINO™ Training Extensions with **Tree-Path KL Divergence Loss**, a training-time
6+
regularizer that encourages consistent predictions along the taxonomy path
7+
from root to leaf. The method is implemented in:
8+
9+
- :class:`otx.backend.native.models.classification.losses.tree_path_kl_divergence_loss.TreePathKLDivergenceLoss`
10+
- :class:`otx.backend.native.models.classification.classifier.h_label_classifier.KLHLabelClassifier`
11+
12+
The feature is currently exposed by default in
13+
:class:`otx.backend.native.models.classification.hlabel_models.timm_model.TimmModelHLabelCls`.
14+
Users may adapt other architectures with minimal modifications by adding the
15+
same wrapper (``KLHLabelClassifier``) in their model’s ``_finalize_model()``.
16+
17+
Overview
18+
--------
19+
20+
Hierarchical classification models predict multiple levels of labels
21+
(e.g., manufacturer → family → variant). Standard cross-entropy treats each
22+
level independently, which means models may output **inconsistent**
23+
combinations such as:
24+
25+
- predicting a correct fine-grained leaf but an incompatible ancestor, or
26+
- predicting parents and children belonging to different branches.
27+
28+
Tree-Path KL Divergence introduces a path-consistency objective by comparing:
29+
30+
- the model’s *combined* probability distribution across all levels, and
31+
- a **tree-consistent target distribution** that places probability mass on
32+
each ground-truth category along the path.
33+
34+
This encourages smooth transitions between hierarchy levels and reduces
35+
structurally invalid predictions.
36+
37+
How It Works
38+
------------
39+
40+
Tree-Path KL Divergence operates on:
41+
42+
- a **list of logits** from each hierarchy level (root → ... → leaf), and
43+
- a **target index** for each corresponding level.
44+
45+
The algorithm implemented in
46+
:class:`TreePathKLDivergenceLoss` performs the following:
47+
48+
1. Concatenates all level logits and applies log-softmax.
49+
2. Constructs a sparse target distribution that allocates equal probability to
50+
the correct class at each level.
51+
3. Computes KL divergence between the model’s distribution and the path-aware
52+
target distribution.
53+
4. Scales the result by ``loss_weight`` (typically ``1.0``).
54+
55+
In :class:`KLHLabelClassifier`, this KL term is added to the hierarchical
56+
cross-entropy loss:
57+
58+
- cross-entropy is averaged across all hierarchy levels,
59+
- KL divergence is multiplied by ``kl_weight``,
60+
- ``kl_weight = 0`` disables the KL term completely.
61+
62+
Enabling Tree-Path KL Divergence
63+
--------------------------------
64+
65+
The recommended entry point is the provided recipe:
66+
67+
.. code-block:: text
68+
69+
recipe/classification/h_label_cls/efficientnet_v2_kl.yaml
70+
71+
This recipe uses :class:`TimmModelHLabelCls` and exposes the argument
72+
``kl_weight`` directly in ``init_args``:
73+
74+
.. code-block:: yaml
75+
76+
task: H_LABEL_CLS
77+
model:
78+
class_path: otx.backend.native.models.classification.hlabel_models.timm_model.TimmModelHLabelCls
79+
init_args:
80+
label_info: <LABEL-TREE-INFO>
81+
model_name: tf_efficientnetv2_s.in21k
82+
kl_weight: 1.0
83+
84+
Using the CLI
85+
--------------------------------
86+
87+
To train a hierarchical model with Tree-Path KL Divergence, the CLI requires:
88+
89+
- ``--data_root``: a path to a directory containing an **``annotations/`` folder**
90+
whose JSON annotation files follow **Datumaro format**.
91+
See the format specification here:
92+
93+
https://open-edge-platform.github.io/datumaro/stable/docs/data-formats/datumaro_format.html
94+
95+
- ``--config``: the **path to a recipe YAML file**, such as
96+
``recipe/classification/h_label_cls/efficientnet_v2_kl.yaml``.
97+
98+
A full training command example:
99+
100+
.. code-block:: bash
101+
102+
(otx) $ otx train \
103+
--config recipe/classification/h_label_cls/efficientnet_v2_kl.yaml \
104+
--data_root /path/to/dataset_with_annotations \
105+
--model.kl_weight 1.0
106+
107+
To disable Tree-Path KL Divergence and train a standard hierarchical model:
108+
109+
.. code-block:: bash
110+
111+
(otx) $ otx train \
112+
--config recipe/classification/h_label_cls/efficientnet_v2_kl.yaml \
113+
--model.kl_weight 0.0
114+
115+
Extending Other Architectures
116+
-----------------------------
117+
118+
Currently, Tree-Path KL Divergence is automatically supported only by
119+
``TimmModelHLabelCls``. To integrate the feature into other architectures, add
120+
the following logic to the model’s ``_finalize_model`` method:
121+
122+
1. Accept a new ``kl_weight`` argument in the model init.
123+
2. After constructing the underlying model, wrap it as:
124+
125+
.. code-block:: python
126+
127+
if self.kl_weight > 0:
128+
model = KLHLabelClassifier(model, kl_weight=self.kl_weight)
129+
130+
3. Ensure that the model returns a list of logits aligned with the hierarchy.
131+
132+
Only a few lines are required, and this enables the same training procedure
133+
for any backbone (ResNet, ViT, ConvNeXt, etc.).
134+
135+
When to Use Tree-Path KL Divergence
136+
-----------------------------------
137+
138+
Tree-Path KL Divergence is most helpful when:
139+
140+
- the label space forms a strict taxonomy,
141+
- incorrect parent/child combinations are undesirable,
142+
- fine-grained classes are scarce and benefit from structural priors,
143+
- you want improved consistency across hierarchy levels.
144+
145+
Practically, start with:
146+
147+
- ``kl_weight = 1.0`` or ``2.0`` for most datasets,
148+
- monitor both fine-grained and coarse-level accuracy,
149+
- adjust ``kl_weight`` based on the trade-off between accuracy and
150+
hierarchical consistency.
151+
152+
Practical Tips
153+
--------------
154+
155+
- Ensure that ``label_info`` correctly describes the hierarchy.
156+
- Excessively large ``kl_weight`` values may over-regularize the model.
157+
- For benchmarking, compare:
158+
- ``kl_weight = 0`` (baseline),
159+
- ``kl_weight = 1–4`` (KL-enabled variants).
160+
- Tree-Path KL acts as a *training-time* consistency constraint; it does not
161+
modify architecture or inference cost.
162+
163+
Limitations
164+
-----------
165+
166+
- Supported out-of-the-box only for :class:`TimmModelHLabelCls`.
167+
- Requires the model to output logits for **each level** of the hierarchy.
168+
- Not applicable to flat classification tasks.
169+
170+

library/src/otx/backend/native/models/classification/classifier/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@
44
"""Head modules for OTX custom model."""
55

66
from .base_classifier import ImageClassifier
7-
from .h_label_classifier import HLabelClassifier
7+
from .h_label_classifier import HLabelClassifier, KLHLabelClassifier
88

9-
__all__ = ["HLabelClassifier", "ImageClassifier"]
9+
__all__ = ["HLabelClassifier", "ImageClassifier", "KLHLabelClassifier"]

library/src/otx/backend/native/models/classification/classifier/h_label_classifier.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import torch
1212

1313
from otx.backend.native.models.classification.heads.hlabel_cls_head import HierarchicalClsHead
14+
from otx.backend.native.models.classification.losses.tree_path_kl_divergence_loss import TreePathKLDivergenceLoss
1415
from otx.backend.native.models.classification.utils.ignored_labels import get_valid_label_mask
1516

1617
from .base_classifier import ImageClassifier
@@ -143,3 +144,87 @@ def _forward_explain(self, images: torch.Tensor) -> dict[str, torch.Tensor | lis
143144
outputs["preds"] = preds
144145

145146
return outputs
147+
148+
149+
class KLHLabelClassifier(HLabelClassifier):
150+
"""Hierarchical label classifier with tree path KL divergence loss.
151+
152+
Args:
153+
backbone (nn.Module): Backbone network.
154+
neck (nn.Module | None): Neck network.
155+
head (nn.Module): Head network.
156+
multiclass_loss (nn.Module): Multiclass loss function.
157+
multilabel_loss (nn.Module | None, optional): Multilabel loss function.
158+
init_cfg (dict | list[dict] | None, optional): Initialization configuration.
159+
kl_weight (float): Loss weight for tree path KL divergence loss
160+
161+
Attributes:
162+
multiclass_loss (nn.Module): Multiclass loss function.
163+
multilabel_loss (nn.Module | None): Multilabel loss function.
164+
is_ignored_label_loss (bool): Flag indicating if ignored label loss is used.
165+
166+
Methods:
167+
loss(inputs, labels, **kwargs): Calculate losses from a batch of inputs and data samples.
168+
"""
169+
170+
def __init__(self, *args, kl_weight: float = 1.0, **kwargs) -> None:
171+
super().__init__(*args, **kwargs)
172+
self.kl_weight = kl_weight
173+
self.kl_loss = TreePathKLDivergenceLoss(reduction="batchmean", loss_weight=1.0)
174+
175+
def loss(self, inputs: torch.Tensor, labels: torch.Tensor, **kwargs) -> torch.Tensor:
176+
"""Calculate losses from a batch of inputs and data samples.
177+
178+
Args:
179+
inputs (torch.Tensor): The input tensor with shape
180+
(N, C, ...) in general.
181+
labels (torch.Tensor): The annotation data of
182+
every samples.
183+
184+
Returns:
185+
torch.Tensor: loss components
186+
"""
187+
cls_scores = self.extract_feat(inputs, stage="head")
188+
loss_score = torch.tensor(0.0, device=cls_scores.device)
189+
logits_list = []
190+
target_list = []
191+
num_effective_heads_in_batch = 0
192+
for i in range(self.head.num_multiclass_heads):
193+
if i not in self.head.empty_multiclass_head_indices:
194+
head_gt = labels[:, i]
195+
logit_range = self.head._get_head_idx_to_logits_range(i) # noqa: SLF001
196+
head_logits = cls_scores[:, logit_range[0] : logit_range[1]]
197+
valid_mask = head_gt >= 0
198+
head_gt = head_gt[valid_mask]
199+
if len(head_gt) > 0:
200+
head_logits = head_logits[valid_mask]
201+
logits_list.append(head_logits)
202+
target_list.append(head_gt)
203+
ce = self.multiclass_loss(head_logits, head_gt)
204+
loss_score += ce
205+
num_effective_heads_in_batch += 1
206+
207+
if num_effective_heads_in_batch > 0:
208+
loss_score /= num_effective_heads_in_batch
209+
210+
if len(logits_list) > 1:
211+
kl_loss = self.kl_loss(logits_list, torch.stack(target_list, dim=1))
212+
loss_score += self.kl_weight * kl_loss
213+
214+
# Multilabel logic (preserved as-is)
215+
if self.head.num_multilabel_classes > 0:
216+
head_gt = labels[:, self.head.num_multiclass_heads :]
217+
head_logits = cls_scores[:, self.head.num_single_label_classes :]
218+
valid_mask = head_gt > 0
219+
head_gt = head_gt[valid_mask]
220+
if len(head_gt) > 0 and self.multilabel_loss is not None:
221+
head_logits = head_logits[valid_mask]
222+
imgs_info = kwargs.pop("imgs_info", None)
223+
if imgs_info is not None and self.is_ignored_label_loss:
224+
valid_label_mask = get_valid_label_mask(imgs_info, self.head.num_classes).to(head_logits.device)
225+
valid_label_mask = valid_label_mask[:, self.head.num_single_label_classes :]
226+
valid_label_mask = valid_label_mask[valid_mask]
227+
kwargs["valid_label_mask"] = valid_label_mask
228+
loss_score += self.multilabel_loss(head_logits, head_gt, **kwargs)
229+
230+
return loss_score

library/src/otx/backend/native/models/classification/hlabel_models/base.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from abc import abstractmethod
99
from copy import deepcopy
10+
from functools import wraps
1011
from typing import TYPE_CHECKING, Any
1112

1213
import torch
@@ -15,6 +16,7 @@
1516
from otx.backend.native.exporter.base import OTXModelExporter
1617
from otx.backend.native.exporter.native import OTXNativeModelExporter
1718
from otx.backend.native.models.base import DataInputParams, DefaultOptimizerCallable, DefaultSchedulerCallable, OTXModel
19+
from otx.backend.native.models.classification.classifier import KLHLabelClassifier
1820
from otx.backend.native.schedulers import LRSchedulerListCallable
1921
from otx.data.entity.base import OTXBatchLossEntity
2022
from otx.data.entity.torch import OTXDataBatch, OTXPredBatch
@@ -46,6 +48,7 @@ class OTXHlabelClsModel(OTXModel):
4648
Defaults to DefaultSchedulerCallable.
4749
metric (MetricCallable, optional): Callable for the metric. Defaults to HLabelClsMetricCallable.
4850
torch_compile (bool, optional): Flag to indicate whether to use torch.compile. Defaults to False.
51+
kl_weight: The weight of tree-path KL divergence loss. Defaults to zero, use CrossEntropy only.
4952
"""
5053

5154
label_info: HLabelInfo
@@ -60,7 +63,9 @@ def __init__(
6063
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
6164
metric: MetricCallable = HLabelClsMetricCallable,
6265
torch_compile: bool = False,
66+
kl_weight: float = 0.0,
6367
) -> None:
68+
self.kl_weight = kl_weight
6469
super().__init__(
6570
label_info=label_info,
6671
data_input_params=data_input_params,
@@ -71,16 +76,46 @@ def __init__(
7176
metric=metric,
7277
torch_compile=torch_compile,
7378
)
74-
7579
if freeze_backbone:
7680
classification_layers = self._identify_classification_layers()
7781
for name, param in self.named_parameters():
7882
param.requires_grad = name in classification_layers
7983

84+
def __getattribute__(self, name: str):
85+
attr = super().__getattribute__(name)
86+
if name == "_create_model" and callable(attr):
87+
cache_name = "__cm_cached__"
88+
cache = super().__getattribute__("__dict__").get(cache_name)
89+
if cache:
90+
return cache
91+
92+
@wraps(attr)
93+
def wrapped(*a, **kw) -> nn.Module:
94+
model = attr(*a, **kw)
95+
return self._finalize_model(model)
96+
97+
self.__dict__[cache_name] = wrapped
98+
return wrapped
99+
return attr
100+
80101
@abstractmethod
81102
def _create_model(self, head_config: dict | None = None) -> nn.Module: # type: ignore[override]
82103
"""Create a PyTorch model for this class."""
83104

105+
def _finalize_model(self, model: nn.Module) -> nn.Module:
106+
"""Run after child _create_model(); upgrade to KL if enabled."""
107+
if self.kl_weight > 0:
108+
return KLHLabelClassifier(
109+
backbone=model.backbone,
110+
neck=model.neck,
111+
head=model.head,
112+
multiclass_loss=model.multiclass_loss,
113+
multilabel_loss=model.multilabel_loss,
114+
init_cfg=getattr(model, "init_cfg", None),
115+
kl_weight=self.kl_weight,
116+
)
117+
return model
118+
84119
def _identify_classification_layers(self, prefix: str = "model.") -> list[str]:
85120
"""Simple identification of the classification layers. Used for incremental learning."""
86121
# identify classification layers

library/src/otx/backend/native/models/classification/hlabel_models/timm_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ class TimmModelHLabelCls(OTXHlabelClsModel):
4545
metric (MetricCallable, optional): The metric callable for evaluating the model.
4646
Defaults to HLabelClsMetricCallable.
4747
torch_compile (bool, optional): Whether to compile the model using TorchScript. Defaults to False.
48+
kl_weight: The weight of tree-path KL divergence loss. Defaults to zero, use CrossEntropy only.
4849
"""
4950

5051
def __init__(
@@ -57,6 +58,7 @@ def __init__(
5758
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
5859
metric: MetricCallable = HLabelClsMetricCallable,
5960
torch_compile: bool = False,
61+
kl_weight: float = 0.0,
6062
) -> None:
6163
super().__init__(
6264
label_info=label_info,
@@ -67,6 +69,7 @@ def __init__(
6769
scheduler=scheduler,
6870
metric=metric,
6971
torch_compile=torch_compile,
72+
kl_weight=kl_weight,
7073
)
7174

7275
def _create_model(self, head_config: dict | None = None) -> nn.Module: # type: ignore[override]

0 commit comments

Comments
 (0)