Skip to content

Conversation

Jyc323
Copy link
Contributor

@Jyc323 Jyc323 commented Sep 16, 2025

Summary

Introduced a new hierarchical regularization loss TreePathKLDivergenceLoss that enforces consistency across hierarchical label predictions by minimizing KL divergence between model probabilities and a constructed target distribution along the true label path.

Key Changes

  • Added tree_path_kl_divergence.py under lib/src/otx/backend/native/models/common/losses/
    • Implements TreePathKLDivergenceLoss with configurable reduction (batchmean, mean, sum) and weighting.
    • Target distribution assigns equal probability (1/L) to the GT class at each hierarchy level.
  • Added comprehensive unit tests in lib/tests/unit/backend/native/models/common/losses/test_tree_path_kl_divergence.py covering:
    • Shape and finiteness checks across multiple levels.
    • Gradient flow validation to ensure backprop works correctly.
    • Alignment vs misalignment sanity check: aligned logits give lower loss.
    • Exact-value correctness for:
      • Single-level case (loss equals CrossEntropyLoss).
      • Multi-level case (manual KL computation matches PyTorch loss).

Motivation

Existing metrics capture evaluation aspects of hierarchical classification, but training lacked a regularizer that explicitly
encourages consistency across levels. The Tree-Path KL loss:

  • Penalizes predictions that scatter probability mass off the ground-truth path.
  • Provides a smooth differentiable signal to improve hierarchical consistency beyond hard CE losses.

This makes the training loop more robust for hierarchical models, reducing path-inconsistent predictions and improving downstream metrics.

Preliminary Results

We evaluated the effect of Tree-Path KL loss on multiple hierarchical datasets using EfficientNet_v2 (timm model) from OTX. The KL weight (λ) was varied between 0 (baseline), 1, and 5. Results show that incorporating KL improves performance on most datasets.

Dataset Metric λ=0 λ=1 λ=5
CUB_200_2011 Accuracy 0.853 0.850 0.858
Full Path Accuracy 0.735 0.724 0.764
Leaf Accuracy 0.825 0.821 0.812
Weighted Precision 0.836 0.832 0.829
AmazonParrot Accuracy 0.718 0.725 0.742
Full Path Accuracy 0.611 0.614 0.659
Leaf Accuracy 0.333 0.334 0.346
Weighted Precision 0.490 0.504 0.504
FGVC-Aircraft Accuracy 0.709 0.742 0.764
Full Path Accuracy 0.551 0.593 0.629
Leaf Accuracy 0.614 0.639 0.660
Weighted Precision 0.705 0.730 0.746
Fish-Vista Accuracy 0.460 0.497 0.507
Full Path Accuracy 0.217 0.264 0.280
Leaf Accuracy 0.416 0.452 0.458
Weighted Precision 0.399 0.435 0.438

Observations:

  • Increasing KL weight to 5 often improves Full Path Accuracy, showing stronger hierarchical consistency.
  • Gains are consistent across datasets of varying granularity (birds, parrots, aircraft, fish).
  • Some trade-off exists in Leaf Accuracy (CUB shows slight drop), but overall performance trends upward in hierarchy-sensitive metrics.

These results indicate that Tree-Path KL is a promising regularizer for hierarchical classification tasks.

How to test

 pytest -q lib/tests/unit/backend/native/models/classification/losses/test_tree_path_kl_divergence.py

Checklist

  • I have added unit tests to cover my changes.​
  • I have added integration tests to cover my changes.​
  • I have ran e2e tests and there is no issues.
  • I have added the description of my changes into CHANGELOG in my target branch (e.g., CHANGELOG in develop).​
  • I have updated the documentation in my target branch accordingly (e.g., documentation in develop).
  • I have linked related issues.

License

  • I submit my code changes under the same Apache License that covers the project.
    Feel free to contact the maintainers if that's a concern.
  • I have updated the license header for each file (see an example below).
# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

@rajeshgangireddy
Copy link
Contributor

PS : This and #4689 are PRs from our GSOC contributor.
Part of the project was 'H-label classification for VLMs' but the metrics and loss functions are useful and would be a nice addition to OTX.

@sovrasov sovrasov added the GSoC label Sep 17, 2025
Copy link
Member

@sovrasov sovrasov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Jyc323 for you contribution, I have a question:

  • This PR adds a new loss implementation but does not reveal how OTX users can utilize that loss when solving their tasks. Could you also add an integration into existing h-cls training pipeline and mention that improvement in OTX documentation? I believe loss choice should be configurable via OTX model recipe.

@sovrasov
Copy link
Member

sovrasov commented Sep 17, 2025

Code quality checks can be launched locally with cd lib && tox r -vv -e pre-commit

@Jyc323
Copy link
Contributor Author

Jyc323 commented Sep 21, 2025

