1+ from typing import Optional , Sequence
2+
13import torch
24from torch import Tensor
35from torch import nn
@@ -18,17 +20,20 @@ class FocalLoss(nn.Module):
1820 """
1921
2022 def __init__ (self ,
21- alpha : Tensor = None ,
23+ alpha : Optional [ Tensor ] = None ,
2224 gamma : float = 0. ,
2325 reduction : str = 'mean' ,
2426 ignore_index : int = - 100 ):
2527 """Constructor.
28+
2629 Args:
27- alpha (Tensor): Weights for each class.
28- gamma (float): A constant, as described in the paper.
30+ alpha (Tensor, optional): Weights for each class. Defaults to None.
31+ gamma (float, optional): A constant, as described in the paper.
32+ Defaults to 0.
2933 reduction (str, optional): 'mean', 'sum' or 'none'.
3034 Defaults to 'mean'.
3135 ignore_index (int, optional): class label to ignore.
36+ Defaults to -100.
3237 """
3338 if reduction not in ('mean' , 'sum' , 'none' ):
3439 raise ValueError (
@@ -65,6 +70,7 @@ def forward(self, x: Tensor, y: Tensor) -> Tensor:
6570 x = x [unignored_mask ]
6671
6772 # compute weighted cross entropy term: -alpha * log(pt)
73+ # (alpha is already part of self.nll_loss)
6874 log_p = F .log_softmax (x , dim = - 1 )
6975 ce = self .nll_loss (log_p , y )
7076
@@ -87,15 +93,38 @@ def forward(self, x: Tensor, y: Tensor) -> Tensor:
8793 return loss
8894
8995
90- def focal_loss (alpha = None , gamma = 0. , reduction = 'mean' , ignore_index = - 100 ,
91- device = 'cpu' , dtype = torch .float32 ):
92- if not ((alpha is None ) or isinstance (alpha , torch .Tensor )):
93- alpha = torch .tensor (alpha , device = device , dtype = dtype )
94-
95- fl = FocalLoss (
96- alpha = alpha ,
97- gamma = gamma ,
98- reduction = reduction ,
99- ignore_index = ignore_index
100- )
101- return fl
96+ def focal_loss (alpha : Optional [Sequence ] = None ,
97+ gamma : float = 0. ,
98+ reduction : str = 'mean' ,
99+ ignore_index : int = - 100 ,
100+ device = 'cpu' ,
101+ dtype = torch .float32 ) -> FocalLoss :
102+ """Factory function for FocalLoss.
103+
104+ Args:
105+ alpha (Sequence, optional): Weights for each class. Will be converted
106+ to a Tensor if not None. Defaults to None.
107+ gamma (float, optional): A constant, as described in the paper.
108+ Defaults to 0.
109+ reduction (str, optional): 'mean', 'sum' or 'none'.
110+ Defaults to 'mean'.
111+ ignore_index (int, optional): class label to ignore.
112+ Defaults to -100.
113+ device (str, optional): Device to move alpha to. Defaults to 'cpu'.
114+ dtype (torch.dtype, optional): dtype to cast alpha to.
115+ Defaults to torch.float32.
116+
117+ Returns:
118+ A FocalLoss object
119+ """
120+ if alpha is not None :
121+ if not isinstance (alpha , Tensor ):
122+ alpha = torch .tensor (alpha )
123+ alpha = alpha .to (device = device , dtype = dtype )
124+
125+ fl = FocalLoss (
126+ alpha = alpha ,
127+ gamma = gamma ,
128+ reduction = reduction ,
129+ ignore_index = ignore_index )
130+ return fl
0 commit comments