Skip to content

Commit 5c1f9c2

Browse files
committed
Fixes
- Fixed missing filename for saving layer inference - Fixed duplicate model init - Fixed error in model init - lint
1 parent 4babcff commit 5c1f9c2

File tree

2 files changed

+49
-53
lines changed

2 files changed

+49
-53
lines changed

napari_cellseg3d/model_workers.py

Lines changed: 46 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -319,27 +319,27 @@ def load_folder(self):
319319
self.log("\nChecking dimensions...")
320320
pad = utils.get_padding_dim(check)
321321

322-
dims = self.model_dict["model_input_size"]
323-
324-
if self.model_dict["name"] == "SegResNet":
325-
model = self.model_dict["class"].get_net(
326-
input_image_size=[
327-
dims,
328-
dims,
329-
dims,
330-
]
331-
)
332-
elif self.model_dict["name"] == "SwinUNetR":
333-
model = self.model_dict["class"].get_net(
334-
img_size=[dims, dims, dims],
335-
use_checkpoint=False,
336-
)
337-
else:
338-
model = self.model_dict["class"].get_net()
339-
340-
self.log_parameters()
341-
342-
model.to(self.device)
322+
# dims = self.model_dict["model_input_size"]
323+
#
324+
# if self.model_dict["name"] == "SegResNet":
325+
# model = self.model_dict["class"].get_net(
326+
# input_image_size=[
327+
# dims,
328+
# dims,
329+
# dims,
330+
# ]
331+
# )
332+
# elif self.model_dict["name"] == "SwinUNetR":
333+
# model = self.model_dict["class"].get_net(
334+
# img_size=[dims, dims, dims],
335+
# use_checkpoint=False,
336+
# )
337+
# else:
338+
# model = self.model_dict["class"].get_net()
339+
#
340+
# self.log_parameters()
341+
#
342+
# model.to(self.device)
343343

344344
# print("FILEPATHS PRINT")
345345
# print(self.images_filepaths)
@@ -399,18 +399,18 @@ def load_layer(self):
399399
# print(volume.shape)
400400
# print(volume.dtype)
401401
if self.use_window:
402-
load_transforms = Compose(
403-
[
404-
ToTensor(),
405-
# anisotropic_transform,
406-
AddChannel(),
407-
# SpatialPad(spatial_size=pad),
408-
AddChannel(),
409-
EnsureType(),
410-
],
411-
map_items=False,
412-
log_stats=True,
413-
)
402+
load_transforms = Compose(
403+
[
404+
ToTensor(),
405+
# anisotropic_transform,
406+
AddChannel(),
407+
# SpatialPad(spatial_size=pad),
408+
AddChannel(),
409+
EnsureType(),
410+
],
411+
map_items=False,
412+
log_stats=True,
413+
)
414414
else:
415415
load_transforms = Compose(
416416
[
@@ -558,13 +558,12 @@ def save_image(
558558
)
559559

560560
imwrite(file_path, image)
561+
filename = os.path.split(file_path)[1]
561562

562563
if from_layer:
563-
self.log(f"\nLayer prediction saved as :")
564+
self.log(f"\nLayer prediction saved as : {filename}")
564565
else:
565-
self.log(f"\nFile n°{i+1} saved as :")
566-
filename = os.path.split(file_path)[1]
567-
self.log(filename)
566+
self.log(f"\nFile n°{i+1} saved as : {filename}")
568567

569568
def aniso_transform(self, image):
570569
zoom = self.transforms["zoom"][1]
@@ -680,9 +679,13 @@ def inference_on_layer(self, image, model, post_process_transforms):
680679

681680
self.save_image(out, from_layer=True)
682681

683-
instance_labels, data_dict = self.get_instance_result(out,from_layer=True)
682+
instance_labels, data_dict = self.get_instance_result(
683+
out, from_layer=True
684+
)
684685

685-
return self.create_result_dict(out, instance_labels, from_layer=True, data_dict=data_dict)
686+
return self.create_result_dict(
687+
out, instance_labels, from_layer=True, data_dict=data_dict
688+
)
686689

687690
def inference(self):
688691
"""
@@ -724,20 +727,18 @@ def inference(self):
724727
torch.set_num_threads(1) # required for threading on macOS ?
725728
self.log("Number of threads has been set to 1 for macOS")
726729

727-
728730
try:
729731
dims = self.model_dict["model_input_size"]
730-
732+
self.log(f"MODEL DIMS : {dims}")
733+
self.log(self.model_dict["name"])
731734

732735
if self.model_dict["name"] == "SegResNet":
733-
model = self.model_dict["class"].get_net()(
736+
model = self.model_dict["class"].get_net(
734737
input_image_size=[
735738
dims,
736739
dims,
737740
dims,
738741
], # TODO FIX ! find a better way & remove model-specific code
739-
out_channels=1,
740-
# dropout_prob=0.3,
741742
)
742743
elif self.model_dict["name"] == "SwinUNetR":
743744
model = self.model_dict["class"].get_net(
@@ -772,10 +773,7 @@ def inference(self):
772773
AsDiscrete(threshold=t), EnsureType()
773774
)
774775

775-
776-
self.log(
777-
"\nLoading weights..."
778-
)
776+
self.log("\nLoading weights...")
779777

780778
if self.weights_dict["custom"]:
781779
weights = self.weights_dict["path"]
@@ -1091,9 +1089,6 @@ def train(self):
10911089
model = model_class.get_net() # get an instance of the model
10921090
model = model.to(self.device)
10931091

1094-
1095-
1096-
10971092
epoch_loss_values = []
10981093
val_metric_values = []
10991094

napari_cellseg3d/plugin_model_inference.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,9 @@ def __init__(self, viewer: "napari.viewer.Viewer"):
100100
######################
101101
######################
102102
# TODO : better way to handle SegResNet size reqs ?
103-
self.model_input_size = ui.IntIncrementCounter(min=1, max=1024, default=128)
103+
self.model_input_size = ui.IntIncrementCounter(
104+
min=1, max=1024, default=128
105+
)
104106
self.model_choice.currentIndexChanged.connect(
105107
self.toggle_display_model_input_size
106108
)
@@ -771,7 +773,6 @@ def on_yield(data, widget):
771773

772774
zoom = widget.zoom
773775

774-
775776
viewer.dims.ndisplay = 3
776777
viewer.scale_bar.visible = True
777778

0 commit comments

Comments
 (0)