diff --git a/library/src/otx/backend/native/models/classification/classifier/__init__.py b/library/src/otx/backend/native/models/classification/classifier/__init__.py index 40b9ad9ae7..d38f562265 100644 --- a/library/src/otx/backend/native/models/classification/classifier/__init__.py +++ b/library/src/otx/backend/native/models/classification/classifier/__init__.py @@ -4,6 +4,6 @@ """Head modules for OTX custom model.""" from .base_classifier import ImageClassifier -from .h_label_classifier import HLabelClassifier +from .h_label_classifier import HLabelClassifier, KLHLabelClassifier -__all__ = ["ImageClassifier", "HLabelClassifier"] +__all__ = ["ImageClassifier", "HLabelClassifier", "KLHLabelClassifier"] diff --git a/library/src/otx/backend/native/models/classification/classifier/h_label_classifier.py b/library/src/otx/backend/native/models/classification/classifier/h_label_classifier.py index 1649785815..1a6cf0ee81 100644 --- a/library/src/otx/backend/native/models/classification/classifier/h_label_classifier.py +++ b/library/src/otx/backend/native/models/classification/classifier/h_label_classifier.py @@ -11,6 +11,7 @@ import torch from otx.backend.native.models.classification.heads.hlabel_cls_head import HierarchicalClsHead +from otx.backend.native.models.classification.losses.tree_path_kl_divergence_loss import TreePathKLDivergenceLoss from otx.backend.native.models.classification.utils.ignored_labels import get_valid_label_mask from .base_classifier import ImageClassifier @@ -143,3 +144,87 @@ def _forward_explain(self, images: torch.Tensor) -> dict[str, torch.Tensor | lis outputs["preds"] = preds return outputs + + +class KLHLabelClassifier(HLabelClassifier): + """Hierarchical label classifier with tree path KL divergence loss. + + Args: + backbone (nn.Module): Backbone network. + neck (nn.Module | None): Neck network. + head (nn.Module): Head network. + multiclass_loss (nn.Module): Multiclass loss function. + multilabel_loss (nn.Module | None, optional): Multilabel loss function. + init_cfg (dict | list[dict] | None, optional): Initialization configuration. + kl_weight (float): Loss weight for tree path KL divergence loss + + Attributes: + multiclass_loss (nn.Module): Multiclass loss function. + multilabel_loss (nn.Module | None): Multilabel loss function. + is_ignored_label_loss (bool): Flag indicating if ignored label loss is used. + + Methods: + loss(inputs, labels, **kwargs): Calculate losses from a batch of inputs and data samples. + """ + + def __init__(self, *args, kl_weight: float = 1.0, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.kl_weight = kl_weight + self.kl_loss = TreePathKLDivergenceLoss(reduction="batchmean", loss_weight=1.0) + + def loss(self, inputs: torch.Tensor, labels: torch.Tensor, **kwargs) -> torch.Tensor: + """Calculate losses from a batch of inputs and data samples. + + Args: + inputs (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + labels (torch.Tensor): The annotation data of + every samples. + + Returns: + torch.Tensor: loss components + """ + cls_scores = self.extract_feat(inputs, stage="head") + loss_score = torch.tensor(0.0, device=cls_scores.device) + logits_list = [] + target_list = [] + num_effective_heads_in_batch = 0 + for i in range(self.head.num_multiclass_heads): + if i not in self.head.empty_multiclass_head_indices: + head_gt = labels[:, i] + logit_range = self.head._get_head_idx_to_logits_range(i) # noqa: SLF001 + head_logits = cls_scores[:, logit_range[0] : logit_range[1]] + valid_mask = head_gt >= 0 + head_gt = head_gt[valid_mask] + if len(head_gt) > 0: + head_logits = head_logits[valid_mask] + logits_list.append(head_logits) + target_list.append(head_gt) + ce = self.multiclass_loss(head_logits, head_gt) + loss_score += ce + num_effective_heads_in_batch += 1 + + if num_effective_heads_in_batch > 0: + loss_score /= num_effective_heads_in_batch + + if len(logits_list) > 1: + kl_loss = self.kl_loss(logits_list, torch.stack(target_list, dim=1)) + loss_score += self.kl_weight * kl_loss + + # Multilabel logic (preserved as-is) + if self.head.num_multilabel_classes > 0: + head_gt = labels[:, self.head.num_multiclass_heads :] + head_logits = cls_scores[:, self.head.num_single_label_classes :] + valid_mask = head_gt > 0 + head_gt = head_gt[valid_mask] + if len(head_gt) > 0 and self.multilabel_loss is not None: + head_logits = head_logits[valid_mask] + imgs_info = kwargs.pop("imgs_info", None) + if imgs_info is not None and self.is_ignored_label_loss: + valid_label_mask = get_valid_label_mask(imgs_info, self.head.num_classes).to(head_logits.device) + valid_label_mask = valid_label_mask[:, self.head.num_single_label_classes :] + valid_label_mask = valid_label_mask[valid_mask] + kwargs["valid_label_mask"] = valid_label_mask + loss_score += self.multilabel_loss(head_logits, head_gt, **kwargs) + + return loss_score diff --git a/library/src/otx/backend/native/models/classification/hlabel_models/base.py b/library/src/otx/backend/native/models/classification/hlabel_models/base.py index 4697ff5fa8..8c13d80a21 100644 --- a/library/src/otx/backend/native/models/classification/hlabel_models/base.py +++ b/library/src/otx/backend/native/models/classification/hlabel_models/base.py @@ -7,6 +7,7 @@ from abc import abstractmethod from copy import deepcopy +from functools import wraps from typing import TYPE_CHECKING, Any import torch @@ -15,6 +16,7 @@ from otx.backend.native.exporter.base import OTXModelExporter from otx.backend.native.exporter.native import OTXNativeModelExporter from otx.backend.native.models.base import DataInputParams, DefaultOptimizerCallable, DefaultSchedulerCallable, OTXModel +from otx.backend.native.models.classification.classifier import KLHLabelClassifier from otx.backend.native.schedulers import LRSchedulerListCallable from otx.data.entity.base import OTXBatchLossEntity from otx.data.entity.torch import OTXDataBatch, OTXPredBatch @@ -45,6 +47,7 @@ class OTXHlabelClsModel(OTXModel): Defaults to DefaultSchedulerCallable. metric (MetricCallable, optional): Callable for the metric. Defaults to HLabelClsMetricCallable. torch_compile (bool, optional): Flag to indicate whether to use torch.compile. Defaults to False. + kl_weight: The weight of tree-path KL divergence loss. Defaults to zero, use CrossEntropy only. """ label_info: HLabelInfo @@ -59,7 +62,9 @@ def __init__( scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, metric: MetricCallable = HLabelClsMetricCallable, torch_compile: bool = False, + kl_weight: float = 0.0, ) -> None: + self.kl_weight = kl_weight super().__init__( label_info=label_info, data_input_params=data_input_params, @@ -70,16 +75,46 @@ def __init__( metric=metric, torch_compile=torch_compile, ) - if freeze_backbone: classification_layers = self._identify_classification_layers() for name, param in self.named_parameters(): param.requires_grad = name in classification_layers + def __getattribute__(self, name: str): + attr = super().__getattribute__(name) + if name == "_create_model" and callable(attr): + cache_name = "__cm_cached__" + cache = super().__getattribute__("__dict__").get(cache_name) + if cache: + return cache + + @wraps(attr) + def wrapped(*a, **kw) -> nn.Module: + model = attr(*a, **kw) + return self._finalize_model(model) + + self.__dict__[cache_name] = wrapped + return wrapped + return attr + @abstractmethod def _create_model(self, head_config: dict | None = None) -> nn.Module: # type: ignore[override] """Create a PyTorch model for this class.""" + def _finalize_model(self, model: nn.Module) -> nn.Module: + """Run after child _create_model(); upgrade to KL if enabled.""" + if self.kl_weight > 0: + return KLHLabelClassifier( + backbone=model.backbone, + neck=model.neck, + head=model.head, + multiclass_loss=model.multiclass_loss, + multilabel_loss=model.multilabel_loss, + init_cfg=getattr(model, "init_cfg", None), + kl_weight=self.kl_weight, + ) + return model + def _identify_classification_layers(self, prefix: str = "model.") -> list[str]: """Simple identification of the classification layers. Used for incremental learning.""" # identify classification layers diff --git a/library/src/otx/backend/native/models/classification/hlabel_models/timm_model.py b/library/src/otx/backend/native/models/classification/hlabel_models/timm_model.py index ba6f2c67fe..9d07611be6 100644 --- a/library/src/otx/backend/native/models/classification/hlabel_models/timm_model.py +++ b/library/src/otx/backend/native/models/classification/hlabel_models/timm_model.py @@ -46,6 +46,7 @@ class TimmModelHLabelCls(OTXHlabelClsModel): metric (MetricCallable, optional): The metric callable for evaluating the model. Defaults to HLabelClsMetricCallable. torch_compile (bool, optional): Whether to compile the model using TorchScript. Defaults to False. + kl_weight: The weight of tree-path KL divergence loss. Defaults to zero, use CrossEntropy only. """ def __init__( @@ -58,6 +59,7 @@ def __init__( scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, metric: MetricCallable = HLabelClsMetricCallable, torch_compile: bool = False, + kl_weight: float = 0.0, ) -> None: super().__init__( label_info=label_info, @@ -68,6 +70,7 @@ def __init__( scheduler=scheduler, metric=metric, torch_compile=torch_compile, + kl_weight=kl_weight, ) def _create_model(self, head_config: dict | None = None) -> nn.Module: # type: ignore[override] diff --git a/library/src/otx/backend/native/models/classification/losses/tree_path_kl_divergence_loss.py b/library/src/otx/backend/native/models/classification/losses/tree_path_kl_divergence_loss.py new file mode 100644 index 0000000000..d0ea1442bf --- /dev/null +++ b/library/src/otx/backend/native/models/classification/losses/tree_path_kl_divergence_loss.py @@ -0,0 +1,52 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Module for defining TreePathKLDivergenceLoss.""" + +from __future__ import annotations + +import torch +from torch import nn +from torch.nn import functional + + +class TreePathKLDivergenceLoss(nn.Module): + """KL divergence between model distribution over concatenated heads and a target distribution. + + Inputs: + logits_list: list of tensors [B, C_l], ordered from root -> leaf + targets: LongTensor [B, L] with per-level GT indices (L == len(logits_list)) + + The target distribution places 1/L probability on the GT index for each level, + and 0 elsewhere, then uses KLDivLoss(log_softmax(logits), target_probs). + """ + + def __init__(self, reduction: str | None = "batchmean", loss_weight: float = 1.0): + super().__init__() + self.reduction = reduction + self.loss_weight = loss_weight + self.kl_div = nn.KLDivLoss(reduction=self.reduction) + + def forward(self, logits_list: list[torch.Tensor], targets: torch.Tensor) -> torch.Tensor: + """Calculate tree_path KL Divergence loss.""" + if not (isinstance(logits_list, (list, tuple)) and len(logits_list) > 0): + msg = "logits_list must be non-empty" + raise ValueError(msg) + num_levels = len(logits_list) + + # concat logits across all levels + dims = [t.size(1) for t in logits_list] + logits_concat = torch.cat(logits_list, dim=1) # [B, sum(C_l)] + log_probs = functional.log_softmax(logits_concat, dim=1) # [B, sum(C_l)] + + # build sparse target distribution with 1/L at each GT index + batch = log_probs.size(0) + tgt = torch.zeros_like(log_probs) # [B, sum(C_l)] + offset = 0 + for num_c, tgt_l in zip(dims, targets.T): # level-by-level + idx_rows = torch.arange(batch, device=log_probs.device) + tgt[idx_rows, offset + tgt_l] = 1.0 / num_levels + offset += num_c + + kl = self.kl_div(log_probs, tgt) + return self.loss_weight * kl diff --git a/library/src/otx/backend/native/utils/utils.py b/library/src/otx/backend/native/utils/utils.py index 593d1f261f..9468723d2e 100644 --- a/library/src/otx/backend/native/utils/utils.py +++ b/library/src/otx/backend/native/utils/utils.py @@ -86,7 +86,7 @@ def mock_modules_for_chkpt() -> Iterator[None]: sys.modules["otx.core.types.task"] = otx.types.task sys.modules["otx.core.types.label"] = otx.types.label sys.modules["otx.core.model"] = otx.backend.native.models # type: ignore[attr-defined] - sys.modules["otx.core.metrics"] = otx.metrics + # sys.modules["otx.core.metrics"] = otx.metrics yield finally: diff --git a/library/src/otx/recipe/classification/h_label_cls/efficientnet_v2_kl.yaml b/library/src/otx/recipe/classification/h_label_cls/efficientnet_v2_kl.yaml new file mode 100644 index 0000000000..a460fac282 --- /dev/null +++ b/library/src/otx/recipe/classification/h_label_cls/efficientnet_v2_kl.yaml @@ -0,0 +1,125 @@ +task: H_LABEL_CLS +model: + class_path: otx.backend.native.models.classification.hlabel_models.timm_model.TimmModelHLabelCls + init_args: + kl_weight: 2.0 + model_name: tf_efficientnetv2_s.in21k + + optimizer: + class_path: torch.optim.SGD + init_args: + lr: 0.0071 + momentum: 0.9 + weight_decay: 0.0001 + + scheduler: + class_path: otx.backend.native.schedulers.LinearWarmupSchedulerCallable + init_args: + num_warmup_steps: 0 + main_scheduler_callable: + class_path: lightning.pytorch.cli.ReduceLROnPlateau + init_args: + mode: max + factor: 0.5 + patience: 3 + monitor: val/accuracy + +engine: + device: auto + +callback_monitor: val/accuracy + +data: ../../_base_/data/classification.yaml + +callbacks: + - class_path: otx.backend.native.callbacks.adaptive_early_stopping.EarlyStoppingWithWarmup + init_args: + warmup_iters: 750 + patience: 5 + mode: max + monitor: val/accuracy + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + dirpath: "" # use engine.work_dir + monitor: val/accuracy + mode: max + save_top_k: 1 + save_last: true + auto_insert_metric_name: false + filename: "checkpoints/epoch_{epoch:03d}" + +overrides: + reset: + - data.train_subset.transforms + + max_epochs: 90 + + data: + task: H_LABEL_CLS + data_format: datumaro + train_subset: + transforms: + - class_path: otx.data.transform_libs.torchvision.EfficientNetRandomCrop + init_args: + scale: $(input_size) + crop_ratio_range: + - 0.08 + - 1.0 + aspect_ratio_range: + - 0.75 + - 1.34 + - class_path: torchvision.transforms.v2.RandomPhotometricDistort + enable: false + init_args: + brightness: + - 0.875 + - 1.125 + contrast: + - 0.5 + - 1.5 + saturation: + - 0.5 + - 1.5 + hue: + - -0.05 + - 0.05 + p: 0.5 + - class_path: otx.data.transform_libs.torchvision.RandomAffine + enable: false + init_args: + max_rotate_degree: 10.0 + max_translate_ratio: 0.1 + scaling_ratio_range: + - 0.5 + - 1.5 + max_shear_degree: 2.0 + - class_path: otx.data.transform_libs.torchvision.RandomFlip + enable: true + init_args: + probability: 0.5 + - class_path: torchvision.transforms.v2.RandomVerticalFlip + enable: false + init_args: + p: 0.5 + - class_path: otx.data.transform_libs.torchvision.RandomGaussianBlur + enable: false + init_args: + kernel_size: 5 + sigma: + - 0.1 + - 2.0 + probability: 0.5 + - class_path: torchvision.transforms.v2.ToDtype + init_args: + dtype: ${as_torch_dtype:torch.float32} + scale: false + - class_path: otx.data.transform_libs.torchvision.RandomGaussianNoise + enable: false + init_args: + mean: 0.0 + sigma: 0.1 + probability: 0.5 + - class_path: torchvision.transforms.v2.Normalize + init_args: + mean: [123.675, 116.28, 103.53] + std: [58.395, 57.12, 57.375] diff --git a/library/tests/unit/backend/native/models/classification/classifier/test_kl_hlabel_classifier.py b/library/tests/unit/backend/native/models/classification/classifier/test_kl_hlabel_classifier.py new file mode 100644 index 0000000000..31801bc513 --- /dev/null +++ b/library/tests/unit/backend/native/models/classification/classifier/test_kl_hlabel_classifier.py @@ -0,0 +1,136 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +from torch import nn + +from otx.backend.native.models.classification.backbones import EfficientNetBackbone +from otx.backend.native.models.classification.classifier import HLabelClassifier, KLHLabelClassifier +from otx.backend.native.models.classification.heads import LinearClsHead, MultiLabelLinearClsHead +from otx.backend.native.models.classification.heads.hlabel_cls_head import HierarchicalClsHead +from otx.backend.native.models.classification.losses import AsymmetricAngularLossWithIgnore +from otx.backend.native.models.classification.necks.gap import GlobalAveragePooling + + +class TestHierHead(HierarchicalClsHead): + """Lightweight hierarchical head for tests, compatible with H/KLH classifiers.""" + + def __init__(self, in_channels: int, head_class_sizes=(3, 3)): + # e.g., two heads with 3 classes each -> total classes = 6 + self.head_class_sizes = list(head_class_sizes) + num_multiclass_heads = len(self.head_class_sizes) + num_multilabel_classes = 0 + num_single_label_classes = sum(self.head_class_sizes) + num_classes = num_single_label_classes + + # Build per-head logit ranges, e.g. [(0,3), (3,6)] + start = 0 + ranges = {} + for idx, k in enumerate(self.head_class_sizes): + ranges[str(idx)] = (start, start + k) + start += k + + empty_multiclass_head_indices = [] + + # Call the real base class with all required args + super().__init__( + num_classes=num_classes, + in_channels=in_channels, + num_multiclass_heads=num_multiclass_heads, + num_multilabel_classes=num_multilabel_classes, + head_idx_to_logits_range=ranges, + num_single_label_classes=num_single_label_classes, + empty_multiclass_head_indices=empty_multiclass_head_indices, + ) + + # Simple linear head over pooled features -> logits + self.classifier = nn.Linear(in_channels, num_classes) + self._head_ranges = ranges + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if isinstance(x, (tuple, list)): + x = x[0] + return self.classifier(x) + + +class TestKLHLabelClassifier: + @pytest.fixture( + params=[ + (LinearClsHead, nn.CrossEntropyLoss, "fxt_multiclass_cls_batch_data_entity"), + (MultiLabelLinearClsHead, AsymmetricAngularLossWithIgnore, "fxt_multilabel_cls_batch_data_entity"), + ], + ids=["multiclass", "multilabel"], + ) + def fxt_model_and_inputs(self, request): + head_class_sizes = (3, 3) + input_fxt_name = "fxt_multiclass_cls_batch_data_entity" + backbone = EfficientNetBackbone(model_name="efficientnet_b0") + neck = GlobalAveragePooling(dim=2) + head = TestHierHead(in_channels=backbone.num_features, head_class_sizes=head_class_sizes) + loss = nn.CrossEntropyLoss() + fxt_input = request.getfixturevalue(input_fxt_name) + level = len(head_class_sizes) + fxt_labels = torch.stack(fxt_input.labels) + fxt_labels = fxt_labels.repeat(1, level) + return (backbone, neck, head, loss, fxt_input.images, fxt_labels) + + def test_forward(self, fxt_model_and_inputs): + backbone, neck, head, loss, images, labels = fxt_model_and_inputs + + model = KLHLabelClassifier( + backbone=backbone, + neck=neck, + head=head, + multiclass_loss=loss, + kl_weight=1, + ) + + output = model(images, labels, mode="loss") + assert isinstance(output, torch.Tensor) + + def test_klh_loss_greater_than_hlabel(self, fxt_model_and_inputs): + """KLHLabelClassifier should have strictly larger loss than HLabelClassifier + when kl_weight > 0 and there are >= 2 multiclass heads.""" + backbone, neck, head, loss, images, labels = fxt_model_and_inputs + h_model = HLabelClassifier( + backbone=backbone, + neck=neck, + head=head, + multiclass_loss=loss, + ) + kl_h_model = KLHLabelClassifier( + backbone=backbone, + neck=neck, + head=head, + multiclass_loss=loss, + kl_weight=2.0, + ) + + h_loss = h_model.loss(images, labels) + klh_loss = kl_h_model.loss(images, labels) + + print(f"HLabel loss: {h_loss.item():.6f} | KLH loss: {klh_loss.item():.6f}") + assert klh_loss > h_loss, "Expected KLH loss to be greater due to added KL term" + + def test_klh_weight_zero_match_hlabel(self, fxt_model_and_inputs): + """With kl_weight == 0, KLH loss should match H label loss (within tolerance).""" + backbone, neck, head, loss, images, labels = fxt_model_and_inputs + h_model = HLabelClassifier( + backbone=backbone, + neck=neck, + head=head, + multiclass_loss=loss, + ) + kl_h_model = KLHLabelClassifier( + backbone=backbone, + neck=neck, + head=head, + multiclass_loss=loss, + kl_weight=0, + ) + h_loss = h_model.loss(images, labels) + klh_loss = kl_h_model.loss(images, labels) + + print(f"[kl=0] HLabel loss: {h_loss.item():.6f} | KLH loss: {klh_loss.item():.6f}") + assert torch.allclose(klh_loss, h_loss, atol=1e-6), "With kl_weight=0, losses should match" diff --git a/library/tests/unit/backend/native/models/classification/losses/test_tree_path_kl_divergence.py b/library/tests/unit/backend/native/models/classification/losses/test_tree_path_kl_divergence.py new file mode 100644 index 0000000000..673a22fb7c --- /dev/null +++ b/library/tests/unit/backend/native/models/classification/losses/test_tree_path_kl_divergence.py @@ -0,0 +1,135 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +from torch.nn import functional + +from otx.backend.native.models.classification.losses.tree_path_kl_divergence_loss import TreePathKLDivergenceLoss + + +@pytest.mark.parametrize( + ("levels", "classes_per_level"), + [ + (2, [3, 5]), + (3, [2, 3, 4]), + ], +) +def test_forward_scalar_and_finite(levels, classes_per_level): + torch.manual_seed(0) + batch = 4 + logits_list = [torch.randn(batch, c) for c in classes_per_level] + targets = torch.stack([torch.randint(0, c, (batch,)) for c in classes_per_level], dim=1) + + loss_fn = TreePathKLDivergenceLoss(reduction="batchmean") + loss = loss_fn(logits_list, targets) + assert loss.ndim == 0 + assert torch.isfinite(loss) + assert loss.item() >= -1e-7 + + +def test_backward_produces_grads(): + batch = 3 + channel = [4, 6] + logits_list = [torch.randn(batch, c, requires_grad=True) for c in channel] + targets = torch.stack([torch.randint(0, c, (batch,)) for c in channel], dim=1) + + loss = TreePathKLDivergenceLoss()(logits_list, targets) + loss.backward() + for logit in logits_list: + assert logit.grad is not None + assert torch.isfinite(logit.grad).all() + + +def test_alignment_vs_misalignment_loss(): + batch = 2 + channel0, channel1 = 3, 4 + targets = torch.tensor([[0, 1], [2, 3]]) + + # Aligned: boost GT logits + aligned0 = torch.zeros(batch, channel0) + aligned1 = torch.zeros(batch, channel1) + aligned0[torch.arange(batch), targets[:, 0]] = 5.0 + aligned1[torch.arange(batch), targets[:, 1]] = 5.0 + + # Misaligned: boost wrong logits + mis0 = torch.zeros(batch, channel0) + mis1 = torch.zeros(batch, channel1) + mis0[torch.arange(batch), (targets[:, 0] + 1) % channel0] = 5.0 + mis1[torch.arange(batch), (targets[:, 1] + 1) % channel1] = 5.0 + + loss_fn = TreePathKLDivergenceLoss() + loss_aligned = loss_fn([aligned0, aligned1], targets) + loss_misaligned = loss_fn([mis0, mis1], targets) + assert loss_aligned < loss_misaligned + + +def test_single_level_exact_value(): + """ + With a single level, KL reduces to CE between predicted softmax and one-hot target. + We check exact value against F.cross_entropy. + """ + + logits = torch.tensor([[2.0, 0.0, -1.0], [0.5, 1.0, -0.5]]) + targets = torch.tensor([[0], [2]]) # shape [B,1] + + # TreePathKLP + loss_fn = TreePathKLDivergenceLoss(reduction="batchmean") + kl_loss = loss_fn([logits], targets) + + # CrossEntropy with one-hot is same as NLLLoss(log_softmax) + ce_loss = functional.cross_entropy(logits, targets.view(-1), reduction="mean") + + assert torch.allclose(kl_loss, ce_loss, atol=1e-6) + + +def test_multi_level_exact_value_batchmean(): + """ + Exact numerical check for L=2 levels with 'batchmean' reduction. + + Loss per sample (PyTorch KLDivLoss): + KL(p || q) = sum_j p_j * (log(p_j) - log(q_j)) + where input to KLDivLoss is log(q_j) (our model log_probs), + and the target is p_j (our constructed target distribution). + With reduction='batchmean', PyTorch divides the total sum by batch size. + """ + + # Use double for better numerical agreement + batch = 2 + l0, l1 = 2, 3 + logits0 = torch.tensor([[2.0, -1.0], [0.0, 1.0]], dtype=torch.float64) # [B, l0] + logits1 = torch.tensor([[0.5, 0.0, -0.5], [-1.0, 2.0, 0.5]], dtype=torch.float64) # [B, l1] + + # Ground-truth indices per level + # sample 0: level0->0, level1->1 + # sample 1: level0->1, level1->2 + targets = torch.tensor([[0, 1], [1, 2]], dtype=torch.long) # [B, 2] + level = 2 # number of levels + + # Model log probs over concatenated heads + concat = torch.cat([logits0, logits1], dim=1) # [B, l0+l1] + log_q = functional.log_softmax(concat, dim=1) # log(q_j) + + # Build target distribution p: 1/level at each GT index, 0 elsewhere + p = torch.zeros_like(log_q, dtype=torch.float64) + offset = 0 + for num_c, tgt_l in zip([l0, l1], targets.T): + rows = torch.arange(batch) + p[rows, offset + tgt_l] = 1.0 / level + offset += num_c + + # Manual KL with 'batchmean' reduction: + # sum_i sum_j p_ij * (log p_ij - log q_ij) / batch + # (avoid log(0) by masking since p is sparse) + mask = p > 0 + log_p = torch.zeros_like(p) + log_p[mask] = torch.log(p[mask]) + manual_kl = (p * (log_p - log_q)).sum() / batch + + # Loss under test (must match manual) + loss_fn = TreePathKLDivergenceLoss(reduction="batchmean") + test_kl = loss_fn([logits0.float(), logits1.float()], targets) + + assert torch.allclose( + test_kl.double(), manual_kl, atol=1e-8 + ), f"manual={manual_kl.item():.12f} vs loss={test_kl.item():.12f}" diff --git a/library/tests/unit/backend/native/utils/test_api.py b/library/tests/unit/backend/native/utils/test_api.py index 2779a8b43a..c185710e6f 100644 --- a/library/tests/unit/backend/native/utils/test_api.py +++ b/library/tests/unit/backend/native/utils/test_api.py @@ -30,6 +30,7 @@ def test_list_models_pattern() -> None: target = [ "efficientnet_b0", "efficientnet_v2", + "efficientnet_v2_kl", "maskrcnn_efficientnetb2b", "maskrcnn_efficientnetb2b_tile", "tv_efficientnet_b3",