11import torch
22import torch .nn as nn
33
4+ from cellseg_models_pytorch .modules import ConvBlock
5+
46__all__ = ["SegHead" ]
57
68
@@ -11,24 +13,43 @@ def __init__(
1113 out_channels : int ,
1214 kernel_size : int = 1 ,
1315 bias : bool = False ,
16+ excitation_channels : int = None ,
1417 ) -> None :
1518 """Segmentation head at the end of decoder branches.
1619
17- Parameters
18- ----------
19- in_channels : int
20+ Parameters:
21+ in_channels (int):
2022 Number of channels in the input tensor.
21- out_channels : int
23+ out_channels ( int):
2224 Number of channels in the output tensor.
23- kernel_size : int, default=1
25+ kernel_size ( int, default=1):
2426 Kernel size for the conv operation.
25- bias : bool, default=False
27+ bias ( bool, default=False):
2628 If True, add a bias term to the conv operation.
29+ excitation_channels (int, default=None):
30+ Number of channels in an optional excitation conv layer before the
31+ output head.
2732
2833 """
2934 super ().__init__ ()
3035 self .n_classes = out_channels
3136
37+ self .excite = None
38+ if excitation_channels is not None :
39+ self .excite = ConvBlock (
40+ name = "basic" ,
41+ in_channels = in_channels ,
42+ out_channels = excitation_channels ,
43+ short_skip = "basic" ,
44+ kernel_size = 3 ,
45+ normalization = None ,
46+ activation = "relu" ,
47+ convolution = "conv" ,
48+ preactivate = False ,
49+ bias = False ,
50+ )
51+ in_channels = self .excite .out_channels
52+
3253 if kernel_size != 1 :
3354 self .head = nn .Conv2d (
3455 in_channels ,
@@ -43,4 +64,6 @@ def __init__(
4364
4465 def forward (self , x : torch .Tensor ) -> torch .Tensor :
4566 """Forward pass of the segmentation head."""
67+ if self .excite is not None :
68+ x = self .excite (x )
4669 return self .head (x )
0 commit comments