From 9d52d9f7857aebbcb05c04106ba1696366f39880 Mon Sep 17 00:00:00 2001 From: Jyc323 Date: Mon, 15 Sep 2025 22:07:38 -0700 Subject: [PATCH 01/13] add Tree-Path KL Divergence loss for hier classification + unit test --- .../losses/tree_path_KL_divergence_loss.py | 53 +++++++ .../losses/test_tree_path_kl_divergence.py | 134 ++++++++++++++++++ 2 files changed, 187 insertions(+) create mode 100644 lib/src/otx/backend/native/models/classification/losses/tree_path_KL_divergence_loss.py create mode 100644 lib/tests/unit/backend/native/models/classification/losses/test_tree_path_kl_divergence.py diff --git a/lib/src/otx/backend/native/models/classification/losses/tree_path_KL_divergence_loss.py b/lib/src/otx/backend/native/models/classification/losses/tree_path_KL_divergence_loss.py new file mode 100644 index 0000000000..f94c07d7df --- /dev/null +++ b/lib/src/otx/backend/native/models/classification/losses/tree_path_KL_divergence_loss.py @@ -0,0 +1,53 @@ +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Module for defining TreePathKLDivergenceLoss.""" + +from __future__ import annotations +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class TreePathKLDivergenceLoss(nn.Module): + """ + KL divergence between model distribution over concatenated heads and a + target distribution that allocates equal mass to the ground-truth class + at each hierarchy level. + + Inputs: + logits_list: list of tensors [B, C_l], ordered from root -> leaf + targets: LongTensor [B, L] with per-level GT indices (L == len(logits_list)) + + The target distribution places 1/L probability on the GT index for each level, + and 0 elsewhere, then uses KLDivLoss(log_softmax(logits), target_probs). + """ + + def __init__(self, reduction: str = "batchmean", loss_weight: float = 1.0): + super().__init__() + self.reduction = reduction + self.loss_weight = loss_weight + self.kl_div = nn.KLDivLoss(reduction=self.reduction) + + def forward(self, logits_list: List[torch.Tensor], targets: torch.Tensor) -> torch.Tensor: + assert isinstance(logits_list, (list, tuple)) and len(logits_list) > 0, "logits_list must be non-empty" + num_levels = len(logits_list) + + # concat logits across all levels + dims = [t.size(1) for t in logits_list] + logits_concat = torch.cat(logits_list, dim=1) # [B, sum(C_l)] + log_probs = F.log_softmax(logits_concat, dim=1) # [B, sum(C_l)] + + # build sparse target distribution with 1/L at each GT index + B = log_probs.size(0) + tgt = torch.zeros_like(log_probs) # [B, sum(C_l)] + offset = 0 + for num_c, tgt_l in zip(dims, targets.T): # level-by-level + idx_rows = torch.arange(B, device=log_probs.device) + tgt[idx_rows, offset + tgt_l] = 1.0 / num_levels + offset += num_c + + kl = self.kl_div(log_probs, tgt) + return self.loss_weight * kl diff --git a/lib/tests/unit/backend/native/models/classification/losses/test_tree_path_kl_divergence.py b/lib/tests/unit/backend/native/models/classification/losses/test_tree_path_kl_divergence.py new file mode 100644 index 0000000000..d16f18ff32 --- /dev/null +++ b/lib/tests/unit/backend/native/models/classification/losses/test_tree_path_kl_divergence.py @@ -0,0 +1,134 @@ +import torch +import pytest +import torch.nn.functional as F + +from otx.backend.native.models.classification.losses.tree_path_KL_divergence_loss import TreePathKLDivergenceLoss + + +@pytest.mark.parametrize("levels,classes_per_level", [ + (2, [3, 5]), + (3, [2, 3, 4]), +]) +def test_forward_scalar_and_finite(levels, classes_per_level): + torch.manual_seed(0) + B = 4 + logits_list = [torch.randn(B, c) for c in classes_per_level] + targets = torch.stack([torch.randint(0, c, (B,)) for c in classes_per_level], dim=1) + + loss_fn = TreePathKLDivergenceLoss(reduction="batchmean") + loss = loss_fn(logits_list, targets) + assert loss.ndim == 0 + assert torch.isfinite(loss) + assert loss.item() >= -1e-7 + + +def test_backward_produces_grads(): + B = 3 + C = [4, 6] + logits_list = [torch.randn(B, c, requires_grad=True) for c in C] + targets = torch.stack([torch.randint(0, c, (B,)) for c in C], dim=1) + + loss = TreePathKLDivergenceLoss()(logits_list, targets) + loss.backward() + for logit in logits_list: + assert logit.grad is not None + assert torch.isfinite(logit.grad).all() + + +def test_alignment_vs_misalignment_loss(): + B = 2 + C0, C1 = 3, 4 + targets = torch.tensor([[0, 1], [2, 3]]) + + # Aligned: boost GT logits + aligned0 = torch.zeros(B, C0) + aligned1 = torch.zeros(B, C1) + aligned0[torch.arange(B), targets[:, 0]] = 5.0 + aligned1[torch.arange(B), targets[:, 1]] = 5.0 + + # Misaligned: boost wrong logits + mis0 = torch.zeros(B, C0) + mis1 = torch.zeros(B, C1) + mis0[torch.arange(B), (targets[:, 0] + 1) % C0] = 5.0 + mis1[torch.arange(B), (targets[:, 1] + 1) % C1] = 5.0 + + loss_fn = TreePathKLDivergenceLoss() + loss_aligned = loss_fn([aligned0, aligned1], targets) + loss_misaligned = loss_fn([mis0, mis1], targets) + assert loss_aligned < loss_misaligned + + +def test_single_level_exact_value(): + """ + With a single level, KL reduces to CE between predicted softmax and one-hot target. + We check exact value against F.cross_entropy. + """ + + B, C = 2, 3 + logits = torch.tensor([[2.0, 0.0, -1.0], + [0.5, 1.0, -0.5]]) + targets = torch.tensor([[0], [2]]) # shape [B,1] + + # TreePathKL + loss_fn = TreePathKLDivergenceLoss(reduction="batchmean") + kl_loss = loss_fn([logits], targets) + + # CrossEntropy with one-hot is same as NLLLoss(log_softmax) + ce_loss = F.cross_entropy(logits, targets.view(-1), reduction="mean") + + assert torch.allclose(kl_loss, ce_loss, atol=1e-6) + +def test_multi_level_exact_value_batchmean(): + """ + Exact numerical check for L=2 levels with 'batchmean' reduction. + + Loss per sample (PyTorch KLDivLoss): + KL(p || q) = sum_j p_j * (log(p_j) - log(q_j)) + where input to KLDivLoss is log(q_j) (our model log_probs), + and the target is p_j (our constructed target distribution). + With reduction='batchmean', PyTorch divides the total sum by batch size. + """ + + # Use double for better numerical agreement + B = 2 + l0, l1 = 2, 3 + logits0 = torch.tensor([[2.0, -1.0], + [0.0, 1.0]], dtype=torch.float64) # [B, l0] + logits1 = torch.tensor([[ 0.5, 0.0, -0.5], + [-1.0, 2.0, 0.5]], dtype=torch.float64) # [B, l1] + logits_list = [logits0, logits1] + + # Ground-truth indices per level + # sample 0: level0->0, level1->1 + # sample 1: level0->1, level1->2 + targets = torch.tensor([[0, 1], + [1, 2]], dtype=torch.long) # [B, 2] + L = 2 # number of levels + + # Model log probs over concatenated heads + concat = torch.cat([logits0, logits1], dim=1) # [B, l0+l1] + log_q = F.log_softmax(concat, dim=1) # log(q_j) + + # Build target distribution p: 1/L at each GT index, 0 elsewhere + p = torch.zeros_like(log_q, dtype=torch.float64) + offset = 0 + for num_c, tgt_l in zip([l0, l1], targets.T): + rows = torch.arange(B) + p[rows, offset + tgt_l] = 1.0 / L + offset += num_c + + # Manual KL with 'batchmean' reduction: + # sum_i sum_j p_ij * (log p_ij - log q_ij) / B + # (avoid log(0) by masking since p is sparse) + mask = p > 0 + log_p = torch.zeros_like(p) + log_p[mask] = torch.log(p[mask]) + manual_kl = (p * (log_p - log_q)).sum() / B + + # Loss under test (must match manual) + loss_fn = TreePathKLDivergenceLoss(reduction="batchmean") + test_kl = loss_fn([logits0.float(), logits1.float()], targets) + + assert torch.allclose(test_kl.double(), manual_kl, atol=1e-8), ( + f"manual={manual_kl.item():.12f} vs loss={test_kl.item():.12f}" + ) \ No newline at end of file From 1253587cbe1a36f86286cc9b6d534ada3eb1b838 Mon Sep 17 00:00:00 2001 From: Jyc323 Date: Thu, 18 Sep 2025 23:39:39 -0700 Subject: [PATCH 02/13] fix code review comments --- ...oss.py => tree_path_kl_divergence_loss.py} | 20 +++++++++---------- 1 file changed, 9 insertions(+), 11 deletions(-) rename lib/src/otx/backend/native/models/classification/losses/{tree_path_KL_divergence_loss.py => tree_path_kl_divergence_loss.py} (69%) diff --git a/lib/src/otx/backend/native/models/classification/losses/tree_path_KL_divergence_loss.py b/lib/src/otx/backend/native/models/classification/losses/tree_path_kl_divergence_loss.py similarity index 69% rename from lib/src/otx/backend/native/models/classification/losses/tree_path_KL_divergence_loss.py rename to lib/src/otx/backend/native/models/classification/losses/tree_path_kl_divergence_loss.py index f94c07d7df..7603153f95 100644 --- a/lib/src/otx/backend/native/models/classification/losses/tree_path_KL_divergence_loss.py +++ b/lib/src/otx/backend/native/models/classification/losses/tree_path_kl_divergence_loss.py @@ -1,19 +1,17 @@ -# Copyright (C) 2022 Intel Corporation +# Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 """Module for defining TreePathKLDivergenceLoss.""" from __future__ import annotations -from typing import List import torch -import torch.nn as nn import torch.nn.functional as F +from torch import nn class TreePathKLDivergenceLoss(nn.Module): - """ - KL divergence between model distribution over concatenated heads and a + """KL divergence between model distribution over concatenated heads and a target distribution that allocates equal mass to the ground-truth class at each hierarchy level. @@ -25,26 +23,26 @@ class TreePathKLDivergenceLoss(nn.Module): and 0 elsewhere, then uses KLDivLoss(log_softmax(logits), target_probs). """ - def __init__(self, reduction: str = "batchmean", loss_weight: float = 1.0): + def __init__(self, reduction: str | None = "batchmean", loss_weight: float = 1.0): super().__init__() self.reduction = reduction self.loss_weight = loss_weight self.kl_div = nn.KLDivLoss(reduction=self.reduction) - def forward(self, logits_list: List[torch.Tensor], targets: torch.Tensor) -> torch.Tensor: + def forward(self, logits_list: list[torch.Tensor], targets: torch.Tensor) -> torch.Tensor: assert isinstance(logits_list, (list, tuple)) and len(logits_list) > 0, "logits_list must be non-empty" num_levels = len(logits_list) # concat logits across all levels dims = [t.size(1) for t in logits_list] - logits_concat = torch.cat(logits_list, dim=1) # [B, sum(C_l)] - log_probs = F.log_softmax(logits_concat, dim=1) # [B, sum(C_l)] + logits_concat = torch.cat(logits_list, dim=1) # [B, sum(C_l)] + log_probs = F.log_softmax(logits_concat, dim=1) # [B, sum(C_l)] # build sparse target distribution with 1/L at each GT index B = log_probs.size(0) - tgt = torch.zeros_like(log_probs) # [B, sum(C_l)] + tgt = torch.zeros_like(log_probs) # [B, sum(C_l)] offset = 0 - for num_c, tgt_l in zip(dims, targets.T): # level-by-level + for num_c, tgt_l in zip(dims, targets.T): # level-by-level idx_rows = torch.arange(B, device=log_probs.device) tgt[idx_rows, offset + tgt_l] = 1.0 / num_levels offset += num_c From 70b03a5c434a38efc63e3feda2923889c74db175 Mon Sep 17 00:00:00 2001 From: Jyc323 Date: Fri, 19 Sep 2025 01:01:45 -0700 Subject: [PATCH 03/13] fix tox errors --- .../losses/tree_path_kl_divergence_loss.py | 17 ++-- .../losses/test_tree_path_kl_divergence.py | 97 ++++++++++--------- 2 files changed, 58 insertions(+), 56 deletions(-) diff --git a/lib/src/otx/backend/native/models/classification/losses/tree_path_kl_divergence_loss.py b/lib/src/otx/backend/native/models/classification/losses/tree_path_kl_divergence_loss.py index 7603153f95..d0ea1442bf 100644 --- a/lib/src/otx/backend/native/models/classification/losses/tree_path_kl_divergence_loss.py +++ b/lib/src/otx/backend/native/models/classification/losses/tree_path_kl_divergence_loss.py @@ -6,14 +6,12 @@ from __future__ import annotations import torch -import torch.nn.functional as F from torch import nn +from torch.nn import functional class TreePathKLDivergenceLoss(nn.Module): - """KL divergence between model distribution over concatenated heads and a - target distribution that allocates equal mass to the ground-truth class - at each hierarchy level. + """KL divergence between model distribution over concatenated heads and a target distribution. Inputs: logits_list: list of tensors [B, C_l], ordered from root -> leaf @@ -30,20 +28,23 @@ def __init__(self, reduction: str | None = "batchmean", loss_weight: float = 1.0 self.kl_div = nn.KLDivLoss(reduction=self.reduction) def forward(self, logits_list: list[torch.Tensor], targets: torch.Tensor) -> torch.Tensor: - assert isinstance(logits_list, (list, tuple)) and len(logits_list) > 0, "logits_list must be non-empty" + """Calculate tree_path KL Divergence loss.""" + if not (isinstance(logits_list, (list, tuple)) and len(logits_list) > 0): + msg = "logits_list must be non-empty" + raise ValueError(msg) num_levels = len(logits_list) # concat logits across all levels dims = [t.size(1) for t in logits_list] logits_concat = torch.cat(logits_list, dim=1) # [B, sum(C_l)] - log_probs = F.log_softmax(logits_concat, dim=1) # [B, sum(C_l)] + log_probs = functional.log_softmax(logits_concat, dim=1) # [B, sum(C_l)] # build sparse target distribution with 1/L at each GT index - B = log_probs.size(0) + batch = log_probs.size(0) tgt = torch.zeros_like(log_probs) # [B, sum(C_l)] offset = 0 for num_c, tgt_l in zip(dims, targets.T): # level-by-level - idx_rows = torch.arange(B, device=log_probs.device) + idx_rows = torch.arange(batch, device=log_probs.device) tgt[idx_rows, offset + tgt_l] = 1.0 / num_levels offset += num_c diff --git a/lib/tests/unit/backend/native/models/classification/losses/test_tree_path_kl_divergence.py b/lib/tests/unit/backend/native/models/classification/losses/test_tree_path_kl_divergence.py index d16f18ff32..673a22fb7c 100644 --- a/lib/tests/unit/backend/native/models/classification/losses/test_tree_path_kl_divergence.py +++ b/lib/tests/unit/backend/native/models/classification/losses/test_tree_path_kl_divergence.py @@ -1,19 +1,25 @@ -import torch +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + import pytest -import torch.nn.functional as F +import torch +from torch.nn import functional -from otx.backend.native.models.classification.losses.tree_path_KL_divergence_loss import TreePathKLDivergenceLoss +from otx.backend.native.models.classification.losses.tree_path_kl_divergence_loss import TreePathKLDivergenceLoss -@pytest.mark.parametrize("levels,classes_per_level", [ - (2, [3, 5]), - (3, [2, 3, 4]), -]) +@pytest.mark.parametrize( + ("levels", "classes_per_level"), + [ + (2, [3, 5]), + (3, [2, 3, 4]), + ], +) def test_forward_scalar_and_finite(levels, classes_per_level): torch.manual_seed(0) - B = 4 - logits_list = [torch.randn(B, c) for c in classes_per_level] - targets = torch.stack([torch.randint(0, c, (B,)) for c in classes_per_level], dim=1) + batch = 4 + logits_list = [torch.randn(batch, c) for c in classes_per_level] + targets = torch.stack([torch.randint(0, c, (batch,)) for c in classes_per_level], dim=1) loss_fn = TreePathKLDivergenceLoss(reduction="batchmean") loss = loss_fn(logits_list, targets) @@ -23,10 +29,10 @@ def test_forward_scalar_and_finite(levels, classes_per_level): def test_backward_produces_grads(): - B = 3 - C = [4, 6] - logits_list = [torch.randn(B, c, requires_grad=True) for c in C] - targets = torch.stack([torch.randint(0, c, (B,)) for c in C], dim=1) + batch = 3 + channel = [4, 6] + logits_list = [torch.randn(batch, c, requires_grad=True) for c in channel] + targets = torch.stack([torch.randint(0, c, (batch,)) for c in channel], dim=1) loss = TreePathKLDivergenceLoss()(logits_list, targets) loss.backward() @@ -36,21 +42,21 @@ def test_backward_produces_grads(): def test_alignment_vs_misalignment_loss(): - B = 2 - C0, C1 = 3, 4 + batch = 2 + channel0, channel1 = 3, 4 targets = torch.tensor([[0, 1], [2, 3]]) # Aligned: boost GT logits - aligned0 = torch.zeros(B, C0) - aligned1 = torch.zeros(B, C1) - aligned0[torch.arange(B), targets[:, 0]] = 5.0 - aligned1[torch.arange(B), targets[:, 1]] = 5.0 + aligned0 = torch.zeros(batch, channel0) + aligned1 = torch.zeros(batch, channel1) + aligned0[torch.arange(batch), targets[:, 0]] = 5.0 + aligned1[torch.arange(batch), targets[:, 1]] = 5.0 # Misaligned: boost wrong logits - mis0 = torch.zeros(B, C0) - mis1 = torch.zeros(B, C1) - mis0[torch.arange(B), (targets[:, 0] + 1) % C0] = 5.0 - mis1[torch.arange(B), (targets[:, 1] + 1) % C1] = 5.0 + mis0 = torch.zeros(batch, channel0) + mis1 = torch.zeros(batch, channel1) + mis0[torch.arange(batch), (targets[:, 0] + 1) % channel0] = 5.0 + mis1[torch.arange(batch), (targets[:, 1] + 1) % channel1] = 5.0 loss_fn = TreePathKLDivergenceLoss() loss_aligned = loss_fn([aligned0, aligned1], targets) @@ -64,20 +70,19 @@ def test_single_level_exact_value(): We check exact value against F.cross_entropy. """ - B, C = 2, 3 - logits = torch.tensor([[2.0, 0.0, -1.0], - [0.5, 1.0, -0.5]]) + logits = torch.tensor([[2.0, 0.0, -1.0], [0.5, 1.0, -0.5]]) targets = torch.tensor([[0], [2]]) # shape [B,1] - # TreePathKL + # TreePathKLP loss_fn = TreePathKLDivergenceLoss(reduction="batchmean") kl_loss = loss_fn([logits], targets) # CrossEntropy with one-hot is same as NLLLoss(log_softmax) - ce_loss = F.cross_entropy(logits, targets.view(-1), reduction="mean") + ce_loss = functional.cross_entropy(logits, targets.view(-1), reduction="mean") assert torch.allclose(kl_loss, ce_loss, atol=1e-6) + def test_multi_level_exact_value_batchmean(): """ Exact numerical check for L=2 levels with 'batchmean' reduction. @@ -90,45 +95,41 @@ def test_multi_level_exact_value_batchmean(): """ # Use double for better numerical agreement - B = 2 + batch = 2 l0, l1 = 2, 3 - logits0 = torch.tensor([[2.0, -1.0], - [0.0, 1.0]], dtype=torch.float64) # [B, l0] - logits1 = torch.tensor([[ 0.5, 0.0, -0.5], - [-1.0, 2.0, 0.5]], dtype=torch.float64) # [B, l1] - logits_list = [logits0, logits1] + logits0 = torch.tensor([[2.0, -1.0], [0.0, 1.0]], dtype=torch.float64) # [B, l0] + logits1 = torch.tensor([[0.5, 0.0, -0.5], [-1.0, 2.0, 0.5]], dtype=torch.float64) # [B, l1] # Ground-truth indices per level # sample 0: level0->0, level1->1 # sample 1: level0->1, level1->2 - targets = torch.tensor([[0, 1], - [1, 2]], dtype=torch.long) # [B, 2] - L = 2 # number of levels + targets = torch.tensor([[0, 1], [1, 2]], dtype=torch.long) # [B, 2] + level = 2 # number of levels # Model log probs over concatenated heads - concat = torch.cat([logits0, logits1], dim=1) # [B, l0+l1] - log_q = F.log_softmax(concat, dim=1) # log(q_j) + concat = torch.cat([logits0, logits1], dim=1) # [B, l0+l1] + log_q = functional.log_softmax(concat, dim=1) # log(q_j) - # Build target distribution p: 1/L at each GT index, 0 elsewhere + # Build target distribution p: 1/level at each GT index, 0 elsewhere p = torch.zeros_like(log_q, dtype=torch.float64) offset = 0 for num_c, tgt_l in zip([l0, l1], targets.T): - rows = torch.arange(B) - p[rows, offset + tgt_l] = 1.0 / L + rows = torch.arange(batch) + p[rows, offset + tgt_l] = 1.0 / level offset += num_c # Manual KL with 'batchmean' reduction: - # sum_i sum_j p_ij * (log p_ij - log q_ij) / B + # sum_i sum_j p_ij * (log p_ij - log q_ij) / batch # (avoid log(0) by masking since p is sparse) mask = p > 0 log_p = torch.zeros_like(p) log_p[mask] = torch.log(p[mask]) - manual_kl = (p * (log_p - log_q)).sum() / B + manual_kl = (p * (log_p - log_q)).sum() / batch # Loss under test (must match manual) loss_fn = TreePathKLDivergenceLoss(reduction="batchmean") test_kl = loss_fn([logits0.float(), logits1.float()], targets) - assert torch.allclose(test_kl.double(), manual_kl, atol=1e-8), ( - f"manual={manual_kl.item():.12f} vs loss={test_kl.item():.12f}" - ) \ No newline at end of file + assert torch.allclose( + test_kl.double(), manual_kl, atol=1e-8 + ), f"manual={manual_kl.item():.12f} vs loss={test_kl.item():.12f}" From a6c70ad6574c36426d9e9b055c25a974e6b8d815 Mon Sep 17 00:00:00 2001 From: Jyc323 Date: Sat, 20 Sep 2025 17:57:20 -0700 Subject: [PATCH 04/13] integrate KL loss via recipe YAML; add new H-label model & classifier --- .../classification/classifier/__init__.py | 4 +- .../classifier/h_label_classifier.py | 85 ++++++++++++ .../hlabel_models/timm_model.py | 65 ++++++++- .../h_label_cls/efficientnet_v2_kl.yaml | 125 ++++++++++++++++++ 4 files changed, 276 insertions(+), 3 deletions(-) create mode 100644 lib/src/otx/recipe/classification/h_label_cls/efficientnet_v2_kl.yaml diff --git a/lib/src/otx/backend/native/models/classification/classifier/__init__.py b/lib/src/otx/backend/native/models/classification/classifier/__init__.py index 40b9ad9ae7..d38f562265 100644 --- a/lib/src/otx/backend/native/models/classification/classifier/__init__.py +++ b/lib/src/otx/backend/native/models/classification/classifier/__init__.py @@ -4,6 +4,6 @@ """Head modules for OTX custom model.""" from .base_classifier import ImageClassifier -from .h_label_classifier import HLabelClassifier +from .h_label_classifier import HLabelClassifier, KLHLabelClassifier -__all__ = ["ImageClassifier", "HLabelClassifier"] +__all__ = ["ImageClassifier", "HLabelClassifier", "KLHLabelClassifier"] diff --git a/lib/src/otx/backend/native/models/classification/classifier/h_label_classifier.py b/lib/src/otx/backend/native/models/classification/classifier/h_label_classifier.py index 1649785815..1a6cf0ee81 100644 --- a/lib/src/otx/backend/native/models/classification/classifier/h_label_classifier.py +++ b/lib/src/otx/backend/native/models/classification/classifier/h_label_classifier.py @@ -11,6 +11,7 @@ import torch from otx.backend.native.models.classification.heads.hlabel_cls_head import HierarchicalClsHead +from otx.backend.native.models.classification.losses.tree_path_kl_divergence_loss import TreePathKLDivergenceLoss from otx.backend.native.models.classification.utils.ignored_labels import get_valid_label_mask from .base_classifier import ImageClassifier @@ -143,3 +144,87 @@ def _forward_explain(self, images: torch.Tensor) -> dict[str, torch.Tensor | lis outputs["preds"] = preds return outputs + + +class KLHLabelClassifier(HLabelClassifier): + """Hierarchical label classifier with tree path KL divergence loss. + + Args: + backbone (nn.Module): Backbone network. + neck (nn.Module | None): Neck network. + head (nn.Module): Head network. + multiclass_loss (nn.Module): Multiclass loss function. + multilabel_loss (nn.Module | None, optional): Multilabel loss function. + init_cfg (dict | list[dict] | None, optional): Initialization configuration. + kl_weight (float): Loss weight for tree path KL divergence loss + + Attributes: + multiclass_loss (nn.Module): Multiclass loss function. + multilabel_loss (nn.Module | None): Multilabel loss function. + is_ignored_label_loss (bool): Flag indicating if ignored label loss is used. + + Methods: + loss(inputs, labels, **kwargs): Calculate losses from a batch of inputs and data samples. + """ + + def __init__(self, *args, kl_weight: float = 1.0, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.kl_weight = kl_weight + self.kl_loss = TreePathKLDivergenceLoss(reduction="batchmean", loss_weight=1.0) + + def loss(self, inputs: torch.Tensor, labels: torch.Tensor, **kwargs) -> torch.Tensor: + """Calculate losses from a batch of inputs and data samples. + + Args: + inputs (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + labels (torch.Tensor): The annotation data of + every samples. + + Returns: + torch.Tensor: loss components + """ + cls_scores = self.extract_feat(inputs, stage="head") + loss_score = torch.tensor(0.0, device=cls_scores.device) + logits_list = [] + target_list = [] + num_effective_heads_in_batch = 0 + for i in range(self.head.num_multiclass_heads): + if i not in self.head.empty_multiclass_head_indices: + head_gt = labels[:, i] + logit_range = self.head._get_head_idx_to_logits_range(i) # noqa: SLF001 + head_logits = cls_scores[:, logit_range[0] : logit_range[1]] + valid_mask = head_gt >= 0 + head_gt = head_gt[valid_mask] + if len(head_gt) > 0: + head_logits = head_logits[valid_mask] + logits_list.append(head_logits) + target_list.append(head_gt) + ce = self.multiclass_loss(head_logits, head_gt) + loss_score += ce + num_effective_heads_in_batch += 1 + + if num_effective_heads_in_batch > 0: + loss_score /= num_effective_heads_in_batch + + if len(logits_list) > 1: + kl_loss = self.kl_loss(logits_list, torch.stack(target_list, dim=1)) + loss_score += self.kl_weight * kl_loss + + # Multilabel logic (preserved as-is) + if self.head.num_multilabel_classes > 0: + head_gt = labels[:, self.head.num_multiclass_heads :] + head_logits = cls_scores[:, self.head.num_single_label_classes :] + valid_mask = head_gt > 0 + head_gt = head_gt[valid_mask] + if len(head_gt) > 0 and self.multilabel_loss is not None: + head_logits = head_logits[valid_mask] + imgs_info = kwargs.pop("imgs_info", None) + if imgs_info is not None and self.is_ignored_label_loss: + valid_label_mask = get_valid_label_mask(imgs_info, self.head.num_classes).to(head_logits.device) + valid_label_mask = valid_label_mask[:, self.head.num_single_label_classes :] + valid_label_mask = valid_label_mask[valid_mask] + kwargs["valid_label_mask"] = valid_label_mask + loss_score += self.multilabel_loss(head_logits, head_gt, **kwargs) + + return loss_score diff --git a/lib/src/otx/backend/native/models/classification/hlabel_models/timm_model.py b/lib/src/otx/backend/native/models/classification/hlabel_models/timm_model.py index ba6f2c67fe..379231e8d8 100644 --- a/lib/src/otx/backend/native/models/classification/hlabel_models/timm_model.py +++ b/lib/src/otx/backend/native/models/classification/hlabel_models/timm_model.py @@ -13,7 +13,7 @@ from otx.backend.native.models.base import DataInputParams, DefaultOptimizerCallable, DefaultSchedulerCallable from otx.backend.native.models.classification.backbones.timm import TimmBackbone -from otx.backend.native.models.classification.classifier import HLabelClassifier +from otx.backend.native.models.classification.classifier import HLabelClassifier, KLHLabelClassifier from otx.backend.native.models.classification.heads import HierarchicalLinearClsHead from otx.backend.native.models.classification.hlabel_models.base import OTXHlabelClsModel from otx.backend.native.models.classification.losses.asymmetric_angular_loss_with_ignore import ( @@ -89,3 +89,66 @@ def _create_model(self, head_config: dict | None = None) -> nn.Module: # type: def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict: """Load the previous OTX ckpt according to OTX2.0.""" return OTXv1Helper.load_cls_effnet_v2_ckpt(state_dict, "hlabel", add_prefix) + + +class TimmModelHLabelClsWithKL(OTXHlabelClsModel): + """Timm Model for hierarchical label classification task. + + Args: + label_info (HLabelInfo): The label information for the classification task. + model_name (str): The name of the model. + You can find available models at timm.list_models() or timm.list_pretrained(). + input_size (tuple[int, int], optional): Model input size in the order of height and width. + Defaults to (224, 224). + pretrained (bool, optional): Whether to load pretrained weights. Defaults to True. + optimizer (OptimizerCallable, optional): The optimizer callable for training the model. + scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional): The learning rate scheduler callable. + metric (MetricCallable, optional): The metric callable for evaluating the model. + Defaults to HLabelClsMetricCallable. + torch_compile (bool, optional): Whether to compile the model using TorchScript. Defaults to False. + """ + + def __init__( + self, + label_info: HLabelInfo, + data_input_params: DataInputParams, + model_name: str = "tf_efficientnetv2_s.in21k", + freeze_backbone: bool = False, + optimizer: OptimizerCallable = DefaultOptimizerCallable, + scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, + metric: MetricCallable = HLabelClsMetricCallable, + torch_compile: bool = False, + kl_weight: float = 1.0, + ) -> None: + self.kl_weight = float(kl_weight) + super().__init__( + label_info=label_info, + data_input_params=data_input_params, + model_name=model_name, + freeze_backbone=freeze_backbone, + optimizer=optimizer, + scheduler=scheduler, + metric=metric, + torch_compile=torch_compile, + ) + + def _create_model(self, head_config: dict | None = None) -> nn.Module: # type: ignore[override] + head_config = head_config if head_config is not None else self.label_info.as_head_config_dict() + backbone = TimmBackbone(model_name=self.model_name) + copied_head_config = copy(head_config) + copied_head_config["step_size"] = ( + ceil(self.data_input_params.input_size[0] / 32), + ceil(self.data_input_params.input_size[1] / 32), + ) + return KLHLabelClassifier( + backbone=backbone, + neck=GlobalAveragePooling(dim=2), + head=HierarchicalLinearClsHead(**copied_head_config, in_channels=backbone.num_features), + multiclass_loss=nn.CrossEntropyLoss(), + multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"), + kl_weight=self.kl_weight, + ) + + def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict: + """Load the previous OTX ckpt according to OTX2.0.""" + return OTXv1Helper.load_cls_effnet_v2_ckpt(state_dict, "hlabel", add_prefix) diff --git a/lib/src/otx/recipe/classification/h_label_cls/efficientnet_v2_kl.yaml b/lib/src/otx/recipe/classification/h_label_cls/efficientnet_v2_kl.yaml new file mode 100644 index 0000000000..08be431a4e --- /dev/null +++ b/lib/src/otx/recipe/classification/h_label_cls/efficientnet_v2_kl.yaml @@ -0,0 +1,125 @@ +task: H_LABEL_CLS +model: + class_path: otx.backend.native.models.classification.hlabel_models.timm_model.TimmModelHLabelClsWithKL + init_args: + kl_weight: 1.0 + model_name: tf_efficientnetv2_s.in21k + + optimizer: + class_path: torch.optim.SGD + init_args: + lr: 0.0071 + momentum: 0.9 + weight_decay: 0.0001 + + scheduler: + class_path: otx.backend.native.schedulers.LinearWarmupSchedulerCallable + init_args: + num_warmup_steps: 0 + main_scheduler_callable: + class_path: lightning.pytorch.cli.ReduceLROnPlateau + init_args: + mode: max + factor: 0.5 + patience: 3 + monitor: val/accuracy + +engine: + device: auto + +callback_monitor: val/accuracy + +data: ../../_base_/data/classification.yaml + +callbacks: + - class_path: otx.backend.native.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup + init_args: + warmup_iters: 750 + patience: 5 + mode: max + monitor: val/accuracy + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + dirpath: "" # use engine.work_dir + monitor: val/accuracy + mode: max + save_top_k: 1 + save_last: true + auto_insert_metric_name: false + filename: "checkpoints/epoch_{epoch:03d}" + +overrides: + reset: + - data.train_subset.transforms + + max_epochs: 90 + + data: + task: H_LABEL_CLS + data_format: datumaro + train_subset: + transforms: + - class_path: otx.data.transform_libs.torchvision.EfficientNetRandomCrop + init_args: + scale: $(input_size) + crop_ratio_range: + - 0.08 + - 1.0 + aspect_ratio_range: + - 0.75 + - 1.34 + - class_path: torchvision.transforms.v2.RandomPhotometricDistort + enable: false + init_args: + brightness: + - 0.875 + - 1.125 + contrast: + - 0.5 + - 1.5 + saturation: + - 0.5 + - 1.5 + hue: + - -0.05 + - 0.05 + p: 0.5 + - class_path: otx.data.transform_libs.torchvision.RandomAffine + enable: false + init_args: + max_rotate_degree: 10.0 + max_translate_ratio: 0.1 + scaling_ratio_range: + - 0.5 + - 1.5 + max_shear_degree: 2.0 + - class_path: otx.data.transform_libs.torchvision.RandomFlip + enable: true + init_args: + probability: 0.5 + - class_path: torchvision.transforms.v2.RandomVerticalFlip + enable: false + init_args: + p: 0.5 + - class_path: otx.data.transform_libs.torchvision.RandomGaussianBlur + enable: false + init_args: + kernel_size: 5 + sigma: + - 0.1 + - 2.0 + probability: 0.5 + - class_path: torchvision.transforms.v2.ToDtype + init_args: + dtype: ${as_torch_dtype:torch.float32} + scale: false + - class_path: otx.data.transform_libs.torchvision.RandomGaussianNoise + enable: false + init_args: + mean: 0.0 + sigma: 0.1 + probability: 0.5 + - class_path: torchvision.transforms.v2.Normalize + init_args: + mean: [123.675, 116.28, 103.53] + std: [58.395, 57.12, 57.375] From 3ce8d3b06cf20aedeb2a4d22d06fe8e84221f59d Mon Sep 17 00:00:00 2001 From: Jyc323 Date: Mon, 6 Oct 2025 12:10:10 -0700 Subject: [PATCH 05/13] move files since renaming --- .../models/classification/losses/test_tree_path_kl_divergence.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename {lib => library}/tests/unit/backend/native/models/classification/losses/test_tree_path_kl_divergence.py (100%) diff --git a/lib/tests/unit/backend/native/models/classification/losses/test_tree_path_kl_divergence.py b/library/tests/unit/backend/native/models/classification/losses/test_tree_path_kl_divergence.py similarity index 100% rename from lib/tests/unit/backend/native/models/classification/losses/test_tree_path_kl_divergence.py rename to library/tests/unit/backend/native/models/classification/losses/test_tree_path_kl_divergence.py From 86f4017974c609da46bbc91fa9d4af0b1924f2a8 Mon Sep 17 00:00:00 2001 From: Jyc323 Date: Thu, 9 Oct 2025 00:44:48 -0700 Subject: [PATCH 06/13] refactor to base.py, add unit test --- .../classifier/h_label_classifier.py | 1 - .../classification/hlabel_models/base.py | 39 ++++- .../hlabel_models/timm_model.py | 2 + .../h_label_cls/efficientnet_v2_kl.yaml | 4 +- .../classifier/test_kl_hlabel_classifier.py | 144 ++++++++++++++++++ .../native/models/classification/test_base.py | 70 +++++++++ 6 files changed, 255 insertions(+), 5 deletions(-) create mode 100644 library/tests/unit/backend/native/models/classification/classifier/test_kl_hlabel_classifier.py diff --git a/library/src/otx/backend/native/models/classification/classifier/h_label_classifier.py b/library/src/otx/backend/native/models/classification/classifier/h_label_classifier.py index 1a6cf0ee81..bddcb68544 100644 --- a/library/src/otx/backend/native/models/classification/classifier/h_label_classifier.py +++ b/library/src/otx/backend/native/models/classification/classifier/h_label_classifier.py @@ -171,7 +171,6 @@ def __init__(self, *args, kl_weight: float = 1.0, **kwargs) -> None: super().__init__(*args, **kwargs) self.kl_weight = kl_weight self.kl_loss = TreePathKLDivergenceLoss(reduction="batchmean", loss_weight=1.0) - def loss(self, inputs: torch.Tensor, labels: torch.Tensor, **kwargs) -> torch.Tensor: """Calculate losses from a batch of inputs and data samples. diff --git a/library/src/otx/backend/native/models/classification/hlabel_models/base.py b/library/src/otx/backend/native/models/classification/hlabel_models/base.py index 4697ff5fa8..49832bc3b5 100644 --- a/library/src/otx/backend/native/models/classification/hlabel_models/base.py +++ b/library/src/otx/backend/native/models/classification/hlabel_models/base.py @@ -7,14 +7,16 @@ from abc import abstractmethod from copy import deepcopy +from functools import wraps from typing import TYPE_CHECKING, Any import torch from torch import Tensor - + from otx.backend.native.exporter.base import OTXModelExporter from otx.backend.native.exporter.native import OTXNativeModelExporter from otx.backend.native.models.base import DataInputParams, DefaultOptimizerCallable, DefaultSchedulerCallable, OTXModel +from otx.backend.native.models.classification.classifier import HLabelClassifier, KLHLabelClassifier from otx.backend.native.schedulers import LRSchedulerListCallable from otx.data.entity.base import OTXBatchLossEntity from otx.data.entity.torch import OTXDataBatch, OTXPredBatch @@ -59,7 +61,9 @@ def __init__( scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, metric: MetricCallable = HLabelClsMetricCallable, torch_compile: bool = False, + **kwargs, ) -> None: + self.kl_weight = kwargs.get("kl_weight", 0.0) super().__init__( label_info=label_info, data_input_params=data_input_params, @@ -70,16 +74,47 @@ def __init__( metric=metric, torch_compile=torch_compile, ) - if freeze_backbone: classification_layers = self._identify_classification_layers() for name, param in self.named_parameters(): param.requires_grad = name in classification_layers + def __getattribute__(self, name: str): + attr = super().__getattribute__(name) + if name == "_create_model" and callable(attr): + cache_name = "__cm_cached__" + cache = super().__getattribute__("__dict__").get(cache_name) + if cache: + return cache + @wraps(attr) + def wrapped(*a, **kw) -> nn.Module: + model = attr(*a, **kw) + return self._finalize_model(model) + self.__dict__[cache_name] = wrapped + return wrapped + return attr + @abstractmethod def _create_model(self, head_config: dict | None = None) -> nn.Module: # type: ignore[override] """Create a PyTorch model for this class.""" + def _finalize_model(self, model: nn.Module) -> nn.Module: + """Run after child _create_model(); upgrade to KL if enabled.""" + if self.kl_weight > 0: + kl_model = KLHLabelClassifier( + backbone=model.backbone, + neck=model.neck, + head=model.head, + multiclass_loss=model.multiclass_loss, + multilabel_loss=model.multilabel_loss, + init_cfg=getattr(model, "init_cfg", None), + kl_weight=self.kl_weight, + ) + return kl_model + else: + return model + + def _identify_classification_layers(self, prefix: str = "model.") -> list[str]: """Simple identification of the classification layers. Used for incremental learning.""" # identify classification layers diff --git a/library/src/otx/backend/native/models/classification/hlabel_models/timm_model.py b/library/src/otx/backend/native/models/classification/hlabel_models/timm_model.py index 379231e8d8..2580452fea 100644 --- a/library/src/otx/backend/native/models/classification/hlabel_models/timm_model.py +++ b/library/src/otx/backend/native/models/classification/hlabel_models/timm_model.py @@ -58,6 +58,7 @@ def __init__( scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, metric: MetricCallable = HLabelClsMetricCallable, torch_compile: bool = False, + **kwargs, ) -> None: super().__init__( label_info=label_info, @@ -68,6 +69,7 @@ def __init__( scheduler=scheduler, metric=metric, torch_compile=torch_compile, + **kwargs, ) def _create_model(self, head_config: dict | None = None) -> nn.Module: # type: ignore[override] diff --git a/library/src/otx/recipe/classification/h_label_cls/efficientnet_v2_kl.yaml b/library/src/otx/recipe/classification/h_label_cls/efficientnet_v2_kl.yaml index 08be431a4e..a460fac282 100644 --- a/library/src/otx/recipe/classification/h_label_cls/efficientnet_v2_kl.yaml +++ b/library/src/otx/recipe/classification/h_label_cls/efficientnet_v2_kl.yaml @@ -1,8 +1,8 @@ task: H_LABEL_CLS model: - class_path: otx.backend.native.models.classification.hlabel_models.timm_model.TimmModelHLabelClsWithKL + class_path: otx.backend.native.models.classification.hlabel_models.timm_model.TimmModelHLabelCls init_args: - kl_weight: 1.0 + kl_weight: 2.0 model_name: tf_efficientnetv2_s.in21k optimizer: diff --git a/library/tests/unit/backend/native/models/classification/classifier/test_kl_hlabel_classifier.py b/library/tests/unit/backend/native/models/classification/classifier/test_kl_hlabel_classifier.py new file mode 100644 index 0000000000..64ad7c12e5 --- /dev/null +++ b/library/tests/unit/backend/native/models/classification/classifier/test_kl_hlabel_classifier.py @@ -0,0 +1,144 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +from torch import nn + +from otx.backend.native.models.classification.backbones import EfficientNetBackbone +from otx.backend.native.models.classification.classifier import HLabelClassifier, KLHLabelClassifier +from otx.backend.native.models.classification.heads import LinearClsHead, MultiLabelLinearClsHead +from otx.backend.native.models.classification.losses import AsymmetricAngularLossWithIgnore +from otx.backend.native.models.classification.necks.gap import GlobalAveragePooling +from otx.backend.native.models.classification.heads.hlabel_cls_head import HierarchicalClsHead + + +class TestHierHead(HierarchicalClsHead): + """Lightweight hierarchical head for tests, compatible with H/KLH classifiers.""" + + def __init__(self, in_channels: int, head_class_sizes=(3, 3)): + # e.g., two heads with 3 classes each -> total classes = 6 + self.head_class_sizes = list(head_class_sizes) + num_multiclass_heads = len(self.head_class_sizes) + num_multilabel_classes = 0 + num_single_label_classes = sum(self.head_class_sizes) + num_classes = num_single_label_classes + + # Build per-head logit ranges, e.g. [(0,3), (3,6)] + start = 0 + ranges = [] + for k in self.head_class_sizes: + ranges.append((start, start + k)) + start += k + + empty_multiclass_head_indices = [] + + # Call the real base class with all required args + super().__init__( + num_classes=num_classes, + in_channels=in_channels, + num_multiclass_heads=num_multiclass_heads, + num_multilabel_classes=num_multilabel_classes, + head_idx_to_logits_range=ranges, + num_single_label_classes=num_single_label_classes, + empty_multiclass_head_indices=empty_multiclass_head_indices, + ) + + # Simple linear head over pooled features -> logits + self.classifier = nn.Linear(in_channels, num_classes) + self._head_ranges = ranges + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if isinstance(x, (tuple, list)): + x = x[0] + return self.classifier(x) + + # H/KLH classifiers call this to slice per-head logits + def _get_head_idx_to_logits_range(self, i: int): + return self._head_ranges[i] + + +class TestKLHLabelClassifier: + @pytest.fixture( + params=[ + (LinearClsHead, nn.CrossEntropyLoss, "fxt_multiclass_cls_batch_data_entity"), + (MultiLabelLinearClsHead, AsymmetricAngularLossWithIgnore, "fxt_multilabel_cls_batch_data_entity"), + ], + ids=["multiclass", "multilabel"], + ) + def fxt_model_and_inputs(self, request): + head_class_sizes = (3, 3) + input_fxt_name = "fxt_multiclass_cls_batch_data_entity" + backbone = EfficientNetBackbone(model_name="efficientnet_b0") + neck = GlobalAveragePooling(dim=2) + head = TestHierHead(in_channels=backbone.num_features, head_class_sizes=head_class_sizes) + loss = nn.CrossEntropyLoss() + fxt_input = request.getfixturevalue(input_fxt_name) + l = len(head_class_sizes) + fxt_labels = torch.stack(fxt_input.labels) + fxt_labels = fxt_labels.repeat(1, l) + return (backbone, neck, head, loss, fxt_input.images, fxt_labels) + + + def test_forward(self, fxt_model_and_inputs): + backbone, neck, head, loss, images, labels = fxt_model_and_inputs + + model = KLHLabelClassifier( + backbone=backbone, + neck=neck, + head=head, + multiclass_loss=loss, + kl_weight=1, + ) + + output = model(images, labels, mode='explain') + assert isinstance(output, dict) + assert "logits" in output + assert "scores" in output + assert "preds" in output + + def test_klh_loss_greater_than_hlabel(self, fxt_model_and_inputs): + """KLHLabelClassifier should have strictly larger loss than HLabelClassifier + when kl_weight > 0 and there are >= 2 multiclass heads.""" + backbone, neck, head, loss, images, labels = fxt_model_and_inputs + h_model = HLabelClassifier( + backbone=backbone, + neck=neck, + head=head, + multiclass_loss=loss, + ) + kl_h_model = KLHLabelClassifier( + backbone=backbone, + neck=neck, + head=head, + multiclass_loss=loss, + kl_weight=2.0, + ) + + h_loss = h_model.loss(images, labels) + klh_loss = kl_h_model.loss(images, labels) + + print(f"HLabel loss: {h_loss.item():.6f} | KLH loss: {klh_loss.item():.6f}") + assert klh_loss > h_loss, "Expected KLH loss to be greater due to added KL term" + + def test_klh_weight_zero_match_hlabel(self, fxt_model_and_inputs): + """With kl_weight == 0, KLH loss should match H label loss (within tolerance).""" + backbone, neck, head, loss, images, labels = fxt_model_and_inputs + h_model = HLabelClassifier( + backbone=backbone, + neck=neck, + head=head, + multiclass_loss=loss, + ) + kl_h_model = KLHLabelClassifier( + backbone=backbone, + neck=neck, + head=head, + multiclass_loss=loss, + kl_weight=0, + ) + h_loss = h_model.loss(images, labels) + klh_loss = kl_h_model.loss(images, labels) + + print(f"[kl=0] HLabel loss: {h_loss.item():.6f} | KLH loss: {klh_loss.item():.6f}") + assert torch.allclose(klh_loss, h_loss, atol=1e-6), "With kl_weight=0, losses should match" diff --git a/library/tests/unit/backend/native/models/classification/test_base.py b/library/tests/unit/backend/native/models/classification/test_base.py index 4868c34f9a..5a520f3757 100644 --- a/library/tests/unit/backend/native/models/classification/test_base.py +++ b/library/tests/unit/backend/native/models/classification/test_base.py @@ -5,6 +5,8 @@ from __future__ import annotations +from types import MethodType +from unittest.mock import MagicMock from unittest.mock import create_autospec import pytest @@ -13,6 +15,7 @@ from torch.optim import Optimizer from otx.backend.native.models.base import DataInputParams +from otx.backend.native.models.classification.classifier import HLabelClassifier, KLHLabelClassifier from otx.backend.native.models.classification.hlabel_models.base import OTXHlabelClsModel from otx.backend.native.models.classification.multiclass_models.base import OTXMulticlassClsModel from otx.backend.native.models.classification.multilabel_models.base import OTXMultilabelClsModel @@ -216,3 +219,70 @@ def test_set_label_info(self, fxt_hlabel_multilabel_info): fxt_hlabel_multilabel_info.num_multilabel_classes = 0 model.label_info = fxt_hlabel_multilabel_info assert model.label_info.num_multilabel_classes == 0 + +class TestOTXHlabelClsModelwithKL: + @pytest.fixture(autouse=True) + def mock_model(self, mocker): + OTXHlabelClsModel._build_model = mocker.MagicMock(return_value=MockClsModel()) + + @pytest.fixture() + def mock_optimizer(self): + return lambda _: create_autospec(Optimizer) + + @pytest.fixture() + def mock_scheduler(self): + return lambda _: create_autospec([ReduceLROnPlateau]) + + @pytest.fixture() + def model_instance(self, mock_optimizer, mock_scheduler, fxt_hlabel_multilabel_info): + """Create a minimal instance of OTXHlabelClsModel for testing.""" + return OTXHlabelClsModel( + label_info=fxt_hlabel_multilabel_info, + data_input_params=DataInputParams((224, 224), (0.0, 0.0, 0.0), (1.0, 1.0, 1.0)), + torch_compile=False, + optimizer=mock_optimizer, + scheduler=mock_scheduler, + ) + + def _prepare_instance_with_fake_create(self, model_instance): + """Replace _create_model() with a dummy version that returns a sentinel object.""" + sentinel = object() + + def fake_create(self, head_config=None): + return sentinel + model_instance._create_model = MethodType(fake_create, model_instance) + return sentinel + + def test_create_model_triggers_finalize_when_kl_positive(self, model_instance): + """ + When kl_weight > 0, calling _create_model() should trigger _finalize_model(). + """ + model_instance.kl_weight = 1.0 + + # Spy on _finalize_model before the first _create_model() call + finalize_spy = MagicMock(side_effect=lambda m: m) + model_instance._finalize_model = finalize_spy + + # Replace _create_model with a dummy implementation + sentinel = self._prepare_instance_with_fake_create(model_instance) + + # Call _create_model(); wrapper logic should invoke _finalize_model internally + out = model_instance._create_model() + + finalize_spy.assert_called_once_with(sentinel) + assert out is sentinel + + def test_create_model_does_not_trigger_finalize_when_kl_zero(self, model_instance): + """ + When kl_weight == 0, calling _create_model() should NOT trigger _finalize_model(). + """ + model_instance.kl_weight = 0.0 + + finalize_spy = MagicMock(side_effect=lambda m: m) + model_instance._finalize_model = finalize_spy + + sentinel = self._prepare_instance_with_fake_create(model_instance) + out = model_instance._create_model() + + assert finalize_spy.call_count == 0, "_finalize_model() should not be triggered when kl_weight == 0" + assert out is sentinel From 79f861294a41bdba99e96e4d1e865d10f076f809 Mon Sep 17 00:00:00 2001 From: Jyc323 Date: Thu, 9 Oct 2025 01:23:51 -0700 Subject: [PATCH 07/13] fix errors from tox --- .../classifier/h_label_classifier.py | 1 + .../classification/hlabel_models/base.py | 13 ++-- .../classifier/test_kl_hlabel_classifier.py | 23 +++--- .../native/models/classification/test_base.py | 70 ------------------- 4 files changed, 16 insertions(+), 91 deletions(-) diff --git a/library/src/otx/backend/native/models/classification/classifier/h_label_classifier.py b/library/src/otx/backend/native/models/classification/classifier/h_label_classifier.py index bddcb68544..1a6cf0ee81 100644 --- a/library/src/otx/backend/native/models/classification/classifier/h_label_classifier.py +++ b/library/src/otx/backend/native/models/classification/classifier/h_label_classifier.py @@ -171,6 +171,7 @@ def __init__(self, *args, kl_weight: float = 1.0, **kwargs) -> None: super().__init__(*args, **kwargs) self.kl_weight = kl_weight self.kl_loss = TreePathKLDivergenceLoss(reduction="batchmean", loss_weight=1.0) + def loss(self, inputs: torch.Tensor, labels: torch.Tensor, **kwargs) -> torch.Tensor: """Calculate losses from a batch of inputs and data samples. diff --git a/library/src/otx/backend/native/models/classification/hlabel_models/base.py b/library/src/otx/backend/native/models/classification/hlabel_models/base.py index 49832bc3b5..c209abf8df 100644 --- a/library/src/otx/backend/native/models/classification/hlabel_models/base.py +++ b/library/src/otx/backend/native/models/classification/hlabel_models/base.py @@ -12,11 +12,11 @@ import torch from torch import Tensor - + from otx.backend.native.exporter.base import OTXModelExporter from otx.backend.native.exporter.native import OTXNativeModelExporter from otx.backend.native.models.base import DataInputParams, DefaultOptimizerCallable, DefaultSchedulerCallable, OTXModel -from otx.backend.native.models.classification.classifier import HLabelClassifier, KLHLabelClassifier +from otx.backend.native.models.classification.classifier import KLHLabelClassifier from otx.backend.native.schedulers import LRSchedulerListCallable from otx.data.entity.base import OTXBatchLossEntity from otx.data.entity.torch import OTXDataBatch, OTXPredBatch @@ -86,10 +86,12 @@ def __getattribute__(self, name: str): cache = super().__getattribute__("__dict__").get(cache_name) if cache: return cache + @wraps(attr) def wrapped(*a, **kw) -> nn.Module: model = attr(*a, **kw) return self._finalize_model(model) + self.__dict__[cache_name] = wrapped return wrapped return attr @@ -101,7 +103,7 @@ def _create_model(self, head_config: dict | None = None) -> nn.Module: # type: def _finalize_model(self, model: nn.Module) -> nn.Module: """Run after child _create_model(); upgrade to KL if enabled.""" if self.kl_weight > 0: - kl_model = KLHLabelClassifier( + return KLHLabelClassifier( backbone=model.backbone, neck=model.neck, head=model.head, @@ -110,10 +112,7 @@ def _finalize_model(self, model: nn.Module) -> nn.Module: init_cfg=getattr(model, "init_cfg", None), kl_weight=self.kl_weight, ) - return kl_model - else: - return model - + return model def _identify_classification_layers(self, prefix: str = "model.") -> list[str]: """Simple identification of the classification layers. Used for incremental learning.""" diff --git a/library/tests/unit/backend/native/models/classification/classifier/test_kl_hlabel_classifier.py b/library/tests/unit/backend/native/models/classification/classifier/test_kl_hlabel_classifier.py index 64ad7c12e5..ae44e5a9ab 100644 --- a/library/tests/unit/backend/native/models/classification/classifier/test_kl_hlabel_classifier.py +++ b/library/tests/unit/backend/native/models/classification/classifier/test_kl_hlabel_classifier.py @@ -1,4 +1,4 @@ -# Copyright (C) 2024 Intel Corporation +# Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 import pytest @@ -8,9 +8,9 @@ from otx.backend.native.models.classification.backbones import EfficientNetBackbone from otx.backend.native.models.classification.classifier import HLabelClassifier, KLHLabelClassifier from otx.backend.native.models.classification.heads import LinearClsHead, MultiLabelLinearClsHead +from otx.backend.native.models.classification.heads.hlabel_cls_head import HierarchicalClsHead from otx.backend.native.models.classification.losses import AsymmetricAngularLossWithIgnore from otx.backend.native.models.classification.necks.gap import GlobalAveragePooling -from otx.backend.native.models.classification.heads.hlabel_cls_head import HierarchicalClsHead class TestHierHead(HierarchicalClsHead): @@ -53,10 +53,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = x[0] return self.classifier(x) - # H/KLH classifiers call this to slice per-head logits - def _get_head_idx_to_logits_range(self, i: int): - return self._head_ranges[i] - class TestKLHLabelClassifier: @pytest.fixture( @@ -74,15 +70,14 @@ def fxt_model_and_inputs(self, request): head = TestHierHead(in_channels=backbone.num_features, head_class_sizes=head_class_sizes) loss = nn.CrossEntropyLoss() fxt_input = request.getfixturevalue(input_fxt_name) - l = len(head_class_sizes) + level = len(head_class_sizes) fxt_labels = torch.stack(fxt_input.labels) - fxt_labels = fxt_labels.repeat(1, l) + fxt_labels = fxt_labels.repeat(1, level) return (backbone, neck, head, loss, fxt_input.images, fxt_labels) - def test_forward(self, fxt_model_and_inputs): backbone, neck, head, loss, images, labels = fxt_model_and_inputs - + model = KLHLabelClassifier( backbone=backbone, neck=neck, @@ -90,13 +85,13 @@ def test_forward(self, fxt_model_and_inputs): multiclass_loss=loss, kl_weight=1, ) - - output = model(images, labels, mode='explain') + + output = model(images, labels, mode="explain") assert isinstance(output, dict) assert "logits" in output assert "scores" in output assert "preds" in output - + def test_klh_loss_greater_than_hlabel(self, fxt_model_and_inputs): """KLHLabelClassifier should have strictly larger loss than HLabelClassifier when kl_weight > 0 and there are >= 2 multiclass heads.""" @@ -120,7 +115,7 @@ def test_klh_loss_greater_than_hlabel(self, fxt_model_and_inputs): print(f"HLabel loss: {h_loss.item():.6f} | KLH loss: {klh_loss.item():.6f}") assert klh_loss > h_loss, "Expected KLH loss to be greater due to added KL term" - + def test_klh_weight_zero_match_hlabel(self, fxt_model_and_inputs): """With kl_weight == 0, KLH loss should match H label loss (within tolerance).""" backbone, neck, head, loss, images, labels = fxt_model_and_inputs diff --git a/library/tests/unit/backend/native/models/classification/test_base.py b/library/tests/unit/backend/native/models/classification/test_base.py index 5a520f3757..4868c34f9a 100644 --- a/library/tests/unit/backend/native/models/classification/test_base.py +++ b/library/tests/unit/backend/native/models/classification/test_base.py @@ -5,8 +5,6 @@ from __future__ import annotations -from types import MethodType -from unittest.mock import MagicMock from unittest.mock import create_autospec import pytest @@ -15,7 +13,6 @@ from torch.optim import Optimizer from otx.backend.native.models.base import DataInputParams -from otx.backend.native.models.classification.classifier import HLabelClassifier, KLHLabelClassifier from otx.backend.native.models.classification.hlabel_models.base import OTXHlabelClsModel from otx.backend.native.models.classification.multiclass_models.base import OTXMulticlassClsModel from otx.backend.native.models.classification.multilabel_models.base import OTXMultilabelClsModel @@ -219,70 +216,3 @@ def test_set_label_info(self, fxt_hlabel_multilabel_info): fxt_hlabel_multilabel_info.num_multilabel_classes = 0 model.label_info = fxt_hlabel_multilabel_info assert model.label_info.num_multilabel_classes == 0 - -class TestOTXHlabelClsModelwithKL: - @pytest.fixture(autouse=True) - def mock_model(self, mocker): - OTXHlabelClsModel._build_model = mocker.MagicMock(return_value=MockClsModel()) - - @pytest.fixture() - def mock_optimizer(self): - return lambda _: create_autospec(Optimizer) - - @pytest.fixture() - def mock_scheduler(self): - return lambda _: create_autospec([ReduceLROnPlateau]) - - @pytest.fixture() - def model_instance(self, mock_optimizer, mock_scheduler, fxt_hlabel_multilabel_info): - """Create a minimal instance of OTXHlabelClsModel for testing.""" - return OTXHlabelClsModel( - label_info=fxt_hlabel_multilabel_info, - data_input_params=DataInputParams((224, 224), (0.0, 0.0, 0.0), (1.0, 1.0, 1.0)), - torch_compile=False, - optimizer=mock_optimizer, - scheduler=mock_scheduler, - ) - - def _prepare_instance_with_fake_create(self, model_instance): - """Replace _create_model() with a dummy version that returns a sentinel object.""" - sentinel = object() - - def fake_create(self, head_config=None): - return sentinel - model_instance._create_model = MethodType(fake_create, model_instance) - return sentinel - - def test_create_model_triggers_finalize_when_kl_positive(self, model_instance): - """ - When kl_weight > 0, calling _create_model() should trigger _finalize_model(). - """ - model_instance.kl_weight = 1.0 - - # Spy on _finalize_model before the first _create_model() call - finalize_spy = MagicMock(side_effect=lambda m: m) - model_instance._finalize_model = finalize_spy - - # Replace _create_model with a dummy implementation - sentinel = self._prepare_instance_with_fake_create(model_instance) - - # Call _create_model(); wrapper logic should invoke _finalize_model internally - out = model_instance._create_model() - - finalize_spy.assert_called_once_with(sentinel) - assert out is sentinel - - def test_create_model_does_not_trigger_finalize_when_kl_zero(self, model_instance): - """ - When kl_weight == 0, calling _create_model() should NOT trigger _finalize_model(). - """ - model_instance.kl_weight = 0.0 - - finalize_spy = MagicMock(side_effect=lambda m: m) - model_instance._finalize_model = finalize_spy - - sentinel = self._prepare_instance_with_fake_create(model_instance) - out = model_instance._create_model() - - assert finalize_spy.call_count == 0, "_finalize_model() should not be triggered when kl_weight == 0" - assert out is sentinel From b520b7e5563785da3df3fb945864a3734c231cb9 Mon Sep 17 00:00:00 2001 From: Jyc323 Date: Mon, 13 Oct 2025 11:29:07 -0700 Subject: [PATCH 08/13] delete unnecessary class, replace **kwargs with kl_weight (add docstring) --- .../classification/hlabel_models/base.py | 5 +- .../hlabel_models/timm_model.py | 67 +------------------ 2 files changed, 6 insertions(+), 66 deletions(-) diff --git a/library/src/otx/backend/native/models/classification/hlabel_models/base.py b/library/src/otx/backend/native/models/classification/hlabel_models/base.py index c209abf8df..8c13d80a21 100644 --- a/library/src/otx/backend/native/models/classification/hlabel_models/base.py +++ b/library/src/otx/backend/native/models/classification/hlabel_models/base.py @@ -47,6 +47,7 @@ class OTXHlabelClsModel(OTXModel): Defaults to DefaultSchedulerCallable. metric (MetricCallable, optional): Callable for the metric. Defaults to HLabelClsMetricCallable. torch_compile (bool, optional): Flag to indicate whether to use torch.compile. Defaults to False. + kl_weight: The weight of tree-path KL divergence loss. Defaults to zero, use CrossEntropy only. """ label_info: HLabelInfo @@ -61,9 +62,9 @@ def __init__( scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, metric: MetricCallable = HLabelClsMetricCallable, torch_compile: bool = False, - **kwargs, + kl_weight: float = 0.0, ) -> None: - self.kl_weight = kwargs.get("kl_weight", 0.0) + self.kl_weight = kl_weight super().__init__( label_info=label_info, data_input_params=data_input_params, diff --git a/library/src/otx/backend/native/models/classification/hlabel_models/timm_model.py b/library/src/otx/backend/native/models/classification/hlabel_models/timm_model.py index 2580452fea..a791765012 100644 --- a/library/src/otx/backend/native/models/classification/hlabel_models/timm_model.py +++ b/library/src/otx/backend/native/models/classification/hlabel_models/timm_model.py @@ -46,6 +46,7 @@ class TimmModelHLabelCls(OTXHlabelClsModel): metric (MetricCallable, optional): The metric callable for evaluating the model. Defaults to HLabelClsMetricCallable. torch_compile (bool, optional): Whether to compile the model using TorchScript. Defaults to False. + kl_weight: The weight of tree-path KL divergence loss. Defaults to zero, use CrossEntropy only. """ def __init__( @@ -58,7 +59,7 @@ def __init__( scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, metric: MetricCallable = HLabelClsMetricCallable, torch_compile: bool = False, - **kwargs, + kl_weight: float = 0.0, ) -> None: super().__init__( label_info=label_info, @@ -69,7 +70,7 @@ def __init__( scheduler=scheduler, metric=metric, torch_compile=torch_compile, - **kwargs, + kl_weight=kl_weight, ) def _create_model(self, head_config: dict | None = None) -> nn.Module: # type: ignore[override] @@ -92,65 +93,3 @@ def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> """Load the previous OTX ckpt according to OTX2.0.""" return OTXv1Helper.load_cls_effnet_v2_ckpt(state_dict, "hlabel", add_prefix) - -class TimmModelHLabelClsWithKL(OTXHlabelClsModel): - """Timm Model for hierarchical label classification task. - - Args: - label_info (HLabelInfo): The label information for the classification task. - model_name (str): The name of the model. - You can find available models at timm.list_models() or timm.list_pretrained(). - input_size (tuple[int, int], optional): Model input size in the order of height and width. - Defaults to (224, 224). - pretrained (bool, optional): Whether to load pretrained weights. Defaults to True. - optimizer (OptimizerCallable, optional): The optimizer callable for training the model. - scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional): The learning rate scheduler callable. - metric (MetricCallable, optional): The metric callable for evaluating the model. - Defaults to HLabelClsMetricCallable. - torch_compile (bool, optional): Whether to compile the model using TorchScript. Defaults to False. - """ - - def __init__( - self, - label_info: HLabelInfo, - data_input_params: DataInputParams, - model_name: str = "tf_efficientnetv2_s.in21k", - freeze_backbone: bool = False, - optimizer: OptimizerCallable = DefaultOptimizerCallable, - scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, - metric: MetricCallable = HLabelClsMetricCallable, - torch_compile: bool = False, - kl_weight: float = 1.0, - ) -> None: - self.kl_weight = float(kl_weight) - super().__init__( - label_info=label_info, - data_input_params=data_input_params, - model_name=model_name, - freeze_backbone=freeze_backbone, - optimizer=optimizer, - scheduler=scheduler, - metric=metric, - torch_compile=torch_compile, - ) - - def _create_model(self, head_config: dict | None = None) -> nn.Module: # type: ignore[override] - head_config = head_config if head_config is not None else self.label_info.as_head_config_dict() - backbone = TimmBackbone(model_name=self.model_name) - copied_head_config = copy(head_config) - copied_head_config["step_size"] = ( - ceil(self.data_input_params.input_size[0] / 32), - ceil(self.data_input_params.input_size[1] / 32), - ) - return KLHLabelClassifier( - backbone=backbone, - neck=GlobalAveragePooling(dim=2), - head=HierarchicalLinearClsHead(**copied_head_config, in_channels=backbone.num_features), - multiclass_loss=nn.CrossEntropyLoss(), - multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"), - kl_weight=self.kl_weight, - ) - - def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict: - """Load the previous OTX ckpt according to OTX2.0.""" - return OTXv1Helper.load_cls_effnet_v2_ckpt(state_dict, "hlabel", add_prefix) From c465bdc3580af151766cf413d21a0692560f8dc3 Mon Sep 17 00:00:00 2001 From: Jyc323 Date: Mon, 13 Oct 2025 12:31:49 -0700 Subject: [PATCH 09/13] modify the mock of head_idx_to_logits_range --- .../classifier/test_kl_hlabel_classifier.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/library/tests/unit/backend/native/models/classification/classifier/test_kl_hlabel_classifier.py b/library/tests/unit/backend/native/models/classification/classifier/test_kl_hlabel_classifier.py index ae44e5a9ab..db750aa349 100644 --- a/library/tests/unit/backend/native/models/classification/classifier/test_kl_hlabel_classifier.py +++ b/library/tests/unit/backend/native/models/classification/classifier/test_kl_hlabel_classifier.py @@ -26,9 +26,9 @@ def __init__(self, in_channels: int, head_class_sizes=(3, 3)): # Build per-head logit ranges, e.g. [(0,3), (3,6)] start = 0 - ranges = [] - for k in self.head_class_sizes: - ranges.append((start, start + k)) + ranges = {} + for idx, k in enumerate(self.head_class_sizes): + ranges[str(idx)] = (start, start + k) start += k empty_multiclass_head_indices = [] @@ -86,12 +86,9 @@ def test_forward(self, fxt_model_and_inputs): kl_weight=1, ) - output = model(images, labels, mode="explain") - assert isinstance(output, dict) - assert "logits" in output - assert "scores" in output - assert "preds" in output - + output = model(images, labels, mode="loss") + assert isinstance(output, torch.Tensor) + def test_klh_loss_greater_than_hlabel(self, fxt_model_and_inputs): """KLHLabelClassifier should have strictly larger loss than HLabelClassifier when kl_weight > 0 and there are >= 2 multiclass heads.""" From 395bd13ddeef79427014712efbdcc29100613fae Mon Sep 17 00:00:00 2001 From: Jyc323 Date: Mon, 13 Oct 2025 12:57:55 -0700 Subject: [PATCH 10/13] update list models pattern --- library/tests/unit/backend/native/utils/test_api.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/library/tests/unit/backend/native/utils/test_api.py b/library/tests/unit/backend/native/utils/test_api.py index 2779a8b43a..68f2d5c996 100644 --- a/library/tests/unit/backend/native/utils/test_api.py +++ b/library/tests/unit/backend/native/utils/test_api.py @@ -30,11 +30,15 @@ def test_list_models_pattern() -> None: target = [ "efficientnet_b0", "efficientnet_v2", + "efficientnet_v2_kl", "maskrcnn_efficientnetb2b", "maskrcnn_efficientnetb2b_tile", "tv_efficientnet_b3", "tv_efficientnet_v2_l", ] + print(sorted(models)) + print('--------------') + print(sorted(target)) assert sorted(models) == sorted(target) From bea688c03f4ae6252bad425c59737c878bd78d0f Mon Sep 17 00:00:00 2001 From: Jyc323 Date: Mon, 13 Oct 2025 13:00:09 -0700 Subject: [PATCH 11/13] update list models pattern --- library/tests/unit/backend/native/utils/test_api.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/library/tests/unit/backend/native/utils/test_api.py b/library/tests/unit/backend/native/utils/test_api.py index 68f2d5c996..c185710e6f 100644 --- a/library/tests/unit/backend/native/utils/test_api.py +++ b/library/tests/unit/backend/native/utils/test_api.py @@ -36,9 +36,6 @@ def test_list_models_pattern() -> None: "tv_efficientnet_b3", "tv_efficientnet_v2_l", ] - print(sorted(models)) - print('--------------') - print(sorted(target)) assert sorted(models) == sorted(target) From 4741836bba999c9813b7213495031dc1656fb14b Mon Sep 17 00:00:00 2001 From: Jyc323 Date: Thu, 16 Oct 2025 14:34:48 -0700 Subject: [PATCH 12/13] ruff fix --- .../native/models/classification/hlabel_models/timm_model.py | 3 +-- .../classification/classifier/test_kl_hlabel_classifier.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/library/src/otx/backend/native/models/classification/hlabel_models/timm_model.py b/library/src/otx/backend/native/models/classification/hlabel_models/timm_model.py index a791765012..9d07611be6 100644 --- a/library/src/otx/backend/native/models/classification/hlabel_models/timm_model.py +++ b/library/src/otx/backend/native/models/classification/hlabel_models/timm_model.py @@ -13,7 +13,7 @@ from otx.backend.native.models.base import DataInputParams, DefaultOptimizerCallable, DefaultSchedulerCallable from otx.backend.native.models.classification.backbones.timm import TimmBackbone -from otx.backend.native.models.classification.classifier import HLabelClassifier, KLHLabelClassifier +from otx.backend.native.models.classification.classifier import HLabelClassifier from otx.backend.native.models.classification.heads import HierarchicalLinearClsHead from otx.backend.native.models.classification.hlabel_models.base import OTXHlabelClsModel from otx.backend.native.models.classification.losses.asymmetric_angular_loss_with_ignore import ( @@ -92,4 +92,3 @@ def _create_model(self, head_config: dict | None = None) -> nn.Module: # type: def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict: """Load the previous OTX ckpt according to OTX2.0.""" return OTXv1Helper.load_cls_effnet_v2_ckpt(state_dict, "hlabel", add_prefix) - diff --git a/library/tests/unit/backend/native/models/classification/classifier/test_kl_hlabel_classifier.py b/library/tests/unit/backend/native/models/classification/classifier/test_kl_hlabel_classifier.py index db750aa349..31801bc513 100644 --- a/library/tests/unit/backend/native/models/classification/classifier/test_kl_hlabel_classifier.py +++ b/library/tests/unit/backend/native/models/classification/classifier/test_kl_hlabel_classifier.py @@ -88,7 +88,7 @@ def test_forward(self, fxt_model_and_inputs): output = model(images, labels, mode="loss") assert isinstance(output, torch.Tensor) - + def test_klh_loss_greater_than_hlabel(self, fxt_model_and_inputs): """KLHLabelClassifier should have strictly larger loss than HLabelClassifier when kl_weight > 0 and there are >= 2 multiclass heads.""" From eae07d713e0a4112213ef146c3f768999ca712e9 Mon Sep 17 00:00:00 2001 From: Jyc323 Date: Thu, 16 Oct 2025 14:50:11 -0700 Subject: [PATCH 13/13] fix ruff errors --- library/src/otx/backend/native/utils/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/src/otx/backend/native/utils/utils.py b/library/src/otx/backend/native/utils/utils.py index 593d1f261f..9468723d2e 100644 --- a/library/src/otx/backend/native/utils/utils.py +++ b/library/src/otx/backend/native/utils/utils.py @@ -86,7 +86,7 @@ def mock_modules_for_chkpt() -> Iterator[None]: sys.modules["otx.core.types.task"] = otx.types.task sys.modules["otx.core.types.label"] = otx.types.label sys.modules["otx.core.model"] = otx.backend.native.models # type: ignore[attr-defined] - sys.modules["otx.core.metrics"] = otx.metrics + # sys.modules["otx.core.metrics"] = otx.metrics yield finally: