-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcriterion.py
More file actions
20 lines (17 loc) · 805 Bytes
/
criterion.py
File metadata and controls
20 lines (17 loc) · 805 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch.nn as nn
import torch.nn.functional as F
class CELoss(nn.Module):
def __init__(self, num_class, multi_label):
super(CELoss, self).__init__()
self.num_class = num_class
self.multi_label = multi_label
print(f'[INFO] Using multi_label: {self.multi_label}')
def forward(self, logits, targets):
if self.num_class == 2 and not self.multi_label:
loss = F.binary_cross_entropy_with_logits(logits, targets.float())
elif self.num_class > 2 and not self.multi_label:
loss = F.cross_entropy(logits, targets.long())
else:
is_labeled = targets == targets # mask for labeled data
loss = F.binary_cross_entropy_with_logits(logits[is_labeled], targets[is_labeled].float())
return loss