|
| 1 | +import torch.nn as nn |
| 2 | + |
1 | 3 | from .base._multitask_unet import MultiTaskUnet |
2 | 4 | from .cellpose.cellpose import ( |
3 | 5 | CellPoseUnet, |
|
26 | 28 | "omnipose_base": omnipose_base, |
27 | 29 | "omnipose_plus": omnipose_plus, |
28 | 30 | "hovernet_base": hovernet_base, |
| 31 | + "hovernet_plus": hovernet_plus, |
29 | 32 | "hovernet_small": hovernet_small, |
30 | 33 | "hovernet_small_plus": hovernet_small_plus, |
31 | 34 | "stardist_base": stardist_base, |
|
34 | 37 | } |
35 | 38 |
|
36 | 39 |
|
37 | | -def get_model(name: str, type: str, ntypes: int = None, ntissues: int = None): |
38 | | - """Get the corect model at hand given name and type.""" |
| 40 | +def get_model( |
| 41 | + name: str, type: str, ntypes: int = None, ntissues: int = None, **kwargs |
| 42 | +) -> nn.Module: |
| 43 | + """Get the corect model at hand given name and type. |
| 44 | +
|
| 45 | + Parameters |
| 46 | + ---------- |
| 47 | + name : str |
| 48 | + Name of the model. |
| 49 | + type : str |
| 50 | + Type of the model. One of "base", "plus", "small", "small_plus". |
| 51 | + ntypes : int |
| 52 | + Number of cell types to segment. |
| 53 | + ntissues : int |
| 54 | + Number of tissue types to segment. |
| 55 | + **kwargs : dict |
| 56 | + Additional keyword arguments. |
| 57 | +
|
| 58 | + Returns |
| 59 | + ------- |
| 60 | + nn.Module: The specified model. |
| 61 | + """ |
39 | 62 | if name == "stardist": |
40 | 63 | if type == "base": |
41 | 64 | model = MODEL_LOOKUP["stardist_base_multiclass"]( |
42 | | - n_rays=32, type_classes=ntypes |
| 65 | + n_rays=32, type_classes=ntypes, **kwargs |
43 | 66 | ) |
44 | 67 | elif type == "plus": |
45 | 68 | model = MODEL_LOOKUP["stardist_plus"]( |
46 | | - n_rays=32, type_classes=ntypes, sem_classes=ntissues |
| 69 | + n_rays=32, type_classes=ntypes, sem_classes=ntissues, **kwargs |
47 | 70 | ) |
48 | 71 | elif name == "cellpose": |
49 | 72 | if type == "base": |
50 | | - model = MODEL_LOOKUP["cellpose_base"](type_classes=ntypes) |
| 73 | + model = MODEL_LOOKUP["cellpose_base"](type_classes=ntypes, **kwargs) |
51 | 74 | elif type == "plus": |
52 | 75 | model = MODEL_LOOKUP["cellpose_plus"]( |
53 | | - type_classes=ntypes, sem_classes=ntissues |
| 76 | + type_classes=ntypes, sem_classes=ntissues, **kwargs |
54 | 77 | ) |
55 | 78 | elif name == "omnipose": |
56 | 79 | if type == "base": |
57 | | - model = MODEL_LOOKUP["omnipose_base"](type_classes=ntypes) |
| 80 | + model = MODEL_LOOKUP["omnipose_base"](type_classes=ntypes, **kwargs) |
58 | 81 | elif type == "plus": |
59 | 82 | model = MODEL_LOOKUP["omnipose_plus"]( |
60 | | - type_classes=ntypes, sem_classes=ntissues |
| 83 | + type_classes=ntypes, sem_classes=ntissues, **kwargs |
61 | 84 | ) |
62 | 85 | elif name == "hovernet": |
63 | 86 | if type == "base": |
64 | | - model = MODEL_LOOKUP["hovernet_base"](type_classes=ntypes) |
| 87 | + model = MODEL_LOOKUP["hovernet_base"](type_classes=ntypes, **kwargs) |
65 | 88 | elif type == "small": |
66 | | - model = MODEL_LOOKUP["hovernet_small"](type_classes=ntypes) |
| 89 | + model = MODEL_LOOKUP["hovernet_small"](type_classes=ntypes, **kwargs) |
67 | 90 | elif type == "plus": |
68 | 91 | model = MODEL_LOOKUP["hovernet_plus"]( |
69 | | - type_classes=ntypes, sem_classes=ntissues |
| 92 | + type_classes=ntypes, sem_classes=ntissues, **kwargs |
70 | 93 | ) |
71 | 94 | elif type == "small_plus": |
72 | 95 | model = MODEL_LOOKUP["hovernet_small_plus"]( |
73 | | - type_classes=ntypes, sem_classes=ntissues |
| 96 | + type_classes=ntypes, sem_classes=ntissues, **kwargs |
74 | 97 | ) |
75 | 98 | else: |
76 | 99 | raise ValueError("Unknown model type or name.") |
|
0 commit comments