Skip to content

Commit cb5116d

Browse files
committed
:wrench move inherent model params where they are instantiated and fixed padding only for sliding window inference
1 parent 21fee9b commit cb5116d

File tree

4 files changed

+101
-137
lines changed

4 files changed

+101
-137
lines changed

napari_cellseg3d/model_workers.py

Lines changed: 87 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
# Qt
4444
from qtpy.QtCore import Signal
4545

46-
4746
from napari_cellseg3d import utils
4847
from napari_cellseg3d import log_utility
4948

@@ -168,20 +167,19 @@ class InferenceWorker(GeneratorWorker):
168167
Inherits from :py:class:`napari.qt.threading.GeneratorWorker`"""
169168

170169
def __init__(
171-
self,
172-
device,
173-
model_dict,
174-
weights_dict,
175-
images_filepaths,
176-
results_path,
177-
filetype,
178-
transforms,
179-
instance,
180-
use_window,
181-
window_infer_size,
182-
window_overlap_percentage,
183-
keep_on_cpu,
184-
stats_csv,
170+
self,
171+
device,
172+
model_dict,
173+
weights_dict,
174+
images_filepaths,
175+
results_path,
176+
filetype,
177+
transforms,
178+
instance,
179+
use_window,
180+
window_infer_size,
181+
keep_on_cpu,
182+
stats_csv,
185183
):
186184
"""Initializes a worker for inference with the arguments needed by the :py:func:`~inference` function.
187185
@@ -206,8 +204,6 @@ def __init__(
206204
207205
* window_infer_size: size of window if use_window is True
208206
209-
* window_overlap_percentage: overlap of sliding windows if use_window is True
210-
211207
* keep_on_cpu: keep images on CPU or no
212208
213209
* stats_csv: compute stats on cells and save them to a csv file
@@ -231,7 +227,7 @@ def __init__(
231227
self.instance_params = instance
232228
self.use_window = use_window
233229
self.window_infer_size = window_infer_size
234-
self.window_overlap_percentage = window_overlap_percentage
230+
self.window_overlap_percentage = 0.8,
235231
self.keep_on_cpu = keep_on_cpu
236232
self.stats_to_csv = stats_csv
237233
"""These attributes are all arguments of :py:func:~inference, please see that for reference"""
@@ -343,36 +339,25 @@ def inference(self):
343339
# if self.device =="cuda": # TODO : fix mem alloc, this does not work it seems
344340
# torch.backends.cudnn.benchmark = False
345341

346-
# TODO : better solution than loading first image always ?
342+
self.log("\nChecking dimensions...")
347343
data_check = LoadImaged(keys=["image"])(images_dict[0])
348-
# print(data)
349344
check = data_check["image"].shape
350-
# print(check)
351-
# TODO remove
352-
# z_aniso = 5 / 1.5
353-
# if zoom is not None :
354-
# pad = utils.get_padding_dim(check, anisotropy_factor=zoom)
355-
# else:
356-
self.log("\nChecking dimensions...")
357-
dims = self.model_dict["segres_size"]
345+
pad = utils.get_padding_dim(check)
346+
347+
dims = self.model_dict["model_input_size"]
358348

359349
model = self.model_dict["class"].get_net()
360350
if self.model_dict["name"] == "SegResNet":
361-
model = self.model_dict["class"].get_net()(
351+
model = self.model_dict["class"].get_net(
362352
input_image_size=[
363353
dims,
364354
dims,
365355
dims,
366-
], # TODO FIX ! find a better way & remove model-specific code
367-
out_channels=1,
368-
# dropout_prob=0.3,
356+
]
369357
)
370358
elif self.model_dict["name"] == "SwinUNetR":
371-
model = self.model_dict["class"].get_net()(
359+
model = self.model_dict["class"].get_net(
372360
img_size=[dims, dims, dims],
373-
in_channels=1,
374-
out_channels=1,
375-
feature_size=48,
376361
use_checkpoint=False,
377362
)
378363

@@ -382,17 +367,29 @@ def inference(self):
382367

383368
# print("FILEPATHS PRINT")
384369
# print(self.images_filepaths)
385-
386-
load_transforms = Compose(
387-
[
388-
LoadImaged(keys=["image"]),
389-
# AddChanneld(keys=["image"]), #already done
390-
EnsureChannelFirstd(keys=["image"]),
391-
# Orientationd(keys=["image"], axcodes="PLI"),
392-
# anisotropic_transform,
393-
EnsureTyped(keys=["image"]),
394-
]
395-
)
370+
if self.use_window:
371+
load_transforms = Compose(
372+
[
373+
LoadImaged(keys=["image"]),
374+
# AddChanneld(keys=["image"]), #already done
375+
EnsureChannelFirstd(keys=["image"]),
376+
# Orientationd(keys=["image"], axcodes="PLI"),
377+
# anisotropic_transform,
378+
EnsureTyped(keys=["image"]),
379+
]
380+
)
381+
else:
382+
load_transforms = Compose(
383+
[
384+
LoadImaged(keys=["image"]),
385+
# AddChanneld(keys=["image"]), #already done
386+
EnsureChannelFirstd(keys=["image"]),
387+
# Orientationd(keys=["image"], axcodes="PLI"),
388+
# anisotropic_transform,
389+
SpatialPadd(keys=["image"], spatial_size=pad),
390+
EnsureTyped(keys=["image"]),
391+
]
392+
)
396393

397394
if not self.transforms["thresh"][0]:
398395
post_process_transforms = EnsureType()
@@ -448,16 +445,9 @@ def inference(self):
448445
inputs = inputs.to("cpu")
449446
print(inputs.shape)
450447

451-
if self.model_dict["name"] == "SwinUNetR":
452-
model_output = lambda inputs: post_process_transforms(
453-
torch.sigmoid(
454-
self.model_dict["class"].get_output(model, inputs)
455-
)
456-
)
457-
else:
458-
model_output = lambda inputs: post_process_transforms(
459-
self.model_dict["class"].get_output(model, inputs)
460-
)
448+
model_output = lambda inputs: post_process_transforms(
449+
self.model_dict["class"].get_output(model, inputs)
450+
)
461451

462452
if self.keep_on_cpu:
463453
dataset_device = "cpu"
@@ -479,7 +469,6 @@ def inference(self):
479469
device=dataset_device,
480470
overlap=window_overlap,
481471
)
482-
print("done window infernce")
483472
out = outputs.detach().cpu()
484473
# del outputs # TODO fix memory ?
485474
# outputs = None
@@ -519,14 +508,14 @@ def inference(self):
519508

520509
# File output save name : original-name_model_date+time_number.filetype
521510
file_path = (
522-
self.results_path
523-
+ "/"
524-
+ f"Prediction_{image_id}_"
525-
+ original_filename
526-
+ "_"
527-
+ self.model_dict["name"]
528-
+ f"_{time}_"
529-
+ self.filetype
511+
self.results_path
512+
+ "/"
513+
+ f"Prediction_{image_id}_"
514+
+ original_filename
515+
+ "_"
516+
+ self.model_dict["name"]
517+
+ f"_{time}_"
518+
+ self.filetype
530519
)
531520

532521
# print(filename)
@@ -567,14 +556,14 @@ def method(image):
567556
instance_labels = method(to_instance)
568557

569558
instance_filepath = (
570-
self.results_path
571-
+ "/"
572-
+ f"Instance_seg_labels_{image_id}_"
573-
+ original_filename
574-
+ "_"
575-
+ self.model_dict["name"]
576-
+ f"_{time}_"
577-
+ self.filetype
559+
self.results_path
560+
+ "/"
561+
+ f"Instance_seg_labels_{image_id}_"
562+
+ original_filename
563+
+ "_"
564+
+ self.model_dict["name"]
565+
+ f"_{time}_"
566+
+ self.filetype
578567
)
579568

580569
imwrite(instance_filepath, instance_labels)
@@ -617,23 +606,23 @@ class TrainingWorker(GeneratorWorker):
617606
Inherits from :py:class:`napari.qt.threading.GeneratorWorker`"""
618607

619608
def __init__(
620-
self,
621-
device,
622-
model_dict,
623-
weights_path,
624-
data_dicts,
625-
validation_percent,
626-
max_epochs,
627-
loss_function,
628-
learning_rate,
629-
val_interval,
630-
batch_size,
631-
results_path,
632-
sampling,
633-
num_samples,
634-
sample_size,
635-
do_augmentation,
636-
deterministic,
609+
self,
610+
device,
611+
model_dict,
612+
weights_path,
613+
data_dicts,
614+
validation_percent,
615+
max_epochs,
616+
loss_function,
617+
learning_rate,
618+
val_interval,
619+
batch_size,
620+
results_path,
621+
sampling,
622+
num_samples,
623+
sample_size,
624+
do_augmentation,
625+
deterministic,
637626
):
638627
"""Initializes a worker for inference with the arguments needed by the :py:func:`~train` function. Note: See :py:func:`~train`
639628
@@ -841,9 +830,8 @@ def train(self):
841830
else:
842831
size = check
843832
print(f"Size of image : {size}")
844-
model = model_class.get_net()(
833+
model = model_class.get_net(
845834
input_image_size=utils.get_padding_dim(size),
846-
out_channels=1,
847835
dropout_prob=0.3,
848836
)
849837
elif model_name == "SwinUNetR":
@@ -852,11 +840,8 @@ def train(self):
852840
else:
853841
size = check
854842
print(f"Size of image : {size}")
855-
model = model_class.get_net()(
843+
model = model_class.get_net(
856844
img_size=utils.get_padding_dim(size),
857-
in_channels=1,
858-
out_channels=1,
859-
feature_size=48,
860845
use_checkpoint=True,
861846
)
862847
else:
@@ -868,10 +853,10 @@ def train(self):
868853

869854
self.train_files, self.val_files = (
870855
self.data_dicts[
871-
0 : int(len(self.data_dicts) * self.validation_percent)
856+
0: int(len(self.data_dicts) * self.validation_percent)
872857
],
873858
self.data_dicts[
874-
int(len(self.data_dicts) * self.validation_percent) :
859+
int(len(self.data_dicts) * self.validation_percent):
875860
],
876861
)
877862

@@ -1032,10 +1017,10 @@ def train(self):
10321017
if self.device.type == "cuda":
10331018
self.log("Memory Usage:")
10341019
alloc_mem = round(
1035-
torch.cuda.memory_allocated(0) / 1024**3, 1
1020+
torch.cuda.memory_allocated(0) / 1024 ** 3, 1
10361021
)
10371022
reserved_mem = round(
1038-
torch.cuda.memory_reserved(0) / 1024**3, 1
1023+
torch.cuda.memory_reserved(0) / 1024 ** 3, 1
10391024
)
10401025
self.log(f"Allocated: {alloc_mem}GB")
10411026
self.log(f"Cached: {reserved_mem}GB")
@@ -1117,7 +1102,7 @@ def train(self):
11171102
yield train_report
11181103

11191104
weights_filename = (
1120-
f"{model_name}_best_metric" + f"_epoch_{epoch + 1}.pth"
1105+
f"{model_name}_best_metric" + f"_epoch_{epoch + 1}.pth"
11211106
)
11221107

11231108
if metric > best_metric:
@@ -1158,7 +1143,6 @@ def train(self):
11581143

11591144
# self.close()
11601145

1161-
11621146
# def this_is_fine(self):
11631147
# import numpy as np
11641148
#

napari_cellseg3d/models/model_SegResNet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from monai.networks.nets import SegResNetVAE
22

33

4-
def get_net():
5-
return SegResNetVAE
4+
def get_net(input_image_size, dropout_prob=None):
5+
return SegResNetVAE(input_image_size, out_channels=1, dropout_prob=dropout_prob)
66

77

88
def get_weights_file():

napari_cellseg3d/models/model_SwinUNetR.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
1+
import torch
12
from monai.networks.nets import SwinUNETR
23

34

45
def get_weights_file():
56
return ""
67

78

8-
def get_net():
9-
return SwinUNETR
9+
def get_net(img_size, use_checkpoint=True):
10+
return SwinUNETR(img_size, in_channels=1, out_channels=1, feature_size=48, use_checkpoint=use_checkpoint)
1011

1112

1213
def get_output(model, input):
1314
out = model(input)
14-
return out
15+
return torch.sigmoid(out)
1516

1617

1718
def get_validation(model, val_inputs):

0 commit comments

Comments
 (0)