11import torch .nn as nn
22
3- from .base ._multitask_unet import MultiTaskUnet
43from .cellpose .cellpose import (
54 CellPoseUnet ,
65 cellpose_base ,
5352
5453
5554def get_model (
56- name : str , type : str , ntypes : int = None , ntissues : int = None , ** kwargs
55+ name : str ,
56+ type : str ,
57+ n_type_classes : int = None ,
58+ n_sem_classes : int = None ,
59+ ** kwargs ,
5760) -> nn .Module :
5861 """Get the corect model at hand given name and type.
5962
60- Parameters
61- ----------
62- name : str
63+ Parameters:
64+ name (str):
6365 Name of the model.
64- type : str
66+ type ( str):
6567 Type of the model. One of "base", "plus", "small", "small_plus".
66- ntypes : int
68+ n_type_classes ( int):
6769 Number of cell types to segment.
68- ntissues : int
70+ n_sem_classes ( int):
6971 Number of tissue types to segment.
70- **kwargs : dict
72+ **kwargs
7173 Additional keyword arguments.
7274
73- Returns
74- -------
75+ Returns:
7576 nn.Module: The specified model.
7677 """
7778 if name == "stardist" :
7879 if type == "base" :
7980 model = MODEL_LOOKUP ["stardist_base_multiclass" ](
80- type_classes = ntypes , ** kwargs
81+ n_type_classes = n_type_classes , ** kwargs
8182 )
8283 elif type == "plus" :
8384 model = MODEL_LOOKUP ["stardist_plus" ](
84- type_classes = ntypes , sem_classes = ntissues , ** kwargs
85+ n_type_classes = n_type_classes , n_sem_classes = n_sem_classes , ** kwargs
8586 )
8687 elif name == "cppnet" :
8788 if type == "base" :
8889 model = MODEL_LOOKUP ["cppnet_base_multiclass" ](
89- type_classes = ntypes , ** kwargs
90+ n_type_classes = n_type_classes , ** kwargs
9091 )
9192 elif type == "plus" :
9293 model = MODEL_LOOKUP ["cppnet_plus" ](
93- type_classes = ntypes , sem_classes = ntissues , ** kwargs
94+ n_type_classes = n_type_classes , n_sem_classes = n_sem_classes , ** kwargs
9495 )
9596 elif name == "cellpose" :
9697 if type == "base" :
97- model = MODEL_LOOKUP ["cellpose_base" ](type_classes = ntypes , ** kwargs )
98+ model = MODEL_LOOKUP ["cellpose_base" ](
99+ n_type_classes = n_type_classes , ** kwargs
100+ )
98101 elif type == "plus" :
99102 model = MODEL_LOOKUP ["cellpose_plus" ](
100- type_classes = ntypes , sem_classes = ntissues , ** kwargs
103+ n_type_classes = n_type_classes , n_sem_classes = n_sem_classes , ** kwargs
101104 )
102105 elif name == "omnipose" :
103106 if type == "base" :
104- model = MODEL_LOOKUP ["omnipose_base" ](type_classes = ntypes , ** kwargs )
107+ model = MODEL_LOOKUP ["omnipose_base" ](
108+ n_type_classes = n_type_classes , ** kwargs
109+ )
105110 elif type == "plus" :
106111 model = MODEL_LOOKUP ["omnipose_plus" ](
107- type_classes = ntypes , sem_classes = ntissues , ** kwargs
112+ n_type_classes = n_type_classes , n_sem_classes = n_sem_classes , ** kwargs
108113 )
109114 elif name == "hovernet" :
110115 if type == "base" :
111- model = MODEL_LOOKUP ["hovernet_base" ](type_classes = ntypes , ** kwargs )
116+ model = MODEL_LOOKUP ["hovernet_base" ](
117+ n_type_classes = n_type_classes , ** kwargs
118+ )
112119 elif type == "small" :
113- model = MODEL_LOOKUP ["hovernet_small" ](type_classes = ntypes , ** kwargs )
120+ model = MODEL_LOOKUP ["hovernet_small" ](
121+ n_type_classes = n_type_classes , ** kwargs
122+ )
114123 elif type == "plus" :
115124 model = MODEL_LOOKUP ["hovernet_plus" ](
116- type_classes = ntypes , sem_classes = ntissues , ** kwargs
125+ n_type_classes = n_type_classes , n_sem_classes = n_sem_classes , ** kwargs
117126 )
118127 elif type == "small_plus" :
119128 model = MODEL_LOOKUP ["hovernet_small_plus" ](
120- type_classes = ntypes , sem_classes = ntissues , ** kwargs
129+ n_type_classes = n_type_classes , n_sem_classes = n_sem_classes , ** kwargs
121130 )
122131 elif name == "cellvit" :
123132 if type == "base" :
124- model = MODEL_LOOKUP ["cellvit_sam_base" ](type_classes = ntypes , ** kwargs )
133+ model = MODEL_LOOKUP ["cellvit_sam_base" ](
134+ n_type_classes = n_type_classes , ** kwargs
135+ )
125136 elif type == "small" :
126- model = MODEL_LOOKUP ["cellvit_sam_small" ](type_classes = ntypes , ** kwargs )
137+ model = MODEL_LOOKUP ["cellvit_sam_small" ](
138+ n_type_classes = n_type_classes , ** kwargs
139+ )
127140 elif type == "plus" :
128141 model = MODEL_LOOKUP ["cellvit_sam_plus" ](
129- type_classes = ntypes , sem_classes = ntissues , ** kwargs
142+ n_type_classes = n_type_classes , n_sem_classes = n_sem_classes , ** kwargs
130143 )
131144 elif type == "small_plus" :
132145 model = MODEL_LOOKUP ["cellvit_sam_small_plus" ](
133- type_classes = ntypes , sem_classes = ntissues , ** kwargs
146+ n_type_classes = n_type_classes , n_sem_classes = n_sem_classes , ** kwargs
134147 )
135148 else :
136149 raise ValueError ("Unknown model type or name." )
@@ -139,7 +152,6 @@ def get_model(
139152
140153
141154__all__ = [
142- "MultiTaskUnet" ,
143155 "HoverNet" ,
144156 "hovernet_base" ,
145157 "hovernet_plus" ,
0 commit comments