Skip to content

Commit 5fc95bb

Browse files
committed
feat: add new MultiTaskDecoder class to wrap the decoder into one module inside models
1 parent 608c61e commit 5fc95bb

File tree

3 files changed

+312
-34
lines changed

3 files changed

+312
-34
lines changed
Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
from itertools import chain
2+
from typing import Any, Dict, List, Tuple
3+
4+
import torch
5+
import torch.nn as nn
6+
import torch.nn.functional as F
7+
8+
from cellseg_models_pytorch.decoders.long_skips import StemSkip
9+
from cellseg_models_pytorch.decoders.unet_decoder import UnetDecoder
10+
from cellseg_models_pytorch.models.base._initialization import (
11+
initialize_decoder,
12+
initialize_head,
13+
)
14+
from cellseg_models_pytorch.models.base._seg_head import SegHead
15+
from cellseg_models_pytorch.modules.misc_modules import StyleReshape
16+
17+
ALLOWED_HEADS = [
18+
"inst",
19+
"type",
20+
"sem",
21+
"cellpose",
22+
"omnipose",
23+
"stardist",
24+
"hovernet",
25+
"dist",
26+
"dcan",
27+
"dran",
28+
]
29+
30+
__all__ = ["MultiTaskDecoder"]
31+
32+
33+
class MultiTaskDecoder(nn.ModuleDict):
34+
def __init__(
35+
self,
36+
decoders: Tuple[str, ...],
37+
heads: Dict[str, Dict[str, int]],
38+
out_channels: Tuple[int, ...],
39+
enc_channels: Tuple[int, ...],
40+
enc_reductions: Tuple[int, ...],
41+
n_layers: Tuple[int, ...],
42+
n_blocks: Tuple[int, ...],
43+
stage_kws: Tuple[Dict[str, Any], ...],
44+
stem_skip_kws: Dict[str, Any] = None,
45+
long_skip: str = "unet",
46+
out_size: int = None,
47+
style_channels: int = None,
48+
head_excitation_channels: int = None,
49+
) -> None:
50+
"""Create a multi-task decoder.
51+
52+
Parameters:
53+
decoders (Tuple[str, ...]):
54+
Tuple of decoder names. E.g. ("decoder1", "decoder2").
55+
heads (Dict[str, Dict[str, int]]):
56+
Dict containing the heads for each decoder. The inner dict contains the
57+
head name and the number of output channels. For example:
58+
{"decoder1": {"inst": 2, "sem": 5}, "decoder2": {"cellpose": 2}}.
59+
out_channels (Tuple[int, ...]):
60+
Tuple of output channels for each decoder stage. The length of the tuple
61+
should be equal to the number of enc_channels.
62+
enc_channels (Tuple[int, ...]):
63+
Tuple of encoder channels.
64+
enc_reductions (Tuple[int, ...]):
65+
Tuple of encoder reduction factors.
66+
n_layers (Tuple[int, ...]):
67+
Tuple of number of conv layers in each decoder stage.
68+
n_blocks (Tuple[int, ...]):
69+
Tuple of number of conv blocks in each decoder stage.
70+
stage_kws (Tuple[Dict[str, Any], ...]):
71+
Tuple of kwargs for each decoder stage. See UnetDecoderStage for info.
72+
stem_skip_kws (Dict[str, Any], default=None):
73+
Optional kwargs for the stem skip connection.
74+
long_skip (str, default="unet"):
75+
The long skip connection method to be used in the decoder
76+
out_size (int, default=None):
77+
The output size of the model. If given, the output will be interpolated to this size.
78+
style_channels (int, default=None):
79+
The number of style channels for domain adaptation.
80+
head_excitation_channels (int, default=None):
81+
The number of excitation channels for the head. If None, no excitation is
82+
used. Excitation is a conv block before the head that widens the output
83+
channels before the head to avoid 'fight over features' (stardist).
84+
"""
85+
super().__init__()
86+
self.out_size = out_size
87+
self._check_head_args(heads, decoders)
88+
self._check_decoder_args(decoders)
89+
self._check_depth(
90+
len(enc_channels),
91+
{
92+
"n_blocks": n_blocks,
93+
"n_layers": n_layers,
94+
"out_channels": out_channels,
95+
"enc_reductions": enc_reductions,
96+
},
97+
)
98+
99+
# style
100+
self.make_style = None
101+
if style_channels is not None:
102+
self.make_style = StyleReshape(enc_channels[0], style_channels)
103+
104+
# set decoders
105+
for decoder_name in decoders:
106+
decoder = UnetDecoder(
107+
enc_channels=enc_channels,
108+
enc_reductions=enc_reductions,
109+
out_channels=out_channels,
110+
style_channels=style_channels,
111+
long_skip=long_skip,
112+
n_conv_layers=n_layers,
113+
n_conv_blocks=n_blocks,
114+
stage_params=stage_kws,
115+
)
116+
self.add_module(f"{decoder_name}_decoder", decoder)
117+
118+
# optional stem skip
119+
self.has_stem_skip = stem_skip_kws is not None
120+
if self.has_stem_skip:
121+
for decoder_name in decoders:
122+
stem_skip = StemSkip(out_channels=out_channels[-1], **stem_skip_kws)
123+
self.add_module(f"{decoder_name}_stem_skip", stem_skip)
124+
125+
# set heads
126+
for decoder_name in heads.keys():
127+
for output_name, n_classes in heads[decoder_name].items():
128+
seg_head = SegHead(
129+
in_channels=decoder.out_channels,
130+
out_channels=n_classes,
131+
kernel_size=1,
132+
excitation_channels=head_excitation_channels,
133+
)
134+
self.add_module(f"{decoder_name}-{output_name}_head", seg_head)
135+
136+
def forward_features(
137+
self, feats: List[torch.Tensor], style: torch.Tensor = None
138+
) -> Dict[str, List[torch.Tensor]]:
139+
"""Forward all the decoders and return multi-res feature-lists per branch."""
140+
res = {}
141+
decoders = [k for k in self.keys() if "decoder" in k]
142+
143+
for dec in decoders:
144+
featlist = self[dec](*feats, style=style)
145+
branch = "_".join(dec.split("_")[:-1])
146+
res[branch] = featlist
147+
148+
return res
149+
150+
def forward_heads(
151+
self, dec_feats: Dict[str, torch.Tensor]
152+
) -> Dict[str, torch.Tensor]:
153+
"""Forward pass all the seg heads."""
154+
res = {}
155+
heads = [k for k in self.keys() if "head" in k]
156+
for head in heads:
157+
branch_head = head.split("-")
158+
branch = branch_head[0] # branch name
159+
head_name = "_".join(branch_head[1].split("_")[:-1]) # head name
160+
x = self[head](dec_feats[branch][-1]) # the last decoder stage feat map
161+
162+
if self.out_size is not None:
163+
x = F.interpolate(
164+
x, size=self.out_size, mode="bilinear", align_corners=False
165+
)
166+
167+
res[f"{branch}-{head_name}"] = x
168+
169+
return res
170+
171+
def forward_style(self, feat: torch.Tensor) -> torch.Tensor:
172+
"""Forward the style domain adaptation layer."""
173+
style = None
174+
if self.make_style is not None:
175+
style = self.make_style(feat)
176+
177+
return style
178+
179+
def forward_stem_skip(
180+
self, x: torch.Tensor, dec_feats: Dict[str, torch.Tensor]
181+
) -> Dict[str, torch.Tensor]:
182+
"""Forward the stem skip connection."""
183+
stems = [k for k in self.keys() if "stem_skip" in k]
184+
for stem in stems:
185+
branch = stem.split("_")[0]
186+
dec_feats[branch][-1] = self[stem](x, dec_feats[branch][-1])
187+
188+
return dec_feats
189+
190+
def forward(
191+
self, enc_feats: Tuple[torch.Tensor, ...], x_in: torch.Tensor = None
192+
) -> Tuple[Dict[str, List[torch.Tensor]], Dict[str, torch.Tensor]]:
193+
"""Forward pass style, decoders and optional stem skip.
194+
195+
Parameters:
196+
enc_feats (Tuple[torch.Tensor, ...]):
197+
Tuple containing encoder feature tensors.
198+
x_in (torch.Tensor, default=None):
199+
Optional (the input image) tensor for stem skip connection.
200+
201+
Returns:
202+
Tuple[Dict[str, List[torch.Tensor]], Dict[str, torch.Tensor]]:
203+
The output of the seg heads.
204+
"""
205+
style = self.forward_style(enc_feats[0])
206+
dec_feats = self.forward_features(enc_feats, style)
207+
208+
# final input resolution skip connection
209+
if self.has_stem_skip and x_in is not None:
210+
dec_feats = self.forward_stem_skip(x_in, dec_feats)
211+
212+
out = self.forward_heads(dec_feats)
213+
214+
return dec_feats, out
215+
216+
def initialize(self) -> None:
217+
"""Initialize the decoders and segmentation heads."""
218+
for name, module in self.items():
219+
if "decoder" in name:
220+
initialize_decoder(module)
221+
if "head" in name:
222+
initialize_head(module)
223+
224+
def _get_inner_keys(self, d: Dict[str, Dict[str, Any]]) -> List[str]:
225+
"""Get the inner dict keys from a nested dict."""
226+
return list(chain.from_iterable(list(d[k].keys()) for k in d.keys()))
227+
228+
def _flatten_inner_dicts(self, d: Dict[str, Dict[str, Any]]) -> List[str]:
229+
"""Get the inner dicts as one dict from a nested dict."""
230+
return dict(chain.from_iterable(list(d[k].items()) for k in d.keys()))
231+
232+
def _check_string_arg(self, arg: str) -> None:
233+
"""Check the str arg does not contain any chars other than '_' for splitting."""
234+
if "-" in arg:
235+
raise ValueError(
236+
f"The dict key '{arg}' contains '-', which is not allowed. Use '_' instead."
237+
)
238+
239+
def _check_decoder_args(self, decoders: Tuple[str, ...]) -> str:
240+
"""Check `decoders` arg."""
241+
if len(decoders) != len(set(decoders)):
242+
raise ValueError("The decoder names need to be unique.")
243+
244+
for dec in decoders:
245+
self._check_string_arg(dec)
246+
247+
def _check_head_args(
248+
self, heads: Dict[str, int], decoders: Tuple[str, ...]
249+
) -> None:
250+
"""Check `heads` arg."""
251+
for head in heads.keys():
252+
self._check_string_arg(head)
253+
254+
for head in self._get_inner_keys(heads):
255+
if head not in ALLOWED_HEADS:
256+
raise ValueError(
257+
f"Unknown head type: '{head}'. Allowed: {ALLOWED_HEADS}."
258+
)
259+
260+
if not set(decoders) == set(heads.keys()):
261+
raise ValueError(
262+
"The decoder names need match exactly to the keys of `heads`. "
263+
f"Got decoders: {decoders} and heads: {list(heads.keys())}."
264+
)
265+
266+
def _check_depth(self, depth: int, arrs: Dict[str, Tuple[Any, ...]]) -> None:
267+
"""Check that the depth matches to tuple args."""
268+
if not 3 <= depth <= 5:
269+
raise ValueError(
270+
f"max value for `depth` is 5, min value is 3. Got: {depth}"
271+
)
272+
273+
for name, arr in arrs.items():
274+
if depth != len(arr):
275+
raise ValueError(
276+
f"The length of `{name}` should be equal to arg `depth`: {depth}. "
277+
f"For `{name}`, got: {arr}."
278+
)

