|
10 | 10 | # import ohem_cpp |
11 | 11 | # class OhemCELoss(nn.Module): |
12 | 12 | # |
13 | | -# def __init__(self, thresh, ignore_lb=255): |
| 13 | +# def __init__(self, thresh, lb_ignore=255): |
14 | 14 | # super(OhemCELoss, self).__init__() |
15 | 15 | # self.score_thresh = thresh |
16 | | -# self.ignore_lb = ignore_lb |
17 | | -# self.criteria = nn.CrossEntropyLoss(ignore_index=ignore_lb, reduction='mean') |
| 16 | +# self.lb_ignore = lb_ignore |
| 17 | +# self.criteria = nn.CrossEntropyLoss(ignore_index=lb_ignore, reduction='mean') |
18 | 18 | # |
19 | 19 | # def forward(self, logits, labels): |
20 | | -# n_min = labels[labels != self.ignore_lb].numel() // 16 |
| 20 | +# n_min = labels[labels != self.lb_ignore].numel() // 16 |
21 | 21 | # labels = ohem_cpp.score_ohem_label( |
22 | | -# logits, labels, self.ignore_lb, self.score_thresh, n_min).detach() |
| 22 | +# logits, labels, self.lb_ignore, self.score_thresh, n_min).detach() |
23 | 23 | # loss = self.criteria(logits, labels) |
24 | 24 | # return loss |
25 | 25 |
|
26 | 26 |
|
27 | 27 | class OhemCELoss(nn.Module): |
28 | 28 |
|
29 | | - def __init__(self, thresh, ignore_lb=255): |
| 29 | + def __init__(self, thresh, lb_ignore=255): |
30 | 30 | super(OhemCELoss, self).__init__() |
31 | 31 | self.thresh = -torch.log(torch.tensor(thresh, requires_grad=False, dtype=torch.float)).cuda() |
32 | | - self.ignore_lb = ignore_lb |
33 | | - self.criteria = nn.CrossEntropyLoss(ignore_index=ignore_lb, reduction='none') |
| 32 | + self.lb_ignore = lb_ignore |
| 33 | + self.criteria = nn.CrossEntropyLoss(ignore_index=lb_ignore, reduction='none') |
34 | 34 |
|
35 | 35 | def forward(self, logits, labels): |
36 | | - n_min = labels[labels != self.ignore_lb].numel() // 16 |
| 36 | + n_min = labels[labels != self.lb_ignore].numel() // 16 |
37 | 37 | loss = self.criteria(logits, labels).view(-1) |
38 | 38 | loss_hard = loss[loss > self.thresh] |
39 | 39 | if loss_hard.numel() < n_min: |
|
0 commit comments