Skip to content

Commit 95089c2

Browse files
committed
fix: uodate the criterions and losses to use F.one_hot
1 parent a5d681b commit 95089c2

File tree

14 files changed

+317
-363
lines changed

14 files changed

+317
-363
lines changed

cellseg_models_pytorch/losses/criterions/bce.py

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,24 @@ def __init__(
1515
apply_mask: bool = False,
1616
edge_weight: float = None,
1717
class_weights: torch.Tensor = None,
18-
**kwargs
18+
**kwargs,
1919
) -> None:
2020
"""Binary cross entropy loss with weighting and other tricks.
2121
2222
Parameters
23-
----------
24-
apply_sd : bool, default=False
25-
If True, Spectral decoupling regularization will be applied to the
26-
loss matrix.
27-
apply_ls : bool, default=False
28-
If True, Label smoothing will be applied to the target.
29-
apply_svls : bool, default=False
30-
If True, spatially varying label smoothing will be applied to the target
31-
apply_mask : bool, default=False
32-
If True, a mask will be applied to the loss matrix. Mask shape: (B, H, W)
33-
edge_weight : float, default=None
34-
Weight that is added to object borders.
35-
class_weights : torch.Tensor, default=None
36-
Class weights. A tensor of shape (n_classes,).
23+
apply_sd (bool, default=False):
24+
If True, Spectral decoupling regularization will be applied to the
25+
loss matrix.
26+
apply_ls (bool, default=False):
27+
If True, Label smoothing will be applied to the target.
28+
apply_svls (bool, default=False):
29+
If True, spatially varying label smoothing will be applied to the target
30+
apply_mask (bool, default=False):
31+
If True, a mask will be applied to the loss matrix. Mask shape: (B, H, W)
32+
edge_weight (float, default=None):
33+
Weight that is added to object borders.
34+
class_weights (torch.Tensor, default=None):
35+
Class weights. A tensor of shape (n_classes,).
3736
"""
3837
super().__init__(
3938
apply_sd, apply_ls, apply_svls, apply_mask, class_weights, edge_weight
@@ -46,23 +45,21 @@ def forward(
4645
target: torch.Tensor,
4746
target_weight: torch.Tensor = None,
4847
mask: torch.Tensor = None,
49-
**kwargs
48+
**kwargs,
5049
) -> torch.Tensor:
5150
"""Compute binary cross entropy loss.
5251
53-
Parameters
54-
----------
55-
yhat : torch.Tensor
52+
Parameters:
53+
yhat (torch.Tensor):
5654
The prediction map. Shape (B, C, H, W).
57-
target : torch.Tensor
55+
target (torch.Tensor):
5856
the ground truth annotations. Shape (B, H, W).
59-
target_weight : torch.Tensor, default=None
57+
target_weight (torch.Tensor, default=None):
6058
The edge weight map. Shape (B, H, W).
61-
mask : torch.Tensor, default=None
59+
mask (torch.Tensor, default=None):
6260
The mask map. Shape (B, H, W).
6361
64-
Returns
65-
-------
62+
Returns:
6663
torch.Tensor:
6764
Computed BCE loss (scalar).
6865
"""

cellseg_models_pytorch/losses/criterions/ce.py

Lines changed: 24 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import torch
22
import torch.nn.functional as F
33

4-
from ...utils import tensor_one_hot
54
from ..weighted_base_loss import WeightedBaseLoss
65

76
__all__ = ["CELoss"]
@@ -20,21 +19,20 @@ def __init__(
2019
) -> None:
2120
"""Cross-Entropy loss with weighting.
2221
23-
Parameters
24-
----------
25-
apply_sd : bool, default=False
26-
If True, Spectral decoupling regularization will be applied to the
27-
loss matrix.
28-
apply_ls : bool, default=False
29-
If True, Label smoothing will be applied to the target.
30-
apply_svls : bool, default=False
31-
If True, spatially varying label smoothing will be applied to the target
32-
apply_mask : bool, default=False
33-
If True, a mask will be applied to the loss matrix. Mask shape: (B, H, W)
34-
edge_weight : float, default=None
35-
Weight that is added to object borders.
36-
class_weights : torch.Tensor, default=None
37-
Class weights. A tensor of shape (n_classes,).
22+
Parameters:
23+
apply_sd (bool, default=False):
24+
If True, Spectral decoupling regularization will be applied to the
25+
loss matrix.
26+
apply_ls (bool, default=False):
27+
If True, Label smoothing will be applied to the target.
28+
apply_svls (bool, default=False):
29+
If True, spatially varying label smoothing will be applied to the target
30+
apply_mask (bool, default=False):
31+
If True, a mask will be applied to the loss matrix. Mask shape: (B, H, W)
32+
edge_weight (float, default=None):
33+
Weight that is added to object borders.
34+
class_weights (torch.Tensor, default=None):
35+
Class weights. A tensor of shape (n_classes,).
3836
"""
3937
super().__init__(
4038
apply_sd, apply_ls, apply_svls, apply_mask, class_weights, edge_weight
@@ -51,35 +49,33 @@ def forward(
5149
) -> torch.Tensor:
5250
"""Compute the cross entropy loss.
5351
54-
Parameters
55-
----------
56-
yhat : torch.Tensor
52+
Parameters:
53+
yhat (torch.Tensor):
5754
The prediction map. Shape (B, C, H, W).
58-
target : torch.Tensor
55+
target (torch.Tensor):
5956
the ground truth annotations. Shape (B, H, W).
60-
target_weight : torch.Tensor, default=None
57+
target_weight (torch.Tensor, default=None):
6158
The edge weight map. Shape (B, H, W).
62-
mask : torch.Tensor, default=None
59+
mask (torch.Tensor, default=None):
6360
The mask map. Shape (B, H, W).
6461
65-
Returns
66-
-------
62+
Returns:
6763
torch.Tensor:
6864
Computed CE loss (scalar).
6965
"""
7066
input_soft = F.softmax(yhat, dim=1) + self.eps # (B, C, H, W)
71-
num_classes = yhat.shape[1]
72-
target_one_hot = tensor_one_hot(target, num_classes) # (B, C, H, W)
67+
n_classes = yhat.shape[1]
68+
target_one_hot = F.one_hot(target.long(), n_classes).permute(0, 3, 1, 2)
7369
assert target_one_hot.shape == yhat.shape
7470

7571
if self.apply_svls:
7672
target_one_hot = self.apply_svls_to_target(
77-
target_one_hot, num_classes, **kwargs
73+
target_one_hot, n_classes, **kwargs
7874
)
7975

8076
if self.apply_ls:
8177
target_one_hot = self.apply_ls_to_target(
82-
target_one_hot, num_classes, **kwargs
78+
target_one_hot, n_classes, **kwargs
8379
)
8480

8581
loss = -torch.sum(target_one_hot * torch.log(input_soft), dim=1) # (B, H, W)

cellseg_models_pytorch/losses/criterions/dice.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import torch
22
import torch.nn.functional as F
33

4-
from cellseg_models_pytorch.utils import tensor_one_hot
5-
64
from ..weighted_base_loss import WeightedBaseLoss
75

86
__all__ = ["DiceLoss"]
@@ -55,34 +53,32 @@ def forward(
5553
"""Compute the DICE coefficient.
5654
5755
Parameters
58-
----------
59-
yhat : torch.Tensor
56+
yhat (torch.Tensor):
6057
The prediction map. Shape (B, C, H, W).
61-
target : torch.Tensor
58+
target (torch.Tensor):
6259
the ground truth annotations. Shape (B, H, W).
63-
target_weight : torch.Tensor, default=None
60+
target_weight (torch.Tensor, default=None):
6461
The edge weight map. Shape (B, H, W).
65-
mask : torch.Tensor, default=None
62+
mask (torch.Tensor, default=None):
6663
The mask map. Shape (B, H, W).
6764
68-
Returns
69-
-------
65+
Returns:
7066
torch.Tensor:
7167
Computed DICE loss (scalar).
7268
"""
7369
yhat_soft = F.softmax(yhat, dim=1)
74-
num_classes = yhat.shape[1]
75-
target_one_hot = tensor_one_hot(target, n_classes=num_classes)
70+
n_classes = yhat.shape[1]
71+
target_one_hot = F.one_hot(target.long(), n_classes).permute(0, 3, 1, 2)
7672
assert target_one_hot.shape == yhat.shape
7773

7874
if self.apply_svls:
7975
target_one_hot = self.apply_svls_to_target(
80-
target_one_hot, num_classes, **kwargs
76+
target_one_hot, n_classes, **kwargs
8177
)
8278

8379
if self.apply_ls:
8480
target_one_hot = self.apply_ls_to_target(
85-
target_one_hot, num_classes, **kwargs
81+
target_one_hot, n_classes, **kwargs
8682
)
8783

8884
intersection = torch.sum(yhat_soft * target_one_hot, 1)

cellseg_models_pytorch/losses/criterions/focal.py

Lines changed: 32 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import torch
22
import torch.nn.functional as F
33

4-
from ...utils import tensor_one_hot
54
from ..weighted_base_loss import WeightedBaseLoss
65

6+
__all__ = ["FocalLoss"]
7+
78

89
class FocalLoss(WeightedBaseLoss):
910
def __init__(
@@ -16,7 +17,7 @@ def __init__(
1617
apply_mask: bool = False,
1718
edge_weight: float = None,
1819
class_weights: torch.Tensor = None,
19-
**kwargs
20+
**kwargs,
2021
) -> None:
2122
"""Focal loss.
2223
@@ -25,25 +26,24 @@ def __init__(
2526
Optionally applies, label smoothing, spatially varying label smoothing or
2627
weights at the object edges or class weights to the loss.
2728
28-
Parameters
29-
----------
30-
alpha : float, default=0.5
31-
Weight factor b/w [0,1].
32-
gamma : float, default=2.0
33-
Focusing factor.
34-
apply_sd : bool, default=False
35-
If True, Spectral decoupling regularization will be applied to the
36-
loss matrix.
37-
apply_ls : bool, default=False
38-
If True, Label smoothing will be applied to the target.
39-
apply_svls : bool, default=False
40-
If True, spatially varying label smoothing will be applied to the target
41-
apply_mask : bool, default=False
42-
If True, a mask will be applied to the loss matrix. Mask shape: (B, H, W)
43-
edge_weight : float, default=none
44-
Weight that is added to object borders.
45-
class_weights : torch.Tensor, default=None
46-
Class weights. A tensor of shape (n_classes,).
29+
Parameters:
30+
alpha (float, default=0.5):
31+
Weight factor b/w [0,1].
32+
gamma (float, default=2.0):
33+
Focusing factor.
34+
apply_sd (bool, default=False):
35+
If True, Spectral decoupling regularization will be applied to the
36+
loss matrix.
37+
apply_ls (bool, default=False):
38+
If True, Label smoothing will be applied to the target.
39+
apply_svls (bool, default=False):
40+
If True, spatially varying label smoothing will be applied to the target
41+
apply_mask (bool, default=False):
42+
If True, a mask will be applied to the loss matrix. Mask shape: (B, H, W)
43+
edge_weight (float, default=none):
44+
Weight that is added to object borders.
45+
class_weights (torch.Tensor, default=None):
46+
Class weights. A tensor of shape (n_classes,).
4747
"""
4848
super().__init__(
4949
apply_sd, apply_ls, apply_svls, apply_mask, class_weights, edge_weight
@@ -58,39 +58,37 @@ def forward(
5858
target: torch.Tensor,
5959
target_weight: torch.Tensor = None,
6060
mask: torch.Tensor = None,
61-
**kwargs
61+
**kwargs,
6262
) -> torch.Tensor:
6363
"""Compute the focal loss.
6464
65-
Parameters
66-
----------
67-
yhat : torch.Tensor
65+
Parameters:
66+
yhat (torch.Tensor):
6867
The prediction map. Shape (B, C, H, W).
69-
target : torch.Tensor
68+
target (torch.Tensor):
7069
the ground truth annotations. Shape (B, H, W).
71-
target_weight : torch.Tensor, default=None
70+
target_weight (torch.Tensor, default=None):
7271
The edge weight map. Shape (B, H, W).
73-
mask : torch.Tensor, default=None
72+
mask (torch.Tensor, default=None):
7473
The mask map. Shape (B, H, W).
7574
76-
Returns
77-
-------
75+
Returns:
7876
torch.Tensor:
7977
Computed Focal loss (scalar).
8078
"""
8179
input_soft = F.softmax(yhat, dim=1) + self.eps # (B, C, H, W)
82-
num_classes = yhat.shape[1]
83-
target_one_hot = tensor_one_hot(target, num_classes) # (B, C, H, W)
80+
n_classes = yhat.shape[1]
81+
target_one_hot = F.one_hot(target.long(), n_classes).permute(0, 3, 1, 2)
8482
assert target_one_hot.shape == yhat.shape
8583

8684
if self.apply_svls:
8785
target_one_hot = self.apply_svls_to_target(
88-
target_one_hot, num_classes, **kwargs
86+
target_one_hot, n_classes, **kwargs
8987
)
9088

9189
if self.apply_ls:
9290
target_one_hot = self.apply_ls_to_target(
93-
target_one_hot, num_classes, **kwargs
91+
target_one_hot, n_classes, **kwargs
9492
)
9593

9694
weight = (1.0 - input_soft) ** self.gamma

0 commit comments

Comments
 (0)