Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import pytest
import torch
from torch.nn import functional

from otx.backend.native.models.classification.losses.tree_path_kl_divergence_loss import TreePathKLDivergenceLoss


@pytest.mark.parametrize(
("levels", "classes_per_level"),
[
(2, [3, 5]),
(3, [2, 3, 4]),
],
)
def test_forward_scalar_and_finite(levels, classes_per_level):
torch.manual_seed(0)
batch = 4
logits_list = [torch.randn(batch, c) for c in classes_per_level]
targets = torch.stack([torch.randint(0, c, (batch,)) for c in classes_per_level], dim=1)

loss_fn = TreePathKLDivergenceLoss(reduction="batchmean")
loss = loss_fn(logits_list, targets)
assert loss.ndim == 0
assert torch.isfinite(loss)
assert loss.item() >= -1e-7


def test_backward_produces_grads():
batch = 3
channel = [4, 6]
logits_list = [torch.randn(batch, c, requires_grad=True) for c in channel]
targets = torch.stack([torch.randint(0, c, (batch,)) for c in channel], dim=1)

loss = TreePathKLDivergenceLoss()(logits_list, targets)
loss.backward()
for logit in logits_list:
assert logit.grad is not None
assert torch.isfinite(logit.grad).all()


def test_alignment_vs_misalignment_loss():
batch = 2
channel0, channel1 = 3, 4
targets = torch.tensor([[0, 1], [2, 3]])

# Aligned: boost GT logits
aligned0 = torch.zeros(batch, channel0)
aligned1 = torch.zeros(batch, channel1)
aligned0[torch.arange(batch), targets[:, 0]] = 5.0
aligned1[torch.arange(batch), targets[:, 1]] = 5.0

# Misaligned: boost wrong logits
mis0 = torch.zeros(batch, channel0)
mis1 = torch.zeros(batch, channel1)
mis0[torch.arange(batch), (targets[:, 0] + 1) % channel0] = 5.0
mis1[torch.arange(batch), (targets[:, 1] + 1) % channel1] = 5.0

loss_fn = TreePathKLDivergenceLoss()
loss_aligned = loss_fn([aligned0, aligned1], targets)
loss_misaligned = loss_fn([mis0, mis1], targets)
assert loss_aligned < loss_misaligned


def test_single_level_exact_value():
"""
With a single level, KL reduces to CE between predicted softmax and one-hot target.
We check exact value against F.cross_entropy.
"""

logits = torch.tensor([[2.0, 0.0, -1.0], [0.5, 1.0, -0.5]])
targets = torch.tensor([[0], [2]]) # shape [B,1]

# TreePathKLP
loss_fn = TreePathKLDivergenceLoss(reduction="batchmean")
kl_loss = loss_fn([logits], targets)

# CrossEntropy with one-hot is same as NLLLoss(log_softmax)
ce_loss = functional.cross_entropy(logits, targets.view(-1), reduction="mean")

assert torch.allclose(kl_loss, ce_loss, atol=1e-6)


def test_multi_level_exact_value_batchmean():
"""
Exact numerical check for L=2 levels with 'batchmean' reduction.

Loss per sample (PyTorch KLDivLoss):
KL(p || q) = sum_j p_j * (log(p_j) - log(q_j))
where input to KLDivLoss is log(q_j) (our model log_probs),
and the target is p_j (our constructed target distribution).
With reduction='batchmean', PyTorch divides the total sum by batch size.
"""

# Use double for better numerical agreement
batch = 2
l0, l1 = 2, 3
logits0 = torch.tensor([[2.0, -1.0], [0.0, 1.0]], dtype=torch.float64) # [B, l0]
logits1 = torch.tensor([[0.5, 0.0, -0.5], [-1.0, 2.0, 0.5]], dtype=torch.float64) # [B, l1]

# Ground-truth indices per level
# sample 0: level0->0, level1->1
# sample 1: level0->1, level1->2
targets = torch.tensor([[0, 1], [1, 2]], dtype=torch.long) # [B, 2]
level = 2 # number of levels

# Model log probs over concatenated heads
concat = torch.cat([logits0, logits1], dim=1) # [B, l0+l1]
log_q = functional.log_softmax(concat, dim=1) # log(q_j)

# Build target distribution p: 1/level at each GT index, 0 elsewhere
p = torch.zeros_like(log_q, dtype=torch.float64)
offset = 0
for num_c, tgt_l in zip([l0, l1], targets.T):
rows = torch.arange(batch)
p[rows, offset + tgt_l] = 1.0 / level
offset += num_c

# Manual KL with 'batchmean' reduction:
# sum_i sum_j p_ij * (log p_ij - log q_ij) / batch
# (avoid log(0) by masking since p is sparse)
mask = p > 0
log_p = torch.zeros_like(p)
log_p[mask] = torch.log(p[mask])
manual_kl = (p * (log_p - log_q)).sum() / batch

