Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@
"""Head modules for OTX custom model."""

from .base_classifier import ImageClassifier
from .h_label_classifier import HLabelClassifier
from .h_label_classifier import HLabelClassifier, KLHLabelClassifier

__all__ = ["ImageClassifier", "HLabelClassifier"]
__all__ = ["ImageClassifier", "HLabelClassifier", "KLHLabelClassifier"]
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch

from otx.backend.native.models.classification.heads.hlabel_cls_head import HierarchicalClsHead
from otx.backend.native.models.classification.losses.tree_path_kl_divergence_loss import TreePathKLDivergenceLoss
from otx.backend.native.models.classification.utils.ignored_labels import get_valid_label_mask

from .base_classifier import ImageClassifier
Expand Down Expand Up @@ -143,3 +144,87 @@ def _forward_explain(self, images: torch.Tensor) -> dict[str, torch.Tensor | lis
outputs["preds"] = preds

return outputs


class KLHLabelClassifier(HLabelClassifier):
"""Hierarchical label classifier with tree path KL divergence loss.

Args:
backbone (nn.Module): Backbone network.
neck (nn.Module | None): Neck network.
head (nn.Module): Head network.
multiclass_loss (nn.Module): Multiclass loss function.
multilabel_loss (nn.Module | None, optional): Multilabel loss function.
init_cfg (dict | list[dict] | None, optional): Initialization configuration.
kl_weight (float): Loss weight for tree path KL divergence loss

Attributes:
multiclass_loss (nn.Module): Multiclass loss function.
multilabel_loss (nn.Module | None): Multilabel loss function.
is_ignored_label_loss (bool): Flag indicating if ignored label loss is used.

Methods:
loss(inputs, labels, **kwargs): Calculate losses from a batch of inputs and data samples.
"""

def __init__(self, *args, kl_weight: float = 1.0, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.kl_weight = kl_weight
self.kl_loss = TreePathKLDivergenceLoss(reduction="batchmean", loss_weight=1.0)

def loss(self, inputs: torch.Tensor, labels: torch.Tensor, **kwargs) -> torch.Tensor:
"""Calculate losses from a batch of inputs and data samples.

Args:
inputs (torch.Tensor): The input tensor with shape
(N, C, ...) in general.
labels (torch.Tensor): The annotation data of
every samples.

Returns:
torch.Tensor: loss components
"""
cls_scores = self.extract_feat(inputs, stage="head")
loss_score = torch.tensor(0.0, device=cls_scores.device)
logits_list = []
target_list = []
num_effective_heads_in_batch = 0
for i in range(self.head.num_multiclass_heads):
if i not in self.head.empty_multiclass_head_indices:
head_gt = labels[:, i]
logit_range = self.head._get_head_idx_to_logits_range(i) # noqa: SLF001
head_logits = cls_scores[:, logit_range[0] : logit_range[1]]
valid_mask = head_gt >= 0
head_gt = head_gt[valid_mask]
if len(head_gt) > 0:
head_logits = head_logits[valid_mask]
logits_list.append(head_logits)
target_list.append(head_gt)
ce = self.multiclass_loss(head_logits, head_gt)
loss_score += ce
num_effective_heads_in_batch += 1

if num_effective_heads_in_batch > 0:
loss_score /= num_effective_heads_in_batch

if len(logits_list) > 1:
kl_loss = self.kl_loss(logits_list, torch.stack(target_list, dim=1))
loss_score += self.kl_weight * kl_loss

# Multilabel logic (preserved as-is)
if self.head.num_multilabel_classes > 0:
head_gt = labels[:, self.head.num_multiclass_heads :]
head_logits = cls_scores[:, self.head.num_single_label_classes :]
valid_mask = head_gt > 0
head_gt = head_gt[valid_mask]
if len(head_gt) > 0 and self.multilabel_loss is not None:
head_logits = head_logits[valid_mask]
imgs_info = kwargs.pop("imgs_info", None)
if imgs_info is not None and self.is_ignored_label_loss:
valid_label_mask = get_valid_label_mask(imgs_info, self.head.num_classes).to(head_logits.device)
valid_label_mask = valid_label_mask[:, self.head.num_single_label_classes :]
valid_label_mask = valid_label_mask[valid_mask]
kwargs["valid_label_mask"] = valid_label_mask
loss_score += self.multilabel_loss(head_logits, head_gt, **kwargs)

return loss_score
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from abc import abstractmethod
from copy import deepcopy
from functools import wraps
from typing import TYPE_CHECKING, Any

import torch
Expand All @@ -15,6 +16,7 @@
from otx.backend.native.exporter.base import OTXModelExporter
from otx.backend.native.exporter.native import OTXNativeModelExporter
from otx.backend.native.models.base import DataInputParams, DefaultOptimizerCallable, DefaultSchedulerCallable, OTXModel
from otx.backend.native.models.classification.classifier import KLHLabelClassifier
from otx.backend.native.schedulers import LRSchedulerListCallable
from otx.data.entity.base import OTXBatchLossEntity
from otx.data.entity.torch import OTXDataBatch, OTXPredBatch
Expand Down Expand Up @@ -45,6 +47,7 @@ class OTXHlabelClsModel(OTXModel):
Defaults to DefaultSchedulerCallable.
metric (MetricCallable, optional): Callable for the metric. Defaults to HLabelClsMetricCallable.
torch_compile (bool, optional): Flag to indicate whether to use torch.compile. Defaults to False.
kl_weight: The weight of tree-path KL divergence loss. Defaults to zero, use CrossEntropy only.
"""

label_info: HLabelInfo
Expand All @@ -59,7 +62,9 @@ def __init__(
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = HLabelClsMetricCallable,
torch_compile: bool = False,
kl_weight: float = 0.0,
) -> None:
self.kl_weight = kl_weight
super().__init__(
label_info=label_info,
data_input_params=data_input_params,
Expand All @@ -70,16 +75,46 @@ def __init__(
metric=metric,
torch_compile=torch_compile,
)

if freeze_backbone:
classification_layers = self._identify_classification_layers()
for name, param in self.named_parameters():
param.requires_grad = name in classification_layers

def __getattribute__(self, name: str):
attr = super().__getattribute__(name)
if name == "_create_model" and callable(attr):
cache_name = "__cm_cached__"
cache = super().__getattribute__("__dict__").get(cache_name)
if cache:
return cache

@wraps(attr)
def wrapped(*a, **kw) -> nn.Module:
model = attr(*a, **kw)
return self._finalize_model(model)

self.__dict__[cache_name] = wrapped
return wrapped
return attr

@abstractmethod
def _create_model(self, head_config: dict | None = None) -> nn.Module: # type: ignore[override]
"""Create a PyTorch model for this class."""

def _finalize_model(self, model: nn.Module) -> nn.Module:
"""Run after child _create_model(); upgrade to KL if enabled."""
if self.kl_weight > 0:
return KLHLabelClassifier(
backbone=model.backbone,
neck=model.neck,
head=model.head,
multiclass_loss=model.multiclass_loss,
multilabel_loss=model.multilabel_loss,
init_cfg=getattr(model, "init_cfg", None),
kl_weight=self.kl_weight,
)
return model

def _identify_classification_layers(self, prefix: str = "model.") -> list[str]:
"""Simple identification of the classification layers. Used for incremental learning."""
# identify classification layers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class TimmModelHLabelCls(OTXHlabelClsModel):
metric (MetricCallable, optional): The metric callable for evaluating the model.
Defaults to HLabelClsMetricCallable.
torch_compile (bool, optional): Whether to compile the model using TorchScript. Defaults to False.
kl_weight: The weight of tree-path KL divergence loss. Defaults to zero, use CrossEntropy only.
"""

def __init__(
Expand All @@ -58,6 +59,7 @@ def __init__(
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = HLabelClsMetricCallable,
torch_compile: bool = False,
kl_weight: float = 0.0,
) -> None:
super().__init__(
label_info=label_info,
Expand All @@ -68,6 +70,7 @@ def __init__(
scheduler=scheduler,
metric=metric,
torch_compile=torch_compile,
kl_weight=kl_weight,
)

def _create_model(self, head_config: dict | None = None) -> nn.Module: # type: ignore[override]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""Module for defining TreePathKLDivergenceLoss."""

from __future__ import annotations

import torch
from torch import nn
from torch.nn import functional


class TreePathKLDivergenceLoss(nn.Module):
"""KL divergence between model distribution over concatenated heads and a target distribution.

Inputs:
logits_list: list of tensors [B, C_l], ordered from root -> leaf
targets: LongTensor [B, L] with per-level GT indices (L == len(logits_list))

The target distribution places 1/L probability on the GT index for each level,
and 0 elsewhere, then uses KLDivLoss(log_softmax(logits), target_probs).
"""

def __init__(self, reduction: str | None = "batchmean", loss_weight: float = 1.0):
super().__init__()
self.reduction = reduction
self.loss_weight = loss_weight
self.kl_div = nn.KLDivLoss(reduction=self.reduction)

def forward(self, logits_list: list[torch.Tensor], targets: torch.Tensor) -> torch.Tensor:
"""Calculate tree_path KL Divergence loss."""
if not (isinstance(logits_list, (list, tuple)) and len(logits_list) > 0):
msg = "logits_list must be non-empty"
raise ValueError(msg)
num_levels = len(logits_list)

# concat logits across all levels
dims = [t.size(1) for t in logits_list]
logits_concat = torch.cat(logits_list, dim=1) # [B, sum(C_l)]
log_probs = functional.log_softmax(logits_concat, dim=1) # [B, sum(C_l)]

# build sparse target distribution with 1/L at each GT index
batch = log_probs.size(0)
tgt = torch.zeros_like(log_probs) # [B, sum(C_l)]
offset = 0
for num_c, tgt_l in zip(dims, targets.T): # level-by-level
idx_rows = torch.arange(batch, device=log_probs.device)
tgt[idx_rows, offset + tgt_l] = 1.0 / num_levels
offset += num_c

kl = self.kl_div(log_probs, tgt)
return self.loss_weight * kl
2 changes: 1 addition & 1 deletion library/src/otx/backend/native/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def mock_modules_for_chkpt() -> Iterator[None]:
sys.modules["otx.core.types.task"] = otx.types.task
sys.modules["otx.core.types.label"] = otx.types.label
sys.modules["otx.core.model"] = otx.backend.native.models # type: ignore[attr-defined]
sys.modules["otx.core.metrics"] = otx.metrics
# sys.modules["otx.core.metrics"] = otx.metrics

yield
finally:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
task: H_LABEL_CLS
model:
class_path: otx.backend.native.models.classification.hlabel_models.timm_model.TimmModelHLabelCls
init_args:
kl_weight: 2.0
model_name: tf_efficientnetv2_s.in21k

optimizer:
class_path: torch.optim.SGD
init_args:
lr: 0.0071
momentum: 0.9
weight_decay: 0.0001

scheduler:
class_path: otx.backend.native.schedulers.LinearWarmupSchedulerCallable
init_args:
num_warmup_steps: 0
main_scheduler_callable:
class_path: lightning.pytorch.cli.ReduceLROnPlateau
init_args:
mode: max
factor: 0.5
patience: 3
monitor: val/accuracy

engine:
device: auto

callback_monitor: val/accuracy

data: ../../_base_/data/classification.yaml

callbacks:
- class_path: otx.backend.native.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup
init_args:
warmup_iters: 750
patience: 5
mode: max
monitor: val/accuracy
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
dirpath: "" # use engine.work_dir
monitor: val/accuracy
mode: max
save_top_k: 1
save_last: true
auto_insert_metric_name: false
filename: "checkpoints/epoch_{epoch:03d}"

overrides:
reset:
- data.train_subset.transforms

max_epochs: 90

data:
task: H_LABEL_CLS
data_format: datumaro
train_subset:
transforms:
- class_path: otx.data.transform_libs.torchvision.EfficientNetRandomCrop
init_args:
scale: $(input_size)
crop_ratio_range:
- 0.08
- 1.0
aspect_ratio_range:
- 0.75
- 1.34
- class_path: torchvision.transforms.v2.RandomPhotometricDistort
enable: false
init_args:
brightness:
- 0.875
- 1.125
contrast:
- 0.5
- 1.5
saturation:
- 0.5
- 1.5
hue:
- -0.05
- 0.05
p: 0.5
- class_path: otx.data.transform_libs.torchvision.RandomAffine
enable: false
init_args:
max_rotate_degree: 10.0
max_translate_ratio: 0.1
scaling_ratio_range:
- 0.5
- 1.5
max_shear_degree: 2.0
- class_path: otx.data.transform_libs.torchvision.RandomFlip
enable: true
init_args:
probability: 0.5
- class_path: torchvision.transforms.v2.RandomVerticalFlip
enable: false
init_args:
p: 0.5
- class_path: otx.data.transform_libs.torchvision.RandomGaussianBlur
enable: false
init_args:
kernel_size: 5
sigma:
- 0.1
- 2.0
probability: 0.5
- class_path: torchvision.transforms.v2.ToDtype
init_args:
dtype: ${as_torch_dtype:torch.float32}
scale: false
- class_path: otx.data.transform_libs.torchvision.RandomGaussianNoise
enable: false
init_args:
mean: 0.0
sigma: 0.1
probability: 0.5
- class_path: torchvision.transforms.v2.Normalize
init_args:
mean: [123.675, 116.28, 103.53]
std: [58.395, 57.12, 57.375]
Loading
Loading