Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""Module for defining TreePathKLDivergenceLoss."""

from __future__ import annotations
from typing import List

import torch
import torch.nn as nn
import torch.nn.functional as F


class TreePathKLDivergenceLoss(nn.Module):
"""
KL divergence between model distribution over concatenated heads and a
target distribution that allocates equal mass to the ground-truth class
at each hierarchy level.

Inputs:
logits_list: list of tensors [B, C_l], ordered from root -> leaf
targets: LongTensor [B, L] with per-level GT indices (L == len(logits_list))

The target distribution places 1/L probability on the GT index for each level,
and 0 elsewhere, then uses KLDivLoss(log_softmax(logits), target_probs).
"""

def __init__(self, reduction: str = "batchmean", loss_weight: float = 1.0):
super().__init__()
self.reduction = reduction
self.loss_weight = loss_weight
self.kl_div = nn.KLDivLoss(reduction=self.reduction)

def forward(self, logits_list: List[torch.Tensor], targets: torch.Tensor) -> torch.Tensor:
assert isinstance(logits_list, (list, tuple)) and len(logits_list) > 0, "logits_list must be non-empty"
num_levels = len(logits_list)

# concat logits across all levels
dims = [t.size(1) for t in logits_list]
logits_concat = torch.cat(logits_list, dim=1) # [B, sum(C_l)]
log_probs = F.log_softmax(logits_concat, dim=1) # [B, sum(C_l)]

# build sparse target distribution with 1/L at each GT index
B = log_probs.size(0)
tgt = torch.zeros_like(log_probs) # [B, sum(C_l)]
offset = 0
for num_c, tgt_l in zip(dims, targets.T): # level-by-level
idx_rows = torch.arange(B, device=log_probs.device)
tgt[idx_rows, offset + tgt_l] = 1.0 / num_levels
offset += num_c

kl = self.kl_div(log_probs, tgt)
return self.loss_weight * kl
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import torch
import pytest
import torch.nn.functional as F

from otx.backend.native.models.classification.losses.tree_path_KL_divergence_loss import TreePathKLDivergenceLoss


@pytest.mark.parametrize("levels,classes_per_level", [
(2, [3, 5]),
(3, [2, 3, 4]),
])
def test_forward_scalar_and_finite(levels, classes_per_level):
torch.manual_seed(0)
B = 4
logits_list = [torch.randn(B, c) for c in classes_per_level]
targets = torch.stack([torch.randint(0, c, (B,)) for c in classes_per_level], dim=1)

loss_fn = TreePathKLDivergenceLoss(reduction="batchmean")
loss = loss_fn(logits_list, targets)
assert loss.ndim == 0
assert torch.isfinite(loss)
assert loss.item() >= -1e-7


def test_backward_produces_grads():
B = 3
C = [4, 6]
logits_list = [torch.randn(B, c, requires_grad=True) for c in C]
targets = torch.stack([torch.randint(0, c, (B,)) for c in C], dim=1)

loss = TreePathKLDivergenceLoss()(logits_list, targets)
loss.backward()
for logit in logits_list:
assert logit.grad is not None
assert torch.isfinite(logit.grad).all()


def test_alignment_vs_misalignment_loss():
B = 2
C0, C1 = 3, 4
targets = torch.tensor([[0, 1], [2, 3]])

# Aligned: boost GT logits
aligned0 = torch.zeros(B, C0)
aligned1 = torch.zeros(B, C1)
aligned0[torch.arange(B), targets[:, 0]] = 5.0
aligned1[torch.arange(B), targets[:, 1]] = 5.0

# Misaligned: boost wrong logits
mis0 = torch.zeros(B, C0)
mis1 = torch.zeros(B, C1)
mis0[torch.arange(B), (targets[:, 0] + 1) % C0] = 5.0
mis1[torch.arange(B), (targets[:, 1] + 1) % C1] = 5.0

loss_fn = TreePathKLDivergenceLoss()
loss_aligned = loss_fn([aligned0, aligned1], targets)
loss_misaligned = loss_fn([mis0, mis1], targets)
assert loss_aligned < loss_misaligned


def test_single_level_exact_value():
"""
With a single level, KL reduces to CE between predicted softmax and one-hot target.
We check exact value against F.cross_entropy.
"""

B, C = 2, 3
logits = torch.tensor([[2.0, 0.0, -1.0],
[0.5, 1.0, -0.5]])
targets = torch.tensor([[0], [2]]) # shape [B,1]

# TreePathKL
loss_fn = TreePathKLDivergenceLoss(reduction="batchmean")
kl_loss = loss_fn([logits], targets)

# CrossEntropy with one-hot is same as NLLLoss(log_softmax)
ce_loss = F.cross_entropy(logits, targets.view(-1), reduction="mean")

assert torch.allclose(kl_loss, ce_loss, atol=1e-6)

def test_multi_level_exact_value_batchmean():
"""
Exact numerical check for L=2 levels with 'batchmean' reduction.

Loss per sample (PyTorch KLDivLoss):
KL(p || q) = sum_j p_j * (log(p_j) - log(q_j))
where input to KLDivLoss is log(q_j) (our model log_probs),
and the target is p_j (our constructed target distribution).
With reduction='batchmean', PyTorch divides the total sum by batch size.
"""

# Use double for better numerical agreement
B = 2
l0, l1 = 2, 3
logits0 = torch.tensor([[2.0, -1.0],
[0.0, 1.0]], dtype=torch.float64) # [B, l0]
logits1 = torch.tensor([[ 0.5, 0.0, -0.5],
[-1.0, 2.0, 0.5]], dtype=torch.float64) # [B, l1]
logits_list = [logits0, logits1]

# Ground-truth indices per level
# sample 0: level0->0, level1->1
# sample 1: level0->1, level1->2
targets = torch.tensor([[0, 1],
[1, 2]], dtype=torch.long) # [B, 2]
L = 2 # number of levels

# Model log probs over concatenated heads
concat = torch.cat([logits0, logits1], dim=1) # [B, l0+l1]
log_q = F.log_softmax(concat, dim=1) # log(q_j)

# Build target distribution p: 1/L at each GT index, 0 elsewhere
p = torch.zeros_like(log_q, dtype=torch.float64)
offset = 0
for num_c, tgt_l in zip([l0, l1], targets.T):
rows = torch.arange(B)
p[rows, offset + tgt_l] = 1.0 / L
offset += num_c

# Manual KL with 'batchmean' reduction:
# sum_i sum_j p_ij * (log p_ij - log q_ij) / B
# (avoid log(0) by masking since p is sparse)
mask = p > 0
log_p = torch.zeros_like(p)
log_p[mask] = torch.log(p[mask])
manual_kl = (p * (log_p - log_q)).sum() / B

# Loss under test (must match manual)
loss_fn = TreePathKLDivergenceLoss(reduction="batchmean")
test_kl = loss_fn([logits0.float(), logits1.float()], targets)

assert torch.allclose(test_kl.double(), manual_kl, atol=1e-8), (
f"manual={manual_kl.item():.12f} vs loss={test_kl.item():.12f}"
)
Loading