1+ from typing import Optional , Sequence
2+
13import torch
24from torch import Tensor
35from torch import nn
@@ -18,30 +20,40 @@ class FocalLoss(nn.Module):
1820 """
1921
2022 def __init__ (self ,
21- alpha : Tensor = None ,
23+ alpha : Optional [ Tensor ] = None ,
2224 gamma : float = 0. ,
2325 reduction : str = 'mean' ,
2426 ignore_index : int = - 100 ):
2527 """Constructor.
28+
2629 Args:
27- alpha (Tensor): Weights for each class.
28- gamma (float): A constant, as described in the paper.
30+ alpha (Tensor, optional): Weights for each class. Defaults to None.
31+ gamma (float, optional): A constant, as described in the paper.
32+ Defaults to 0.
2933 reduction (str, optional): 'mean', 'sum' or 'none'.
3034 Defaults to 'mean'.
3135 ignore_index (int, optional): class label to ignore.
36+ Defaults to -100.
3237 """
38+ if reduction not in ('mean' , 'sum' , 'none' ):
39+ raise ValueError (
40+ 'Reduction must be one of: "mean", "sum", "none".' )
41+
3342 super ().__init__ ()
43+ self .alpha = alpha
3444 self .gamma = gamma
45+ self .ignore_index = ignore_index
46+ self .reduction = reduction
47+
3548 self .nll_loss = nn .NLLLoss (
3649 weight = alpha , reduction = 'none' , ignore_index = ignore_index )
3750
38- self .ignore_index = ignore_index
39-
40- if reduction in ('mean' , 'sum' , 'none' ):
41- self .reduction = reduction
42- else :
43- raise ValueError (
44- 'Reduction must be one of: "mean", "sum", "none".' )
51+ def __repr__ (self ):
52+ arg_keys = ['alpha' , 'gamma' , 'ignore_index' , 'reduction' ]
53+ arg_vals = [self .__dict__ [k ] for k in arg_keys ]
54+ arg_strs = [f'{ k } ={ v } ' for k , v in zip (arg_keys , arg_vals )]
55+ arg_str = ', ' .join (arg_strs )
56+ return f'{ type (self ).__name__ } ({ arg_str } )'
4557
4658 def forward (self , x : Tensor , y : Tensor ) -> Tensor :
4759 if x .ndim > 2 :
@@ -58,6 +70,7 @@ def forward(self, x: Tensor, y: Tensor) -> Tensor:
5870 x = x [unignored_mask ]
5971
6072 # compute weighted cross entropy term: -alpha * log(pt)
73+ # (alpha is already part of self.nll_loss)
6174 log_p = F .log_softmax (x , dim = - 1 )
6275 ce = self .nll_loss (log_p , y )
6376
@@ -80,15 +93,38 @@ def forward(self, x: Tensor, y: Tensor) -> Tensor:
8093 return loss
8194
8295
83- def focal_loss (alpha = None , gamma = 0. , reduction = 'mean' , ignore_index = - 100 ,
84- device = 'cpu' , dtype = torch .float32 ):
85- if not ((alpha is None ) or isinstance (alpha , torch .Tensor )):
86- alpha = torch .tensor (alpha , device = device , dtype = dtype )
87-
88- fl = FocalLoss (
89- alpha = alpha ,
90- gamma = gamma ,
91- reduction = reduction ,
92- ignore_index = ignore_index
93- )
94- return fl
96+ def focal_loss (alpha : Optional [Sequence ] = None ,
97+ gamma : float = 0. ,
98+ reduction : str = 'mean' ,
99+ ignore_index : int = - 100 ,
100+ device = 'cpu' ,
101+ dtype = torch .float32 ) -> FocalLoss :
102+ """Factory function for FocalLoss.
103+
104+ Args:
105+ alpha (Sequence, optional): Weights for each class. Will be converted
106+ to a Tensor if not None. Defaults to None.
107+ gamma (float, optional): A constant, as described in the paper.
108+ Defaults to 0.
109+ reduction (str, optional): 'mean', 'sum' or 'none'.
110+ Defaults to 'mean'.
111+ ignore_index (int, optional): class label to ignore.
112+ Defaults to -100.
113+ device (str, optional): Device to move alpha to. Defaults to 'cpu'.
114+ dtype (torch.dtype, optional): dtype to cast alpha to.
115+ Defaults to torch.float32.
116+
117+ Returns:
118+ A FocalLoss object
119+ """
120+ if alpha is not None :
121+ if not isinstance (alpha , Tensor ):
122+ alpha = torch .tensor (alpha )
123+ alpha = alpha .to (device = device , dtype = dtype )
124+
125+ fl = FocalLoss (
126+ alpha = alpha ,
127+ gamma = gamma ,
128+ reduction = reduction ,
129+ ignore_index = ignore_index )
130+ return fl
0 commit comments