Skip to content

Commit 70da669

Browse files
authored
Merge pull request #3 from AdeelH/repr
Add __repr__ + refactoring + docs + type hints
2 parents 0ce46b2 + d66a6b0 commit 70da669

File tree

3 files changed

+65
-22
lines changed

3 files changed

+65
-22
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11

22
*.pyc
3+
4+
.vscode/

focal_loss.py

Lines changed: 58 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Optional, Sequence
2+
13
import torch
24
from torch import Tensor
35
from torch import nn
@@ -18,30 +20,40 @@ class FocalLoss(nn.Module):
1820
"""
1921

2022
def __init__(self,
21-
alpha: Tensor = None,
23+
alpha: Optional[Tensor] = None,
2224
gamma: float = 0.,
2325
reduction: str = 'mean',
2426
ignore_index: int = -100):
2527
"""Constructor.
28+
2629
Args:
27-
alpha (Tensor): Weights for each class.
28-
gamma (float): A constant, as described in the paper.
30+
alpha (Tensor, optional): Weights for each class. Defaults to None.
31+
gamma (float, optional): A constant, as described in the paper.
32+
Defaults to 0.
2933
reduction (str, optional): 'mean', 'sum' or 'none'.
3034
Defaults to 'mean'.
3135
ignore_index (int, optional): class label to ignore.
36+
Defaults to -100.
3237
"""
38+
if reduction not in ('mean', 'sum', 'none'):
39+
raise ValueError(
40+
'Reduction must be one of: "mean", "sum", "none".')
41+
3342
super().__init__()
43+
self.alpha = alpha
3444
self.gamma = gamma
45+
self.ignore_index = ignore_index
46+
self.reduction = reduction
47+
3548
self.nll_loss = nn.NLLLoss(
3649
weight=alpha, reduction='none', ignore_index=ignore_index)
3750

38-
self.ignore_index = ignore_index
39-
40-
if reduction in ('mean', 'sum', 'none'):
41-
self.reduction = reduction
42-
else:
43-
raise ValueError(
44-
'Reduction must be one of: "mean", "sum", "none".')
51+
def __repr__(self):
52+
arg_keys = ['alpha', 'gamma', 'ignore_index', 'reduction']
53+
arg_vals = [self.__dict__[k] for k in arg_keys]
54+
arg_strs = [f'{k}={v}' for k, v in zip(arg_keys, arg_vals)]
55+
arg_str = ', '.join(arg_strs)
56+
return f'{type(self).__name__}({arg_str})'
4557

4658
def forward(self, x: Tensor, y: Tensor) -> Tensor:
4759
if x.ndim > 2:
@@ -58,6 +70,7 @@ def forward(self, x: Tensor, y: Tensor) -> Tensor:
5870
x = x[unignored_mask]
5971

6072
# compute weighted cross entropy term: -alpha * log(pt)
73+
# (alpha is already part of self.nll_loss)
6174
log_p = F.log_softmax(x, dim=-1)
6275
ce = self.nll_loss(log_p, y)
6376

@@ -80,15 +93,38 @@ def forward(self, x: Tensor, y: Tensor) -> Tensor:
8093
return loss
8194

8295

83-
def focal_loss(alpha=None, gamma=0., reduction='mean', ignore_index=-100,
84-
device='cpu', dtype=torch.float32):
85-
if not ((alpha is None) or isinstance(alpha, torch.Tensor)):
86-
alpha = torch.tensor(alpha, device=device, dtype=dtype)
87-
88-
fl = FocalLoss(
89-
alpha=alpha,
90-
gamma=gamma,
91-
reduction=reduction,
92-
ignore_index=ignore_index
93-
)
94-
return fl
96+
def focal_loss(alpha: Optional[Sequence] = None,
97+
gamma: float = 0.,
98+
reduction: str = 'mean',
99+
ignore_index: int = -100,
100+
device='cpu',
101+
dtype=torch.float32) -> FocalLoss:
102+
"""Factory function for FocalLoss.
103+
104+
Args:
105+
alpha (Sequence, optional): Weights for each class. Will be converted
106+
to a Tensor if not None. Defaults to None.
107+
gamma (float, optional): A constant, as described in the paper.
108+
Defaults to 0.
109+
reduction (str, optional): 'mean', 'sum' or 'none'.
110+
Defaults to 'mean'.
111+
ignore_index (int, optional): class label to ignore.
112+
Defaults to -100.
113+
device (str, optional): Device to move alpha to. Defaults to 'cpu'.
114+
dtype (torch.dtype, optional): dtype to cast alpha to.
115+
Defaults to torch.float32.
116+
117+
Returns:
118+
A FocalLoss object
119+
"""
120+
if alpha is not None:
121+
if not isinstance(alpha, Tensor):
122+
alpha = torch.tensor(alpha)
123+
alpha = alpha.to(device=device, dtype=dtype)
124+
125+
fl = FocalLoss(
126+
alpha=alpha,
127+
gamma=gamma,
128+
reduction=reduction,
129+
ignore_index=ignore_index)
130+
return fl

setup.cfg

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
[yapf]
2+
based_on_style = pep8
3+
DEDENT_CLOSING_BRACKETS = false
4+
SPLIT_COMPLEX_COMPREHENSION = true
5+
COALESCE_BRACKETS = true

0 commit comments

Comments
 (0)