Skip to content

Commit 8307c48

Browse files
committed
fix: fix the imports in base loss
1 parent d32277b commit 8307c48

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

cellseg_models_pytorch/losses/weighted_base_loss.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
import torch.nn as nn
33

4-
from ..utils import filter2D, gaussian_kernel2d
4+
from cellseg_models_pytorch.utils.convolve import filter2D, gaussian_kernel2d
55

66

77
class WeightedBaseLoss(nn.Module):
@@ -13,7 +13,7 @@ def __init__(
1313
apply_mask: bool = False,
1414
class_weights: torch.Tensor = None,
1515
edge_weight: float = None,
16-
**kwargs
16+
**kwargs,
1717
) -> None:
1818
"""Init a base class for weighted cross entropy based losses.
1919
@@ -99,7 +99,7 @@ def apply_svls_to_target(
9999
num_classes: int,
100100
kernel_size: int = 5,
101101
sigma: int = 3,
102-
**kwargs
102+
**kwargs,
103103
) -> torch.Tensor:
104104
"""Apply spatially varying label smoothihng to target map.
105105

0 commit comments

Comments
 (0)