Skip to content
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@
"""Head modules for OTX custom model."""

from .base_classifier import ImageClassifier
from .h_label_classifier import HLabelClassifier
from .h_label_classifier import HLabelClassifier, KLHLabelClassifier

__all__ = ["ImageClassifier", "HLabelClassifier"]
__all__ = ["ImageClassifier", "HLabelClassifier", "KLHLabelClassifier"]
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch

from otx.backend.native.models.classification.heads.hlabel_cls_head import HierarchicalClsHead
from otx.backend.native.models.classification.losses.tree_path_kl_divergence_loss import TreePathKLDivergenceLoss
from otx.backend.native.models.classification.utils.ignored_labels import get_valid_label_mask

from .base_classifier import ImageClassifier
Expand Down Expand Up @@ -143,3 +144,87 @@ def _forward_explain(self, images: torch.Tensor) -> dict[str, torch.Tensor | lis
outputs["preds"] = preds

return outputs


class KLHLabelClassifier(HLabelClassifier):
"""Hierarchical label classifier with tree path KL divergence loss.

Args:
backbone (nn.Module): Backbone network.
neck (nn.Module | None): Neck network.
head (nn.Module): Head network.
multiclass_loss (nn.Module): Multiclass loss function.
multilabel_loss (nn.Module | None, optional): Multilabel loss function.
init_cfg (dict | list[dict] | None, optional): Initialization configuration.
kl_weight (float): Loss weight for tree path KL divergence loss

Attributes:
multiclass_loss (nn.Module): Multiclass loss function.
multilabel_loss (nn.Module | None): Multilabel loss function.
is_ignored_label_loss (bool): Flag indicating if ignored label loss is used.

Methods:
loss(inputs, labels, **kwargs): Calculate losses from a batch of inputs and data samples.
"""

def __init__(self, *args, kl_weight: float = 1.0, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.kl_weight = kl_weight
self.kl_loss = TreePathKLDivergenceLoss(reduction="batchmean", loss_weight=1.0)

def loss(self, inputs: torch.Tensor, labels: torch.Tensor, **kwargs) -> torch.Tensor:
"""Calculate losses from a batch of inputs and data samples.

Args:
inputs (torch.Tensor): The input tensor with shape
(N, C, ...) in general.
labels (torch.Tensor): The annotation data of
every samples.

Returns:
torch.Tensor: loss components
"""
cls_scores = self.extract_feat(inputs, stage="head")
loss_score = torch.tensor(0.0, device=cls_scores.device)
logits_list = []
target_list = []
num_effective_heads_in_batch = 0
for i in range(self.head.num_multiclass_heads):
if i not in self.head.empty_multiclass_head_indices:
head_gt = labels[:, i]
logit_range = self.head._get_head_idx_to_logits_range(i) # noqa: SLF001
head_logits = cls_scores[:, logit_range[0] : logit_range[1]]
valid_mask = head_gt >= 0
head_gt = head_gt[valid_mask]
if len(head_gt) > 0:
head_logits = head_logits[valid_mask]
logits_list.append(head_logits)
target_list.append(head_gt)
ce = self.multiclass_loss(head_logits, head_gt)
loss_score += ce
num_effective_heads_in_batch += 1

if num_effective_heads_in_batch > 0:
loss_score /= num_effective_heads_in_batch

if len(logits_list) > 1:
kl_loss = self.kl_loss(logits_list, torch.stack(target_list, dim=1))
loss_score += self.kl_weight * kl_loss

# Multilabel logic (preserved as-is)
if self.head.num_multilabel_classes > 0:
head_gt = labels[:, self.head.num_multiclass_heads :]
head_logits = cls_scores[:, self.head.num_single_label_classes :]
valid_mask = head_gt > 0
head_gt = head_gt[valid_mask]
if len(head_gt) > 0 and self.multilabel_loss is not None:
head_logits = head_logits[valid_mask]
imgs_info = kwargs.pop("imgs_info", None)
if imgs_info is not None and self.is_ignored_label_loss:
valid_label_mask = get_valid_label_mask(imgs_info, self.head.num_classes).to(head_logits.device)
valid_label_mask = valid_label_mask[:, self.head.num_single_label_classes :]
valid_label_mask = valid_label_mask[valid_mask]
kwargs["valid_label_mask"] = valid_label_mask
loss_score += self.multilabel_loss(head_logits, head_gt, **kwargs)

return loss_score
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from abc import abstractmethod
from copy import deepcopy
from functools import wraps
from typing import TYPE_CHECKING, Any

import torch
Expand All @@ -15,6 +16,7 @@
from otx.backend.native.exporter.base import OTXModelExporter
from otx.backend.native.exporter.native import OTXNativeModelExporter
from otx.backend.native.models.base import DataInputParams, DefaultOptimizerCallable, DefaultSchedulerCallable, OTXModel
from otx.backend.native.models.classification.classifier import KLHLabelClassifier
from otx.backend.native.schedulers import LRSchedulerListCallable
from otx.data.entity.base import OTXBatchLossEntity
from otx.data.entity.torch import OTXDataBatch, OTXPredBatch
Expand Down Expand Up @@ -59,7 +61,9 @@ def __init__(
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = HLabelClsMetricCallable,
torch_compile: bool = False,
**kwargs,
) -> None:
self.kl_weight = kwargs.get("kl_weight", 0.0)
super().__init__(
label_info=label_info,
data_input_params=data_input_params,
Expand All @@ -70,16 +74,46 @@ def __init__(
metric=metric,
torch_compile=torch_compile,
)

if freeze_backbone:
classification_layers = self._identify_classification_layers()
for name, param in self.named_parameters():
param.requires_grad = name in classification_layers

def __getattribute__(self, name: str):
attr = super().__getattribute__(name)
if name == "_create_model" and callable(attr):
cache_name = "__cm_cached__"
cache = super().__getattribute__("__dict__").get(cache_name)
if cache:
return cache

@wraps(attr)
def wrapped(*a, **kw) -> nn.Module:
model = attr(*a, **kw)
return self._finalize_model(model)

self.__dict__[cache_name] = wrapped
return wrapped
return attr

@abstractmethod
def _create_model(self, head_config: dict | None = None) -> nn.Module: # type: ignore[override]
"""Create a PyTorch model for this class."""

def _finalize_model(self, model: nn.Module) -> nn.Module:
"""Run after child _create_model(); upgrade to KL if enabled."""
if self.kl_weight > 0:
return KLHLabelClassifier(
backbone=model.backbone,
neck=model.neck,
head=model.head,
multiclass_loss=model.multiclass_loss,
multilabel_loss=model.multilabel_loss,
init_cfg=getattr(model, "init_cfg", None),
kl_weight=self.kl_weight,
)
return model

def _identify_classification_layers(self, prefix: str = "model.") -> list[str]:
"""Simple identification of the classification layers. Used for incremental learning."""
# identify classification layers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from otx.backend.native.models.base import DataInputParams, DefaultOptimizerCallable, DefaultSchedulerCallable
from otx.backend.native.models.classification.backbones.timm import TimmBackbone
from otx.backend.native.models.classification.classifier import HLabelClassifier
from otx.backend.native.models.classification.classifier import HLabelClassifier, KLHLabelClassifier
from otx.backend.native.models.classification.heads import HierarchicalLinearClsHead
from otx.backend.native.models.classification.hlabel_models.base import OTXHlabelClsModel
from otx.backend.native.models.classification.losses.asymmetric_angular_loss_with_ignore import (
Expand Down Expand Up @@ -58,6 +58,7 @@ def __init__(
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = HLabelClsMetricCallable,
torch_compile: bool = False,
**kwargs,
) -> None:
super().__init__(
label_info=label_info,
Expand All @@ -68,6 +69,7 @@ def __init__(
scheduler=scheduler,
metric=metric,
torch_compile=torch_compile,
**kwargs,
)

def _create_model(self, head_config: dict | None = None) -> nn.Module: # type: ignore[override]
Expand All @@ -89,3 +91,66 @@ def _create_model(self, head_config: dict | None = None) -> nn.Module: # type:
def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict:
"""Load the previous OTX ckpt according to OTX2.0."""
return OTXv1Helper.load_cls_effnet_v2_ckpt(state_dict, "hlabel", add_prefix)


class TimmModelHLabelClsWithKL(OTXHlabelClsModel):
"""Timm Model for hierarchical label classification task.

Args:
label_info (HLabelInfo): The label information for the classification task.
model_name (str): The name of the model.
You can find available models at timm.list_models() or timm.list_pretrained().
input_size (tuple[int, int], optional): Model input size in the order of height and width.
Defaults to (224, 224).
pretrained (bool, optional): Whether to load pretrained weights. Defaults to True.
optimizer (OptimizerCallable, optional): The optimizer callable for training the model.
scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional): The learning rate scheduler callable.
metric (MetricCallable, optional): The metric callable for evaluating the model.
Defaults to HLabelClsMetricCallable.
torch_compile (bool, optional): Whether to compile the model using TorchScript. Defaults to False.
"""

def __init__(
self,
label_info: HLabelInfo,
data_input_params: DataInputParams,
model_name: str = "tf_efficientnetv2_s.in21k",
freeze_backbone: bool = False,
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = HLabelClsMetricCallable,
torch_compile: bool = False,
kl_weight: float = 1.0,
) -> None:
self.kl_weight = float(kl_weight)
super().__init__(
label_info=label_info,
data_input_params=data_input_params,
model_name=model_name,
freeze_backbone=freeze_backbone,
optimizer=optimizer,
scheduler=scheduler,
metric=metric,
torch_compile=torch_compile,
)

def _create_model(self, head_config: dict | None = None) -> nn.Module: # type: ignore[override]
head_config = head_config if head_config is not None else self.label_info.as_head_config_dict()
backbone = TimmBackbone(model_name=self.model_name)
copied_head_config = copy(head_config)
copied_head_config["step_size"] = (
ceil(self.data_input_params.input_size[0] / 32),
ceil(self.data_input_params.input_size[1] / 32),
)
return KLHLabelClassifier(
backbone=backbone,
neck=GlobalAveragePooling(dim=2),
head=HierarchicalLinearClsHead(**copied_head_config, in_channels=backbone.num_features),
multiclass_loss=nn.CrossEntropyLoss(),
multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
kl_weight=self.kl_weight,
)

def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict:
"""Load the previous OTX ckpt according to OTX2.0."""
return OTXv1Helper.load_cls_effnet_v2_ckpt(state_dict, "hlabel", add_prefix)
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""Module for defining TreePathKLDivergenceLoss."""

from __future__ import annotations

import torch
from torch import nn
from torch.nn import functional


class TreePathKLDivergenceLoss(nn.Module):
"""KL divergence between model distribution over concatenated heads and a target distribution.

Inputs:
logits_list: list of tensors [B, C_l], ordered from root -> leaf
targets: LongTensor [B, L] with per-level GT indices (L == len(logits_list))

The target distribution places 1/L probability on the GT index for each level,
and 0 elsewhere, then uses KLDivLoss(log_softmax(logits), target_probs).
"""

def __init__(self, reduction: str | None = "batchmean", loss_weight: float = 1.0):
super().__init__()
self.reduction = reduction
self.loss_weight = loss_weight
self.kl_div = nn.KLDivLoss(reduction=self.reduction)

def forward(self, logits_list: list[torch.Tensor], targets: torch.Tensor) -> torch.Tensor:
"""Calculate tree_path KL Divergence loss."""
if not (isinstance(logits_list, (list, tuple)) and len(logits_list) > 0):
msg = "logits_list must be non-empty"
raise ValueError(msg)
num_levels = len(logits_list)

# concat logits across all levels
dims = [t.size(1) for t in logits_list]
logits_concat = torch.cat(logits_list, dim=1) # [B, sum(C_l)]
log_probs = functional.log_softmax(logits_concat, dim=1) # [B, sum(C_l)]

# build sparse target distribution with 1/L at each GT index
batch = log_probs.size(0)
tgt = torch.zeros_like(log_probs) # [B, sum(C_l)]
offset = 0
for num_c, tgt_l in zip(dims, targets.T): # level-by-level
idx_rows = torch.arange(batch, device=log_probs.device)
tgt[idx_rows, offset + tgt_l] = 1.0 / num_levels
offset += num_c

kl = self.kl_div(log_probs, tgt)
return self.loss_weight * kl
Loading
Loading