|
1 | 1 | """SwinUNetR wrapper for napari_cellseg3d.""" |
| 2 | +import inspect |
2 | 3 |
|
3 | 4 | from monai.networks.nets import SwinUNETR |
4 | 5 |
|
@@ -30,29 +31,26 @@ def __init__( |
30 | 31 | use_checkpoint (bool): whether to use checkpointing during training. |
31 | 32 | **kwargs: additional arguments to SwinUNETR. |
32 | 33 | """ |
| 34 | + parent_init = super().__init__ |
| 35 | + sig = inspect.signature(parent_init) |
| 36 | + init_kwargs = dict( |
| 37 | + in_channels=in_channels, |
| 38 | + out_channels=out_channels, |
| 39 | + use_checkpoint=use_checkpoint, |
| 40 | + drop_rate=0.5, |
| 41 | + attn_drop_rate=0.5, |
| 42 | + use_v2=True, |
| 43 | + **kwargs, |
| 44 | + ) |
| 45 | + if "img_size" in sig.parameters: |
| 46 | + # since MONAI API changes depending on py3.8 or py3.9 |
| 47 | + init_kwargs["img_size"] = input_img_size |
33 | 48 | try: |
34 | | - super().__init__( |
35 | | - img_size=input_img_size, |
36 | | - in_channels=in_channels, |
37 | | - out_channels=out_channels, |
38 | | - feature_size=48, |
39 | | - use_checkpoint=use_checkpoint, |
40 | | - drop_rate=0.5, |
41 | | - attn_drop_rate=0.5, |
42 | | - use_v2=True, |
43 | | - **kwargs, |
44 | | - ) |
| 49 | + parent_init(**init_kwargs) |
45 | 50 | except TypeError as e: |
46 | 51 | logger.warning(f"Caught TypeError: {e}") |
47 | | - super().__init__( |
48 | | - in_channels=1, |
49 | | - out_channels=1, |
50 | | - feature_size=48, |
51 | | - use_checkpoint=use_checkpoint, |
52 | | - drop_rate=0.5, |
53 | | - attn_drop_rate=0.5, |
54 | | - use_v2=True, |
55 | | - ) |
| 52 | + init_kwargs["in_channels"] = 1 |
| 53 | + parent_init(**init_kwargs) |
56 | 54 |
|
57 | 55 | # def forward(self, x_in): |
58 | 56 | # y = super().forward(x_in) |
|
0 commit comments