Skip to content

Commit cd12f87

Browse files
committed
fix: move the excitaiton convs from stardist to SegHead
1 parent 1cda9ca commit cd12f87

File tree

1 file changed

+29
-6
lines changed

1 file changed

+29
-6
lines changed

cellseg_models_pytorch/models/base/_seg_head.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import torch
22
import torch.nn as nn
33

4+
from cellseg_models_pytorch.modules import ConvBlock
5+
46
__all__ = ["SegHead"]
57

68

@@ -11,24 +13,43 @@ def __init__(
1113
out_channels: int,
1214
kernel_size: int = 1,
1315
bias: bool = False,
16+
excitation_channels: int = None,
1417
) -> None:
1518
"""Segmentation head at the end of decoder branches.
1619
17-
Parameters
18-
----------
19-
in_channels : int
20+
Parameters:
21+
in_channels (int):
2022
Number of channels in the input tensor.
21-
out_channels : int
23+
out_channels (int):
2224
Number of channels in the output tensor.
23-
kernel_size : int, default=1
25+
kernel_size (int, default=1):
2426
Kernel size for the conv operation.
25-
bias : bool, default=False
27+
bias (bool, default=False):
2628
If True, add a bias term to the conv operation.
29+
excitation_channels (int, default=None):
30+
Number of channels in an optional excitation conv layer before the
31+
output head.
2732
2833
"""
2934
super().__init__()
3035
self.n_classes = out_channels
3136

37+
self.excite = None
38+
if excitation_channels is not None:
39+
self.excite = ConvBlock(
40+
name="basic",
41+
in_channels=in_channels,
42+
out_channels=excitation_channels,
43+
short_skip="basic",
44+
kernel_size=3,
45+
normalization=None,
46+
activation="relu",
47+
convolution="conv",
48+
preactivate=False,
49+
bias=False,
50+
)
51+
in_channels = self.excite.out_channels
52+
3253
if kernel_size != 1:
3354
self.head = nn.Conv2d(
3455
in_channels,
@@ -43,4 +64,6 @@ def __init__(
4364

4465
def forward(self, x: torch.Tensor) -> torch.Tensor:
4566
"""Forward pass of the segmentation head."""
67+
if self.excite is not None:
68+
x = self.excite(x)
4669
return self.head(x)

0 commit comments

Comments
 (0)