Hi @kprokofi, as suggested, I created a new example recipe yaml file efficientnet_v2_kl.yaml for integrating tree path KL divergence loss. To successfully running the recipe file with engine, I added a new model class TimmModelHLabelClsWithKL under lib/src/otx/backend/native/models/classification/hlabel_models/timm_model.py, and a new classifier class KLHLabelClassifier under lib/src/otx/backend/native/models/classification/classifier/h_label_classifier.py. Could you please take a look and share any suggestions?

While running tox r -vv -e pre-commit, I encountered an error:

lib/src/otx/backend/native/utils/utils.py:89: error: Module has no attribute "metrics" [attr-defined]
Found 1 error in 1 file (checked 363 source files)git

I'm not sure if it's caused by something on my end.

Also, could you let me know the appropriate location for the documentation? If possible, please provide a template or example I can refer to.

Thanks!

@Jyc323 Jyc323 requested review from a team and Radwan-Ibrahim as code owners October 2, 2025 23:59
@Jyc323
Copy link
Contributor Author

Jyc323 commented Oct 3, 2025

Hi @sovrasov, I merged this PR with the current develop branch. Please let me know the next step. Thanks!

@sovrasov
Copy link
Member

sovrasov commented Oct 6, 2025

Hi @kprokofi, as suggested, I created a new example recipe yaml file efficientnet_v2_kl.yaml for integrating tree path KL divergence loss. To successfully running the recipe file with engine, I added a new model class TimmModelHLabelClsWithKL under lib/src/otx/backend/native/models/classification/hlabel_models/timm_model.py, and a new classifier class KLHLabelClassifier under lib/src/otx/backend/native/models/classification/classifier/h_label_classifier.py. Could you please take a look and share any suggestions?

While running tox r -vv -e pre-commit, I encountered an error:

lib/src/otx/backend/native/utils/utils.py:89: error: Module has no attribute "metrics" [attr-defined]
Found 1 error in 1 file (checked 363 source files)git

I'm not sure if it's caused by something on my end.

Also, could you let me know the appropriate location for the documentation? If possible, please provide a template or example I can refer to.

Thanks!

Thanks @Jyc323 for the recent updates. Once new classifier is available for each classification models and extra unit tests are added, we're almost there: only docs are to be added.
Since we have an advanced tutorial for h-cls metrics, we can extend it to an advanced h-cls tutorial by introducing examples on how to use the new loss with the new metrics.

@Jyc323
Copy link
Contributor Author

Jyc323 commented Oct 9, 2025

Hi @sovrasov, thanks for your suggestion, I carefully thought about it. I refactor OTXHlabelClsModel by a dynamic wrapping approach to avoid modifying all child implementations or external calls.
In base.py, I added a custom __getattribute__ hook that automatically wraps each child’s _create_model() at runtime. The wrapper injects a common post-processing step that executes immediately after _create_model() finishes.
The parent class now transparently runs an additional step, _finalize_model(), after model creation. _finalize_model() replaces the classifier with KLHLabelClassifier when kl_weight > 0, enabling hybrid training that combines the Tree-Path KL loss with standard cross-entropy.
Since kl_weight is a new hyperparameter, I added **kwargs parameter to OTXHlabelClsModel. This allows the class to accept additional optional arguments for future extensibility. However, I’m not sure if it reduces type safety or clarity.
I also add more unit tests based on the suggestions.
For the changed files, please take a look and let me know if it aligns with your expectations. If it does, I’ll proceed with adding the related documentation

@sovrasov
Copy link
Member

Hi @sovrasov, thanks for your suggestion, I carefully thought about it. I refactor OTXHlabelClsModel by a dynamic wrapping approach to avoid modifying all child implementations or external calls. In base.py, I added a custom __getattribute__ hook that automatically wraps each child’s _create_model() at runtime. The wrapper injects a common post-processing step that executes immediately after _create_model() finishes. The parent class now transparently runs an additional step, _finalize_model(), after model creation. _finalize_model() replaces the classifier with KLHLabelClassifier when kl_weight > 0, enabling hybrid training that combines the Tree-Path KL loss with standard cross-entropy. Since kl_weight is a new hyperparameter, I added **kwargs parameter to OTXHlabelClsModel. This allows the class to accept additional optional arguments for future extensibility. However, I’m not sure if it reduces type safety or clarity. I also add more unit tests based on the suggestions. For the changed files, please take a look and let me know if it aligns with your expectations. If it does, I’ll proceed with adding the related documentation

Thanks @Jyc323, I've left a couple of comments to the latest changes

@Jyc323
Copy link
Contributor Author

Jyc323 commented Oct 13, 2025

Hi @sovrasov , I resolved the comment, and other errors from tox. Please have a review, thanks!

@sovrasov
Copy link
Member

Hi @sovrasov , I resolved the comment, and other errors from tox. Please have a review, thanks!

@Jyc323 it looks like ruff check is still failing:

ruff.....................................................................Failed
- hook id: ruff
- files were modified by this hook

@Jyc323
Copy link
Contributor Author

Jyc323 commented Oct 16, 2025

Hi, @sovrasov, I have fixed the errors, please have a review, thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

GSoC TEST Any changes in tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants