|
| 1 | +Using Tree-Path KL Divergence for Hierarchical Classification |
| 2 | +============================================================= |
| 3 | + |
| 4 | +This tutorial explains how to train hierarchical classification models in |
| 5 | +OpenVINO™ Training Extensions with **Tree-Path KL Divergence Loss**, a training-time |
| 6 | +regularizer that encourages consistent predictions along the taxonomy path |
| 7 | +from root to leaf. The method is implemented in: |
| 8 | + |
| 9 | +- :class:`otx.backend.native.models.classification.losses.tree_path_kl_divergence_loss.TreePathKLDivergenceLoss` |
| 10 | +- :class:`otx.backend.native.models.classification.classifier.h_label_classifier.KLHLabelClassifier` |
| 11 | + |
| 12 | +The feature is currently exposed by default in |
| 13 | +:class:`otx.backend.native.models.classification.hlabel_models.timm_model.TimmModelHLabelCls`. |
| 14 | +Users may adapt other architectures with minimal modifications by adding the |
| 15 | +same wrapper (``KLHLabelClassifier``) in their model’s ``_finalize_model()``. |
| 16 | + |
| 17 | +Overview |
| 18 | +-------- |
| 19 | + |
| 20 | +Hierarchical classification models predict multiple levels of labels |
| 21 | +(e.g., manufacturer → family → variant). Standard cross-entropy treats each |
| 22 | +level independently, which means models may output **inconsistent** |
| 23 | +combinations such as: |
| 24 | + |
| 25 | +- predicting a correct fine-grained leaf but an incompatible ancestor, or |
| 26 | +- predicting parents and children belonging to different branches. |
| 27 | + |
| 28 | +Tree-Path KL Divergence introduces a path-consistency objective by comparing: |
| 29 | + |
| 30 | +- the model’s *combined* probability distribution across all levels, and |
| 31 | +- a **tree-consistent target distribution** that places probability mass on |
| 32 | + each ground-truth category along the path. |
| 33 | + |
| 34 | +This encourages smooth transitions between hierarchy levels and reduces |
| 35 | +structurally invalid predictions. |
| 36 | + |
| 37 | +How It Works |
| 38 | +------------ |
| 39 | + |
| 40 | +Tree-Path KL Divergence operates on: |
| 41 | + |
| 42 | +- a **list of logits** from each hierarchy level (root → ... → leaf), and |
| 43 | +- a **target index** for each corresponding level. |
| 44 | + |
| 45 | +The algorithm implemented in |
| 46 | +:class:`TreePathKLDivergenceLoss` performs the following: |
| 47 | + |
| 48 | +1. Concatenates all level logits and applies log-softmax. |
| 49 | +2. Constructs a sparse target distribution that allocates equal probability to |
| 50 | + the correct class at each level. |
| 51 | +3. Computes KL divergence between the model’s distribution and the path-aware |
| 52 | + target distribution. |
| 53 | +4. Scales the result by ``loss_weight`` (typically ``1.0``). |
| 54 | + |
| 55 | +In :class:`KLHLabelClassifier`, this KL term is added to the hierarchical |
| 56 | +cross-entropy loss: |
| 57 | + |
| 58 | +- cross-entropy is averaged across all hierarchy levels, |
| 59 | +- KL divergence is multiplied by ``kl_weight``, |
| 60 | +- ``kl_weight = 0`` disables the KL term completely. |
| 61 | + |
| 62 | +Enabling Tree-Path KL Divergence |
| 63 | +-------------------------------- |
| 64 | + |
| 65 | +The recommended entry point is the provided recipe: |
| 66 | + |
| 67 | +.. code-block:: text |
| 68 | +
|
| 69 | + recipe/classification/h_label_cls/efficientnet_v2_kl.yaml |
| 70 | +
|
| 71 | +This recipe uses :class:`TimmModelHLabelCls` and exposes the argument |
| 72 | +``kl_weight`` directly in ``init_args``: |
| 73 | + |
| 74 | +.. code-block:: yaml |
| 75 | +
|
| 76 | + task: H_LABEL_CLS |
| 77 | + model: |
| 78 | + class_path: otx.backend.native.models.classification.hlabel_models.timm_model.TimmModelHLabelCls |
| 79 | + init_args: |
| 80 | + label_info: <LABEL-TREE-INFO> |
| 81 | + model_name: tf_efficientnetv2_s.in21k |
| 82 | + kl_weight: 1.0 |
| 83 | +
|
| 84 | +Using the CLI |
| 85 | +-------------------------------- |
| 86 | + |
| 87 | +To train a hierarchical model with Tree-Path KL Divergence, the CLI requires: |
| 88 | + |
| 89 | +- ``--data_root``: a path to a directory containing an **``annotations/`` folder** |
| 90 | + whose JSON annotation files follow **Datumaro format**. |
| 91 | + See the format specification here: |
| 92 | + |
| 93 | + https://open-edge-platform.github.io/datumaro/stable/docs/data-formats/datumaro_format.html |
| 94 | + |
| 95 | +- ``--config``: the **path to a recipe YAML file**, such as |
| 96 | + ``recipe/classification/h_label_cls/efficientnet_v2_kl.yaml``. |
| 97 | + |
| 98 | +A full training command example: |
| 99 | + |
| 100 | +.. code-block:: bash |
| 101 | +
|
| 102 | + (otx) $ otx train \ |
| 103 | + --config recipe/classification/h_label_cls/efficientnet_v2_kl.yaml \ |
| 104 | + --data_root /path/to/dataset_with_annotations \ |
| 105 | + --model.kl_weight 1.0 |
| 106 | +
|
| 107 | +To disable Tree-Path KL Divergence and train a standard hierarchical model: |
| 108 | + |
| 109 | +.. code-block:: bash |
| 110 | +
|
| 111 | + (otx) $ otx train \ |
| 112 | + --config recipe/classification/h_label_cls/efficientnet_v2_kl.yaml \ |
| 113 | + --model.kl_weight 0.0 |
| 114 | +
|
| 115 | +Extending Other Architectures |
| 116 | +----------------------------- |
| 117 | + |
| 118 | +Currently, Tree-Path KL Divergence is automatically supported only by |
| 119 | +``TimmModelHLabelCls``. To integrate the feature into other architectures, add |
| 120 | +the following logic to the model’s ``_finalize_model`` method: |
| 121 | + |
| 122 | +1. Accept a new ``kl_weight`` argument in the model init. |
| 123 | +2. After constructing the underlying model, wrap it as: |
| 124 | + |
| 125 | + .. code-block:: python |
| 126 | +
|
| 127 | + if self.kl_weight > 0: |
| 128 | + model = KLHLabelClassifier(model, kl_weight=self.kl_weight) |
| 129 | +
|
| 130 | +3. Ensure that the model returns a list of logits aligned with the hierarchy. |
| 131 | + |
| 132 | +Only a few lines are required, and this enables the same training procedure |
| 133 | +for any backbone (ResNet, ViT, ConvNeXt, etc.). |
| 134 | + |
| 135 | +When to Use Tree-Path KL Divergence |
| 136 | +----------------------------------- |
| 137 | + |
| 138 | +Tree-Path KL Divergence is most helpful when: |
| 139 | + |
| 140 | +- the label space forms a strict taxonomy, |
| 141 | +- incorrect parent/child combinations are undesirable, |
| 142 | +- fine-grained classes are scarce and benefit from structural priors, |
| 143 | +- you want improved consistency across hierarchy levels. |
| 144 | + |
| 145 | +Practically, start with: |
| 146 | + |
| 147 | +- ``kl_weight = 1.0`` or ``2.0`` for most datasets, |
| 148 | +- monitor both fine-grained and coarse-level accuracy, |
| 149 | +- adjust ``kl_weight`` based on the trade-off between accuracy and |
| 150 | + hierarchical consistency. |
| 151 | + |
| 152 | +Practical Tips |
| 153 | +-------------- |
| 154 | + |
| 155 | +- Ensure that ``label_info`` correctly describes the hierarchy. |
| 156 | +- Excessively large ``kl_weight`` values may over-regularize the model. |
| 157 | +- For benchmarking, compare: |
| 158 | + - ``kl_weight = 0`` (baseline), |
| 159 | + - ``kl_weight = 1–4`` (KL-enabled variants). |
| 160 | +- Tree-Path KL acts as a *training-time* consistency constraint; it does not |
| 161 | + modify architecture or inference cost. |
| 162 | + |
| 163 | +Limitations |
| 164 | +----------- |
| 165 | + |
| 166 | +- Supported out-of-the-box only for :class:`TimmModelHLabelCls`. |
| 167 | +- Requires the model to output logits for **each level** of the hierarchy. |
| 168 | +- Not applicable to flat classification tasks. |
| 169 | + |
| 170 | + |
0 commit comments