11import torch
2- import torch .nn as nn
2+ import torch .nn . functional as F
33
4+ from ...utils import tensor_one_hot
45from ..weighted_base_loss import WeightedBaseLoss
56
7+ __all__ = ["CELoss" ]
8+
69
710class CELoss (WeightedBaseLoss ):
811 def __init__ (
9- self , edge_weight : float = None , class_weights : torch .Tensor = None , ** kwargs
12+ self ,
13+ apply_sd : bool = False ,
14+ apply_ls : bool = False ,
15+ apply_svls : bool = False ,
16+ edge_weight : float = None ,
17+ class_weights : torch .Tensor = None ,
18+ ** kwargs ,
1019 ) -> None :
1120 """Cross-Entropy loss with weighting.
1221
1322 Parameters
1423 ----------
15- edge_weight : float, default=none
24+ apply_sd : bool, default=False
25+ If True, Spectral decoupling regularization will be applied to the
26+ loss matrix.
27+ apply_ls : bool, default=False
28+ If True, Label smoothing will be applied to the target.
29+ apply_svls : bool, default=False
30+ If True, spatially varying label smoothing will be applied to the target
31+ edge_weight : float, default=None
1632 Weight that is added to object borders.
1733 class_weights : torch.Tensor, default=None
1834 Class weights. A tensor of shape (n_classes,).
1935 """
20- super ().__init__ (class_weights , edge_weight )
21- self .loss = nn . CrossEntropyLoss ( reduction = "none" , weight = class_weights )
36+ super ().__init__ (apply_sd , apply_ls , apply_svls , class_weights , edge_weight )
37+ self .eps = 1e-8
2238
2339 def forward (
2440 self ,
2541 yhat : torch .Tensor ,
2642 target : torch .Tensor ,
2743 target_weight : torch .Tensor = None ,
28- ** kwargs
44+ ** kwargs ,
2945 ) -> torch .Tensor :
3046 """Compute the cross entropy loss.
3147
@@ -43,7 +59,28 @@ def forward(
4359 torch.Tensor:
4460 Computed CE loss (scalar).
4561 """
46- loss = self .loss (yhat , target ) # (B, H, W)
62+ input_soft = F .softmax (yhat , dim = 1 ) + self .eps # (B, C, H, W)
63+ num_classes = yhat .shape [1 ]
64+ target_one_hot = tensor_one_hot (target , num_classes ) # (B, C, H, W)
65+ assert target_one_hot .shape == yhat .shape
66+
67+ if self .apply_svls :
68+ target_one_hot = self .apply_svls_to_target (
69+ target_one_hot , num_classes , ** kwargs
70+ )
71+
72+ if self .apply_ls :
73+ target_one_hot = self .apply_ls_to_target (
74+ target_one_hot , num_classes , ** kwargs
75+ )
76+
77+ loss = - torch .sum (target_one_hot * torch .log (input_soft ), dim = 1 ) # (B, H, W)
78+
79+ if self .apply_sd :
80+ loss = self .apply_spectral_decouple (loss , yhat )
81+
82+ if self .class_weights is not None :
83+ loss = self .apply_class_weights (loss , target )
4784
4885 if self .edge_weight is not None :
4986 loss = self .apply_edge_weights (loss , target_weight )
0 commit comments