Skip to content

Commit 281036a

Browse files
committed
fix: adjust the model classes to use MultiTaskDecoder
1 parent cd12f87 commit 281036a

File tree

9 files changed

+749
-1251
lines changed

9 files changed

+749
-1251
lines changed
Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +0,0 @@
1-
from .unet_decoder import UnetDecoder
2-
from .unet_decoder_stage import UnetDecoderStage
3-
4-
__all__ = ["UnetDecoderStage", "UnetDecoder"]

cellseg_models_pytorch/models/__init__.py

Lines changed: 40 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import torch.nn as nn
22

3-
from .base._multitask_unet import MultiTaskUnet
43
from .cellpose.cellpose import (
54
CellPoseUnet,
65
cellpose_base,
@@ -53,84 +52,98 @@
5352

5453

5554
def get_model(
56-
name: str, type: str, ntypes: int = None, ntissues: int = None, **kwargs
55+
name: str,
56+
type: str,
57+
n_type_classes: int = None,
58+
n_sem_classes: int = None,
59+
**kwargs,
5760
) -> nn.Module:
5861
"""Get the corect model at hand given name and type.
5962
60-
Parameters
61-
----------
62-
name : str
63+
Parameters:
64+
name (str):
6365
Name of the model.
64-
type : str
66+
type (str):
6567
Type of the model. One of "base", "plus", "small", "small_plus".
66-
ntypes : int
68+
n_type_classes (int):
6769
Number of cell types to segment.
68-
ntissues : int
70+
n_sem_classes (int):
6971
Number of tissue types to segment.
70-
**kwargs : dict
72+
**kwargs
7173
Additional keyword arguments.
7274
73-
Returns
74-
-------
75+
Returns:
7576
nn.Module: The specified model.
7677
"""
7778
if name == "stardist":
7879
if type == "base":
7980
model = MODEL_LOOKUP["stardist_base_multiclass"](
80-
type_classes=ntypes, **kwargs
81+
n_type_classes=n_type_classes, **kwargs
8182
)
8283
elif type == "plus":
8384
model = MODEL_LOOKUP["stardist_plus"](
84-
type_classes=ntypes, sem_classes=ntissues, **kwargs
85+
n_type_classes=n_type_classes, n_sem_classes=n_sem_classes, **kwargs
8586
)
8687
elif name == "cppnet":
8788
if type == "base":
8889
model = MODEL_LOOKUP["cppnet_base_multiclass"](
89-
type_classes=ntypes, **kwargs
90+
n_type_classes=n_type_classes, **kwargs
9091
)
9192
elif type == "plus":
9293
model = MODEL_LOOKUP["cppnet_plus"](
93-
type_classes=ntypes, sem_classes=ntissues, **kwargs
94+
n_type_classes=n_type_classes, n_sem_classes=n_sem_classes, **kwargs
9495
)
9596
elif name == "cellpose":
9697
if type == "base":
97-
model = MODEL_LOOKUP["cellpose_base"](type_classes=ntypes, **kwargs)
98+
model = MODEL_LOOKUP["cellpose_base"](
99+
n_type_classes=n_type_classes, **kwargs
100+
)
98101
elif type == "plus":
99102
model = MODEL_LOOKUP["cellpose_plus"](
100-
type_classes=ntypes, sem_classes=ntissues, **kwargs
103+
n_type_classes=n_type_classes, n_sem_classes=n_sem_classes, **kwargs
101104
)
102105
elif name == "omnipose":
103106
if type == "base":
104-
model = MODEL_LOOKUP["omnipose_base"](type_classes=ntypes, **kwargs)
107+
model = MODEL_LOOKUP["omnipose_base"](
108+
n_type_classes=n_type_classes, **kwargs
109+
)
105110
elif type == "plus":
106111
model = MODEL_LOOKUP["omnipose_plus"](
107-
type_classes=ntypes, sem_classes=ntissues, **kwargs
112+
n_type_classes=n_type_classes, n_sem_classes=n_sem_classes, **kwargs
108113
)
109114
elif name == "hovernet":
110115
if type == "base":
111-
model = MODEL_LOOKUP["hovernet_base"](type_classes=ntypes, **kwargs)
116+
model = MODEL_LOOKUP["hovernet_base"](
117+
n_type_classes=n_type_classes, **kwargs
118+
)
112119
elif type == "small":
113-
model = MODEL_LOOKUP["hovernet_small"](type_classes=ntypes, **kwargs)
120+
model = MODEL_LOOKUP["hovernet_small"](
121+
n_type_classes=n_type_classes, **kwargs
122+
)
114123
elif type == "plus":
115124
model = MODEL_LOOKUP["hovernet_plus"](
116-
type_classes=ntypes, sem_classes=ntissues, **kwargs
125+
n_type_classes=n_type_classes, n_sem_classes=n_sem_classes, **kwargs
117126
)
118127
elif type == "small_plus":
119128
model = MODEL_LOOKUP["hovernet_small_plus"](
120-
type_classes=ntypes, sem_classes=ntissues, **kwargs
129+
n_type_classes=n_type_classes, n_sem_classes=n_sem_classes, **kwargs
121130
)
122131
elif name == "cellvit":
123132
if type == "base":
124-
model = MODEL_LOOKUP["cellvit_sam_base"](type_classes=ntypes, **kwargs)
133+
model = MODEL_LOOKUP["cellvit_sam_base"](
134+
n_type_classes=n_type_classes, **kwargs
135+
)
125136
elif type == "small":
126-
model = MODEL_LOOKUP["cellvit_sam_small"](type_classes=ntypes, **kwargs)
137+
model = MODEL_LOOKUP["cellvit_sam_small"](
138+
n_type_classes=n_type_classes, **kwargs
139+
)
127140
elif type == "plus":
128141
model = MODEL_LOOKUP["cellvit_sam_plus"](
129-
type_classes=ntypes, sem_classes=ntissues, **kwargs
142+
n_type_classes=n_type_classes, n_sem_classes=n_sem_classes, **kwargs
130143
)
131144
elif type == "small_plus":
132145
model = MODEL_LOOKUP["cellvit_sam_small_plus"](
133-
type_classes=ntypes, sem_classes=ntissues, **kwargs
146+
n_type_classes=n_type_classes, n_sem_classes=n_sem_classes, **kwargs
134147
)
135148
else:
136149
raise ValueError("Unknown model type or name.")
@@ -139,7 +152,6 @@ def get_model(
139152

140153

141154
__all__ = [
142-
"MultiTaskUnet",
143155
"HoverNet",
144156
"hovernet_base",
145157
"hovernet_plus",

0 commit comments

Comments
 (0)