# Loss under test (must match manual)
loss_fn = TreePathKLDivergenceLoss(reduction="batchmean")
test_kl = loss_fn([logits0.float(), logits1.float()], targets)

assert torch.allclose(
test_kl.double(), manual_kl, atol=1e-8
), f"manual={manual_kl.item():.12f} vs loss={test_kl.item():.12f}"
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@
"""Head modules for OTX custom model."""

from .base_classifier import ImageClassifier
from .h_label_classifier import HLabelClassifier
from .h_label_classifier import HLabelClassifier, KLHLabelClassifier

__all__ = ["ImageClassifier", "HLabelClassifier"]
__all__ = ["ImageClassifier", "HLabelClassifier", "KLHLabelClassifier"]
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch

from otx.backend.native.models.classification.heads.hlabel_cls_head import HierarchicalClsHead
from otx.backend.native.models.classification.losses.tree_path_kl_divergence_loss import TreePathKLDivergenceLoss
from otx.backend.native.models.classification.utils.ignored_labels import get_valid_label_mask

from .base_classifier import ImageClassifier
Expand Down Expand Up @@ -143,3 +144,87 @@ def _forward_explain(self, images: torch.Tensor) -> dict[str, torch.Tensor | lis
outputs["preds"] = preds

return outputs


class KLHLabelClassifier(HLabelClassifier):
"""Hierarchical label classifier with tree path KL divergence loss.

Args:
backbone (nn.Module): Backbone network.
neck (nn.Module | None): Neck network.
head (nn.Module): Head network.
multiclass_loss (nn.Module): Multiclass loss function.
multilabel_loss (nn.Module | None, optional): Multilabel loss function.
init_cfg (dict | list[dict] | None, optional): Initialization configuration.
kl_weight (float): Loss weight for tree path KL divergence loss

Attributes:
multiclass_loss (nn.Module): Multiclass loss function.
multilabel_loss (nn.Module | None): Multilabel loss function.
is_ignored_label_loss (bool): Flag indicating if ignored label loss is used.

Methods:
loss(inputs, labels, **kwargs): Calculate losses from a batch of inputs and data samples.
"""

def __init__(self, *args, kl_weight: float = 1.0, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.kl_weight = kl_weight
self.kl_loss = TreePathKLDivergenceLoss(reduction="batchmean", loss_weight=1.0)

def loss(self, inputs: torch.Tensor, labels: torch.Tensor, **kwargs) -> torch.Tensor:
"""Calculate losses from a batch of inputs and data samples.

Args:
inputs (torch.Tensor): The input tensor with shape
(N, C, ...) in general.
labels (torch.Tensor): The annotation data of
every samples.

Returns:
torch.Tensor: loss components
"""
cls_scores = self.extract_feat(inputs, stage="head")
loss_score = torch.tensor(0.0, device=cls_scores.device)
logits_list = []
target_list = []
num_effective_heads_in_batch = 0
for i in range(self.head.num_multiclass_heads):
if i not in self.head.empty_multiclass_head_indices:
head_gt = labels[:, i]
logit_range = self.head._get_head_idx_to_logits_range(i) # noqa: SLF001
head_logits = cls_scores[:, logit_range[0] : logit_range[1]]
valid_mask = head_gt >= 0
head_gt = head_gt[valid_mask]
if len(head_gt) > 0:
head_logits = head_logits[valid_mask]
logits_list.append(head_logits)
target_list.append(head_gt)
ce = self.multiclass_loss(head_logits, head_gt)
loss_score += ce
num_effective_heads_in_batch += 1

if num_effective_heads_in_batch > 0:
loss_score /= num_effective_heads_in_batch

if len(logits_list) > 1:
kl_loss = self.kl_loss(logits_list, torch.stack(target_list, dim=1))
loss_score += self.kl_weight * kl_loss

# Multilabel logic (preserved as-is)
if self.head.num_multilabel_classes > 0:
head_gt = labels[:, self.head.num_multiclass_heads :]
head_logits = cls_scores[:, self.head.num_single_label_classes :]
valid_mask = head_gt > 0
head_gt = head_gt[valid_mask]
if len(head_gt) > 0 and self.multilabel_loss is not None:
head_logits = head_logits[valid_mask]
imgs_info = kwargs.pop("imgs_info", None)
if imgs_info is not None and self.is_ignored_label_loss:
valid_label_mask = get_valid_label_mask(imgs_info, self.head.num_classes).to(head_logits.device)
valid_label_mask = valid_label_mask[:, self.head.num_single_label_classes :]
valid_label_mask = valid_label_mask[valid_mask]
kwargs["valid_label_mask"] = valid_label_mask
loss_score += self.multilabel_loss(head_logits, head_gt, **kwargs)

return loss_score
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from otx.backend.native.models.base import DataInputParams, DefaultOptimizerCallable, DefaultSchedulerCallable
from otx.backend.native.models.classification.backbones.timm import TimmBackbone
from otx.backend.native.models.classification.classifier import HLabelClassifier
from otx.backend.native.models.classification.classifier import HLabelClassifier, KLHLabelClassifier
from otx.backend.native.models.classification.heads import HierarchicalLinearClsHead
from otx.backend.native.models.classification.hlabel_models.base import OTXHlabelClsModel
from otx.backend.native.models.classification.losses.asymmetric_angular_loss_with_ignore import (
Expand Down Expand Up @@ -89,3 +89,66 @@ def _create_model(self, head_config: dict | None = None) -> nn.Module: # type:
def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict:
"""Load the previous OTX ckpt according to OTX2.0."""
return OTXv1Helper.load_cls_effnet_v2_ckpt(state_dict, "hlabel", add_prefix)


class TimmModelHLabelClsWithKL(OTXHlabelClsModel):
"""Timm Model for hierarchical label classification task.

Args:
label_info (HLabelInfo): The label information for the classification task.
model_name (str): The name of the model.
You can find available models at timm.list_models() or timm.list_pretrained().
input_size (tuple[int, int], optional): Model input size in the order of height and width.
Defaults to (224, 224).
pretrained (bool, optional): Whether to load pretrained weights. Defaults to True.
optimizer (OptimizerCallable, optional): The optimizer callable for training the model.
scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional): The learning rate scheduler callable.
metric (MetricCallable, optional): The metric callable for evaluating the model.
Defaults to HLabelClsMetricCallable.
torch_compile (bool, optional): Whether to compile the model using TorchScript. Defaults to False.
"""

def __init__(
self,
label_info: HLabelInfo,
data_input_params: DataInputParams,
model_name: str = "tf_efficientnetv2_s.in21k",
freeze_backbone: bool = False,
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = HLabelClsMetricCallable,
torch_compile: bool = False,
kl_weight: float = 1.0,
) -> None:
self.kl_weight = float(kl_weight)
super().__init__(
label_info=label_info,
data_input_params=data_input_params,
model_name=model_name,
freeze_backbone=freeze_backbone,
optimizer=optimizer,
scheduler=scheduler,
metric=metric,
torch_compile=torch_compile,
)

def _create_model(self, head_config: dict | None = None) -> nn.Module: # type: ignore[override]
head_config = head_config if head_config is not None else self.label_info.as_head_config_dict()
backbone = TimmBackbone(model_name=self.model_name)
copied_head_config = copy(head_config)
copied_head_config["step_size"] = (
ceil(self.data_input_params.input_size[0] / 32),
ceil(self.data_input_params.input_size[1] / 32),
)
return KLHLabelClassifier(
backbone=backbone,
neck=GlobalAveragePooling(dim=2),
head=HierarchicalLinearClsHead(**copied_head_config, in_channels=backbone.num_features),
multiclass_loss=nn.CrossEntropyLoss(),
multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
kl_weight=self.kl_weight,
)

def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict:
"""Load the previous OTX ckpt according to OTX2.0."""
return OTXv1Helper.load_cls_effnet_v2_ckpt(state_dict, "hlabel", add_prefix)
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""Module for defining TreePathKLDivergenceLoss."""

from __future__ import annotations

import torch
from torch import nn
from torch.nn import functional


class TreePathKLDivergenceLoss(nn.Module):
"""KL divergence between model distribution over concatenated heads and a target distribution.

Inputs:
logits_list: list of tensors [B, C_l], ordered from root -> leaf
targets: LongTensor [B, L] with per-level GT indices (L == len(logits_list))

The target distribution places 1/L probability on the GT index for each level,
and 0 elsewhere, then uses KLDivLoss(log_softmax(logits), target_probs).
"""

def __init__(self, reduction: str | None = "batchmean", loss_weight: float = 1.0):
super().__init__()
self.reduction = reduction
self.loss_weight = loss_weight
self.kl_div = nn.KLDivLoss(reduction=self.reduction)

def forward(self, logits_list: list[torch.Tensor], targets: torch.Tensor) -> torch.Tensor:
"""Calculate tree_path KL Divergence loss."""
if not (isinstance(logits_list, (list, tuple)) and len(logits_list) > 0):
msg = "logits_list must be non-empty"
raise ValueError(msg)
num_levels = len(logits_list)

# concat logits across all levels
dims = [t.size(1) for t in logits_list]
logits_concat = torch.cat(logits_list, dim=1) # [B, sum(C_l)]
log_probs = functional.log_softmax(logits_concat, dim=1) # [B, sum(C_l)]

# build sparse target distribution with 1/L at each GT index
batch = log_probs.size(0)
tgt = torch.zeros_like(log_probs) # [B, sum(C_l)]
offset = 0
for num_c, tgt_l in zip(dims, targets.T): # level-by-level
idx_rows = torch.arange(batch, device=log_probs.device)
tgt[idx_rows, offset + tgt_l] = 1.0 / num_levels
offset += num_c

kl = self.kl_div(log_probs, tgt)
return self.loss_weight * kl
Loading
Loading