Skip to content

Commit b28d3fd

Browse files
committed
Fix Swin args + try to get tifffile installed in CI
1 parent c377e29 commit b28d3fd

File tree

2 files changed

+19
-21
lines changed

2 files changed

+19
-21
lines changed

.github/workflows/test_and_deploy.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ jobs:
5151
run: |
5252
python -m pip install --upgrade pip
5353
python -m pip install setuptools tox tox-gh-actions
54-
# python -m pip install tifffile
54+
python -m pip install tifffile
5555
python -m pip install monai[nibabel,einops,tifffile]
5656
# pip install git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf
5757

napari_cellseg3d/code_models/models/model_SwinUNetR.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""SwinUNetR wrapper for napari_cellseg3d."""
2+
import inspect
23

34
from monai.networks.nets import SwinUNETR
45

@@ -30,29 +31,26 @@ def __init__(
3031
use_checkpoint (bool): whether to use checkpointing during training.
3132
**kwargs: additional arguments to SwinUNETR.
3233
"""
34+
parent_init = super().__init__
35+
sig = inspect.signature(parent_init)
36+
init_kwargs = dict(
37+
in_channels=in_channels,
38+
out_channels=out_channels,
39+
use_checkpoint=use_checkpoint,
40+
drop_rate=0.5,
41+
attn_drop_rate=0.5,
42+
use_v2=True,
43+
**kwargs,
44+
)
45+
if "img_size" in sig.parameters:
46+
# since MONAI API changes depending on py3.8 or py3.9
47+
init_kwargs["img_size"] = input_img_size
3348
try:
34-
super().__init__(
35-
img_size=input_img_size,
36-
in_channels=in_channels,
37-
out_channels=out_channels,
38-
feature_size=48,
39-
use_checkpoint=use_checkpoint,
40-
drop_rate=0.5,
41-
attn_drop_rate=0.5,
42-
use_v2=True,
43-
**kwargs,
44-
)
49+
parent_init(**init_kwargs)
4550
except TypeError as e:
4651
logger.warning(f"Caught TypeError: {e}")
47-
super().__init__(
48-
in_channels=1,
49-
out_channels=1,
50-
feature_size=48,
51-
use_checkpoint=use_checkpoint,
52-
drop_rate=0.5,
53-
attn_drop_rate=0.5,
54-
use_v2=True,
55-
)
52+
init_kwargs["in_channels"] = 1
53+
parent_init(**init_kwargs)
5654

5755
# def forward(self, x_in):
5856
# y = super().forward(x_in)

0 commit comments

Comments
 (0)