cellseg_models_pytorch/decoders/unet_decoder.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, List, Optional, Tuple, Union
1+
from typing import Any, Dict, List, Tuple, Union
22

33
import torch
44
import torch.nn as nn
@@ -13,13 +13,13 @@ def __init__(
1313
self,
1414
enc_channels: Tuple[int, ...],
1515
enc_reductions: Tuple[int, ...],
16-
out_channels: Tuple[int, ...] = (256, 128, 64, 32, 16),
16+
out_channels: Tuple[int, ...],
1717
long_skip: Union[None, str, Tuple[str, ...]] = "unet",
1818
n_conv_layers: Union[None, int, Tuple[int, ...]] = 1,
19-
n_transformers: Union[None, int, Tuple[int, ...]] = None,
2019
n_conv_blocks: Union[int, Tuple[Tuple[int, ...], ...]] = 2,
20+
n_transformers: Union[None, int, Tuple[int, ...]] = None,
2121
n_transformer_blocks: Union[int, Tuple[Tuple[int], ...]] = 1,
22-
stage_params: Optional[Tuple[Dict, ...]] = None,
22+
stage_params: Tuple[Dict, ...] = None,
2323
style_channels: int = None,
2424
**kwargs,
2525
) -> None:
@@ -41,7 +41,7 @@ def __init__(
4141
Number of channels at each encoder layer.
4242
enc_reductions : Tuple[int, ...]
4343
The reduction factor from the input image size at each encoder layer.
44-
out_channels : Tuple[int, ...], default=(256, 128, 64, 32, 16)
44+
out_channels : Tuple[int, ...]
4545
Number of channels at each decoder layer output.
4646
long_skip : Union[None, str, Tuple[str, ...]], default="unet"
4747
long skip method to be used. The argument can be given as a tuple, where
@@ -71,7 +71,7 @@ def __init__(
7171
value indicates the number of `SelfAttention`s inside a single
7272
`TranformerLayer` allowing different sized transformer blocks inside
7373
each transformer-layer in the decoder.
74-
stage_params : Optional[Tuple[Dict, ...]], default=None
74+
stage_params : Tuple[Dict, ...], default=None
7575
The keyword args for each of the distinct decoder stages. Incudes the
7676
parameters for the long skip connections, convolutional layers of the
7777
decoder and transformer layers itself. See the `DecoderStage`

cellseg_models_pytorch/decoders/unet_decoder_stage.py

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, List, Optional, Tuple, Union
1+
from typing import Any, Dict, List, Tuple, Union
22

33
import torch
44
import torch.nn as nn
@@ -18,33 +18,33 @@ def __init__(
1818
skip_channels: Tuple[int, ...],
1919
long_skip: str = "unet",
2020
merge_policy: str = "sum",
21-
skip_params: Optional[Dict[str, Any]] = None,
21+
skip_params: Dict[str, Any] = None,
2222
upsampling: str = "fixed-unpool",
23-
n_conv_layers: Optional[int] = 1,
24-
style_channels: Optional[int] = None,
25-
layer_residual: Optional[bool] = False,
26-
n_conv_blocks: Optional[Tuple[int, ...]] = (2,),
27-
short_skips: Optional[Tuple[str, ...]] = ("residual",),
28-
expand_ratios: Optional[Tuple[float, float]] = ((1.0, 1.0),),
29-
block_types: Optional[Tuple[Tuple[str, ...], ...]] = (("basic", "basic"),),
30-
normalizations: Optional[Tuple[Tuple[str, ...], ...]] = (("bn", "bn"),),
31-
activations: Optional[Tuple[Tuple[str, ...], ...]] = (("relu", "relu"),),
32-
convolutions: Optional[Tuple[Tuple[str, ...], ...]] = (("conv", "conv"),),
33-
attentions: Optional[Tuple[Tuple[str, ...], ...]] = ((None, "se"),),
34-
preactivates: Optional[Tuple[Tuple[bool, ...], ...]] = ((False, False),),
35-
preattends: Optional[Tuple[Tuple[bool, ...], ...]] = ((False, False),),
36-
use_styles: Optional[Tuple[Tuple[bool, ...], ...]] = ((False, False),),
37-
kernel_sizes: Optional[Tuple[Tuple[int, ...]]] = ((3, 3),),
38-
groups: Optional[Tuple[Tuple[int, ...]]] = ((1, 1),),
39-
biases: Optional[Tuple[Tuple[bool, ...]]] = ((False, False),),
40-
n_transformers: Optional[int] = None,
41-
n_transformer_blocks: Optional[Tuple[int, ...]] = (1,),
42-
transformer_blocks: Optional[Tuple[Tuple[str, ...], ...]] = (("exact",),),
43-
transformer_computations: Optional[Tuple[Tuple[str, ...], ...]] = (("basic",),),
44-
transformer_biases: Optional[Tuple[Tuple[bool, ...], ...]] = ((False,),),
45-
transformer_dropouts: Optional[Tuple[Tuple[float, ...], ...]] = ((0.0,),),
46-
transformer_layer_scales: Optional[Tuple[Tuple[bool, ...], ...]] = ((False,),),
47-
transformer_params: Optional[List[Dict[str, Any]]] = None,
23+
n_conv_layers: int = 1,
24+
style_channels: int = None,
25+
layer_residual: bool = False,
26+
n_conv_blocks: Tuple[int, ...] = (2,),
27+
short_skips: Tuple[str, ...] = ("residual",),
28+
expand_ratios: Tuple[float, float] = ((1.0, 1.0),),
29+
block_types: Tuple[Tuple[str, ...], ...] = (("basic", "basic"),),
30+
normalizations: Tuple[Tuple[str, ...], ...] = (("bn", "bn"),),
31+
activations: Tuple[Tuple[str, ...], ...] = (("relu", "relu"),),
32+
convolutions: Tuple[Tuple[str, ...], ...] = (("conv", "conv"),),
33+
attentions: Tuple[Tuple[str, ...], ...] = ((None, "se"),),
34+
preactivates: Tuple[Tuple[bool, ...], ...] = ((False, False),),
35+
preattends: Tuple[Tuple[bool, ...], ...] = ((False, False),),
36+
use_styles: Tuple[Tuple[bool, ...], ...] = ((False, False),),
37+
kernel_sizes: Tuple[Tuple[int, ...]] = ((3, 3),),
38+
groups: Tuple[Tuple[int, ...]] = ((1, 1),),
39+
biases: Tuple[Tuple[bool, ...]] = ((False, False),),
40+
n_transformers: int = None,
41+
n_transformer_blocks: Tuple[int, ...] = (1,),
42+
transformer_blocks: Tuple[Tuple[str, ...], ...] = (("exact",),),
43+
transformer_computations: Tuple[Tuple[str, ...], ...] = (("basic",),),
44+
transformer_biases: Tuple[Tuple[bool, ...], ...] = ((False,),),
45+
transformer_dropouts: Tuple[Tuple[float, ...], ...] = ((0.0,),),
46+
transformer_layer_scales: Tuple[Tuple[bool, ...], ...] = ((False,),),
47+
transformer_params: List[Dict[str, Any]] = None,
4848
**kwargs,
4949
) -> None:
5050
"""Build a decoder stage.
@@ -73,7 +73,7 @@ def __init__(
7373
Allowed: "cross-attn", "unet", "unetpp", "unet3p", "unet3p-lite", None
7474
merge_policy : str, default="sum"
7575
The long skip merge policy. One of: "sum", "cat"
76-
skip_params : Optional[Dict]
76+
skip_params : Dict[str, Any], default=None
7777
Extra keyword arguments for the skip-connection module. These depend
7878
on the skip module. Refer to specific skip modules for more info.
7979
upsampling : str, default="fixed-unpool"

0 commit comments

Comments
 (0)