Skip to content

Commit a6d3cce

Browse files
committed
Fixes + reqs update
- Fixed model instantiation - Fix window overlap argument type error - Updated reqs.txt to have einops (MONAI optional dep.)
1 parent 9d0190c commit a6d3cce

File tree

3 files changed

+9
-7
lines changed

3 files changed

+9
-7
lines changed

napari_cellseg3d/model_workers.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def __init__(
227227
self.instance_params = instance
228228
self.use_window = use_window
229229
self.window_infer_size = window_infer_size
230-
self.window_overlap_percentage = (0.8,)
230+
self.window_overlap_percentage = 0.8
231231
self.keep_on_cpu = keep_on_cpu
232232
self.stats_to_csv = stats_csv
233233
"""These attributes are all arguments of :py:func:~inference, please see that for reference"""
@@ -346,7 +346,7 @@ def inference(self):
346346

347347
dims = self.model_dict["model_input_size"]
348348

349-
model = self.model_dict["class"].get_net()
349+
350350
if self.model_dict["name"] == "SegResNet":
351351
model = self.model_dict["class"].get_net(
352352
input_image_size=[
@@ -360,6 +360,8 @@ def inference(self):
360360
img_size=[dims, dims, dims],
361361
use_checkpoint=False,
362362
)
363+
else:
364+
model = self.model_dict["class"].get_net()
363365

364366
self.log_parameters()
365367

@@ -445,6 +447,7 @@ def inference(self):
445447
inputs = inputs.to("cpu")
446448
print(inputs.shape)
447449

450+
# self.log("output")
448451
model_output = lambda inputs: post_process_transforms(
449452
self.model_dict["class"].get_output(model, inputs)
450453
)
@@ -460,6 +463,8 @@ def inference(self):
460463
else:
461464
window_size = None
462465
window_overlap = 0.25
466+
467+
# self.log("window")
463468
outputs = sliding_window_inference(
464469
inputs,
465470
roi_size=window_size,

napari_cellseg3d/models/model_SwinUNetR.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44

55
def get_weights_file():
6-
return ""
6+
return "Swin64_best_metric.pth"
77

88

99
def get_net(img_size, use_checkpoint=True):

requirements.txt

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,10 @@ napari[all]>=0.4.14
1313
QtPy
1414
opencv-python>=4.5.5
1515
dask-image>=0.6.0
16-
scikit-image>=0.19.2
1716
matplotlib>=3.4.1
1817
tifffile>=2022.2.9
1918
imageio-ffmpeg>=0.4.5
2019
torch>=1.11
21-
monai>=0.9.0
22-
nibabel
20+
monai[nibabel,scikit-image,itk,einops]>=0.9.0
2321
pillow
24-
itk>=5.2.0
2522
vispy>=0.9.6

0 commit comments

Comments
 (0)