Skip to content

Commit f0d5355

Browse files
authored
Merge pull request #2 from AdeelH/ignore_index
Fix ignore_index bug
2 parents de74657 + 22a239b commit f0d5355

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

focal_loss.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ def __init__(self,
3535
self.nll_loss = nn.NLLLoss(
3636
weight=alpha, reduction='none', ignore_index=ignore_index)
3737

38+
self.ignore_index = ignore_index
39+
3840
if reduction in ('mean', 'sum', 'none'):
3941
self.reduction = reduction
4042
else:
@@ -49,6 +51,12 @@ def forward(self, x: Tensor, y: Tensor) -> Tensor:
4951
# (N, d1, d2, ..., dK) --> (N * d1 * ... * dK,)
5052
y = y.view(-1)
5153

54+
unignored_mask = y != self.ignore_index
55+
y = y[unignored_mask]
56+
if len(y) == 0:
57+
return 0.
58+
x = x[unignored_mask]
59+
5260
# compute weighted cross entropy term: -alpha * log(pt)
5361
log_p = F.log_softmax(x, dim=-1)
5462
ce = self.nll_loss(log_p, y)

0 commit comments

Comments
 (0)