-
Notifications
You must be signed in to change notification settings - Fork 462
add Tree-Path KL Divergence loss for hier classification + unit test #4706
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
PS : This and #4689 are PRs from our GSOC contributor. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @Jyc323 for you contribution, I have a question:
- This PR adds a new loss implementation but does not reveal how OTX users can utilize that loss when solving their tasks. Could you also add an integration into existing h-cls training pipeline and mention that improvement in OTX documentation? I believe loss choice should be configurable via OTX model recipe.
lib/tests/unit/backend/native/models/classification/losses/test_tree_path_kl_divergence.py
Outdated
Show resolved
Hide resolved
lib/src/otx/backend/native/models/classification/losses/tree_path_KL_divergence_loss.py
Outdated
Show resolved
Hide resolved
lib/src/otx/backend/native/models/classification/losses/tree_path_KL_divergence_loss.py
Outdated
Show resolved
Hide resolved
lib/tests/unit/backend/native/models/classification/losses/test_tree_path_kl_divergence.py
Outdated
Show resolved
Hide resolved
lib/src/otx/backend/native/models/classification/losses/tree_path_KL_divergence_loss.py
Outdated
Show resolved
Hide resolved
Code quality checks can be launched locally with |
Hi @kprokofi, as suggested, I created a new example recipe yaml file While running
I'm not sure if it's caused by something on my end. Also, could you let me know the appropriate location for the documentation? If possible, please provide a template or example I can refer to. Thanks! |
Hi @sovrasov, I merged this PR with the current develop branch. Please let me know the next step. Thanks! |
Thanks @Jyc323 for the recent updates. Once new classifier is available for each classification models and extra unit tests are added, we're almost there: only docs are to be added. |
library/tests/unit/backend/native/models/classification/losses/test_tree_path_kl_divergence.py
Show resolved
Hide resolved
library/src/otx/backend/native/models/classification/hlabel_models/timm_model.py
Outdated
Show resolved
Hide resolved
library/src/otx/backend/native/models/classification/classifier/h_label_classifier.py
Show resolved
Hide resolved
Hi @sovrasov, thanks for your suggestion, I carefully thought about it. I refactor |
library/src/otx/backend/native/models/classification/hlabel_models/timm_model.py
Outdated
Show resolved
Hide resolved
library/src/otx/backend/native/models/classification/hlabel_models/base.py
Outdated
Show resolved
Hide resolved
Thanks @Jyc323, I've left a couple of comments to the latest changes |
Hi @sovrasov , I resolved the comment, and other errors from tox. Please have a review, thanks! |
Hi, @sovrasov, I have fixed the errors, please have a review, thanks |
Summary
Introduced a new hierarchical regularization loss
TreePathKLDivergenceLoss
that enforces consistency across hierarchical label predictions by minimizing KL divergence between model probabilities and a constructed target distribution along the true label path.Key Changes
tree_path_kl_divergence.py
underlib/src/otx/backend/native/models/common/losses/
TreePathKLDivergenceLoss
with configurable reduction (batchmean
,mean
,sum
) and weighting.lib/tests/unit/backend/native/models/common/losses/test_tree_path_kl_divergence.py
covering:CrossEntropyLoss
).Motivation
Existing metrics capture evaluation aspects of hierarchical classification, but training lacked a regularizer that explicitly
encourages consistency across levels. The Tree-Path KL loss:
This makes the training loop more robust for hierarchical models, reducing path-inconsistent predictions and improving downstream metrics.
Preliminary Results
We evaluated the effect of Tree-Path KL loss on multiple hierarchical datasets using EfficientNet_v2 (timm model) from OTX. The KL weight (λ) was varied between 0 (baseline), 1, and 5. Results show that incorporating KL improves performance on most datasets.
Observations:
These results indicate that Tree-Path KL is a promising regularizer for hierarchical classification tasks.
How to test
Checklist
License
Feel free to contact the maintainers if that's a concern.