Skip to content

Commit d66a6b0

Browse files
committed
refactoring
1 parent 7788b98 commit d66a6b0

File tree

2 files changed

+46
-15
lines changed

2 files changed

+46
-15
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11

22
*.pyc
3+
4+
.vscode/

focal_loss.py

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Optional, Sequence
2+
13
import torch
24
from torch import Tensor
35
from torch import nn
@@ -18,17 +20,20 @@ class FocalLoss(nn.Module):
1820
"""
1921

2022
def __init__(self,
21-
alpha: Tensor = None,
23+
alpha: Optional[Tensor] = None,
2224
gamma: float = 0.,
2325
reduction: str = 'mean',
2426
ignore_index: int = -100):
2527
"""Constructor.
28+
2629
Args:
27-
alpha (Tensor): Weights for each class.
28-
gamma (float): A constant, as described in the paper.
30+
alpha (Tensor, optional): Weights for each class. Defaults to None.
31+
gamma (float, optional): A constant, as described in the paper.
32+
Defaults to 0.
2933
reduction (str, optional): 'mean', 'sum' or 'none'.
3034
Defaults to 'mean'.
3135
ignore_index (int, optional): class label to ignore.
36+
Defaults to -100.
3237
"""
3338
if reduction not in ('mean', 'sum', 'none'):
3439
raise ValueError(
@@ -65,6 +70,7 @@ def forward(self, x: Tensor, y: Tensor) -> Tensor:
6570
x = x[unignored_mask]
6671

6772
# compute weighted cross entropy term: -alpha * log(pt)
73+
# (alpha is already part of self.nll_loss)
6874
log_p = F.log_softmax(x, dim=-1)
6975
ce = self.nll_loss(log_p, y)
7076

@@ -87,15 +93,38 @@ def forward(self, x: Tensor, y: Tensor) -> Tensor:
8793
return loss
8894

8995

90-
def focal_loss(alpha=None, gamma=0., reduction='mean', ignore_index=-100,
91-
device='cpu', dtype=torch.float32):
92-
if not ((alpha is None) or isinstance(alpha, torch.Tensor)):
93-
alpha = torch.tensor(alpha, device=device, dtype=dtype)
94-
95-
fl = FocalLoss(
96-
alpha=alpha,
97-
gamma=gamma,
98-
reduction=reduction,
99-
ignore_index=ignore_index
100-
)
101-
return fl
96+
def focal_loss(alpha: Optional[Sequence] = None,
97+
gamma: float = 0.,
98+
reduction: str = 'mean',
99+
ignore_index: int = -100,
100+
device='cpu',
101+
dtype=torch.float32) -> FocalLoss:
102+
"""Factory function for FocalLoss.
103+
104+
Args:
105+
alpha (Sequence, optional): Weights for each class. Will be converted
106+
to a Tensor if not None. Defaults to None.
107+
gamma (float, optional): A constant, as described in the paper.
108+
Defaults to 0.
109+
reduction (str, optional): 'mean', 'sum' or 'none'.
110+
Defaults to 'mean'.
111+
ignore_index (int, optional): class label to ignore.
112+
Defaults to -100.
113+
device (str, optional): Device to move alpha to. Defaults to 'cpu'.
114+
dtype (torch.dtype, optional): dtype to cast alpha to.
115+
Defaults to torch.float32.
116+
117+
Returns:
118+
A FocalLoss object
119+
"""
120+
if alpha is not None:
121+
if not isinstance(alpha, Tensor):
122+
alpha = torch.tensor(alpha)
123+
alpha = alpha.to(device=device, dtype=dtype)
124+
125+
fl = FocalLoss(
126+
alpha=alpha,
127+
gamma=gamma,
128+
reduction=reduction,
129+
ignore_index=ignore_index)
130+
return fl

0 commit comments

Comments
 (0)