Skip to content

Commit 098cfae

Browse files
committed
Initial commit
Functional prototype for single layer inference Features : - functional basic inference on selected layer - improvements to inference code layout MISSING : - Proper button behaviour - Error if no layer selected - Full test of all functionalities
1 parent 3a62fee commit 098cfae

File tree

2 files changed

+40
-43
lines changed

2 files changed

+40
-43
lines changed

napari_cellseg3d/model_workers.py

Lines changed: 32 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from monai.transforms import AddChannel
1818
from monai.transforms import AsDiscrete
1919
from monai.transforms import Compose
20-
from monai.transforms import EnsureChannelFirst
2120
from monai.transforms import EnsureChannelFirstd
2221
from monai.transforms import EnsureType
2322
from monai.transforms import EnsureTyped
@@ -246,33 +245,32 @@ def load_folder(self):
246245
def load_layer(self):
247246

248247
volume = np.array(self.layer.data, dtype=np.int16)
248+
volume = np.swapaxes(
249+
volume, 0, 2
250+
) # for anisotropy to be monai-like, i.e. zyx
249251
print("Loading layer")
250-
print(volume.shape)
251-
252252
dims_check = volume.shape
253253
self.log("\nChecking dimensions...")
254254
pad = utils.get_padding_dim(dims_check)
255-
print(f"padding: {pad}")
255+
print(volume.shape)
256+
print(volume.dtype)
256257
load_transforms = Compose(
257258
[
258-
AddChannel(),
259-
AddChannel(),
260259
ToTensor(),
261-
EnsureType(),
262-
# EnsureChannelFirst(),
263-
# Orientationd(keys=["image"], axcodes="PLI"),
264260
# anisotropic_transform,
261+
AddChannel(),
265262
SpatialPad(spatial_size=pad),
266-
]
263+
AddChannel(),
264+
EnsureType(),
265+
],
266+
map_items=False,
267+
log_stats=True,
267268
)
268269

269270
self.log("\nLoading dataset...")
270-
inference_ds = Dataset(data=volume, transform=load_transforms)
271-
inference_loader = DataLoader(
272-
inference_ds, batch_size=1, num_workers=2
273-
)
271+
input_image = load_transforms(volume)
274272
self.log("Done")
275-
return inference_loader
273+
return input_image
276274

277275
def model_output(self, inputs, model, post_process_transforms):
278276

@@ -357,12 +355,7 @@ def method(image):
357355
self.log(os.path.split(instance_filepath)[1])
358356

359357
# print(self.stats_to_csv)
360-
if self.stats_to_csv:
361-
data_dict = volume_stats(
362-
instance_labels
363-
) # TODO test with area mesh function
364-
return data_dict
365-
return None
358+
return instance_labels
366359

367360
def inference_on_list(self, inf_data, i, model, post_process_transforms):
368361

@@ -538,16 +531,13 @@ def inference_on_list(self, inf_data, i, model, post_process_transforms):
538531
# print(result)
539532
return result
540533

541-
def inference_on_array(self, image, model, post_process_transforms):
534+
def inference_on_layer(self, image, model, post_process_transforms):
542535

543536
self.log("-" * 10)
544537
self.log(f"Inference started on layer...")
545538

546539
# print(inputs.shape)
547-
548-
image = ToTensor()(image)
549-
image = AddChannel()(image)
550-
image = AddChannel()(image)
540+
image = image.type(torch.FloatTensor)
551541
inputs = image.to("cpu")
552542

553543
model_output = lambda inputs: post_process_transforms(
@@ -632,7 +622,13 @@ def inference_on_array(self, image, model, post_process_transforms):
632622
#################
633623
#################
634624
if self.instance_params["do_instance"]:
635-
data_dict = self.instance_seg(to_instance)
625+
instance_labels = self.instance_seg(to_instance)
626+
if self.stats_to_csv:
627+
data_dict = volume_stats(
628+
instance_labels
629+
) # TODO test with area mesh function
630+
else:
631+
data_dict = None
636632
# self.log(
637633
# f"\nRunning instance segmentation for image n°{image_id}"
638634
# )
@@ -691,16 +687,14 @@ def inference_on_array(self, image, model, post_process_transforms):
691687
instance_labels = None
692688
data_dict = None
693689

694-
# logging(f"Inference completed on image {i+1}")
695690
result = {
696691
"image_id": 0,
697692
"original": None,
698-
"instance_labels": instance_labels,
693+
"instance_labels": np.swapaxes(instance_labels, 0,2),
699694
"object stats": data_dict,
700-
"result": out,
695+
"result": np.swapaxes(out, 0,2),
701696
"model_name": self.model_dict["name"],
702697
}
703-
# print(result)
704698
return result
705699

706700
def inference(self):
@@ -838,8 +832,6 @@ def inference(self):
838832
)
839833
elif is_folder:
840834
inference_loader = self.load_folder()
841-
elif is_layer:
842-
inference_loader = self.load_layer()
843835
##################
844836
##################
845837
# DEBUG
@@ -850,6 +842,10 @@ def inference(self):
850842
print(image.shape)
851843
##################
852844
##################
845+
elif is_layer:
846+
input_image = self.load_layer()
847+
print(input_image.shape)
848+
853849
else:
854850
raise ValueError("No data has been provided. Aborting.")
855851

@@ -861,13 +857,11 @@ def inference(self):
861857
if is_folder:
862858
for i, inf_data in enumerate(inference_loader):
863859
yield self.inference_on_list(
864-
inf_data,i, model, post_process_transforms
860+
inf_data, i, model, post_process_transforms
865861
)
866862
elif is_layer:
867-
image = self.layer.data
868-
print(image.shape)
869-
yield self.inference_on_array(
870-
image, model, post_process_transforms
863+
yield self.inference_on_layer(
864+
input_image, model, post_process_transforms
871865
)
872866
# for i, inf_data in enumerate(inference_loader):
873867
#

napari_cellseg3d/plugin_model_inference.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
from napari_cellseg3d.model_workers import InferenceWorker
1616

1717

18+
# TODO for layer inference : button behaviour/visibility, error if no layer selected, test all funcs
19+
20+
1821
class Inferer(ModelFramework):
1922
"""A plugin to run already trained models in evaluation mode to preform inference and output a label on all
2023
given volumes."""
@@ -605,7 +608,7 @@ def start(self, on_layer=False):
605608
keep_on_cpu=self.keep_on_cpu,
606609
stats_csv=self.stats_to_csv,
607610
)
608-
elif on_layer:
611+
else:
609612
layer = self._viewer.layers.selection.active
610613
self.worker = InferenceWorker(
611614
device=device,
@@ -663,7 +666,7 @@ def on_error(self):
663666
"""Catches errors and tries to clean up. TODO : upgrade"""
664667
self.log.print_and_log("Worker errored...")
665668
self.log.print_and_log("Trying to clean up...")
666-
self.btn_start.setText("Start")
669+
self.btn_start.setText("Start on folder")
667670
self.btn_close.setVisible(True)
668671

669672
self.worker = None
@@ -673,7 +676,7 @@ def on_finish(self):
673676
"""Catches finished signal from worker, resets workspace for next run."""
674677
self.log.print_and_log(f"\nWorker finished at {utils.get_time()}")
675678
self.log.print_and_log("*" * 20)
676-
self.btn_start.setText("Start")
679+
self.btn_start.setText("Start on folder")
677680
self.btn_close.setVisible(True)
678681

679682
self.worker = None
@@ -704,13 +707,13 @@ def on_yield(data, widget):
704707

705708
zoom = widget.zoom
706709

707-
print(data["original"].shape)
710+
# print(data["original"].shape)
708711
print(data["result"].shape)
709712

710713
viewer.dims.ndisplay = 3
711714
viewer.scale_bar.visible = True
712715

713-
if widget.show_original:
716+
if widget.show_original and data["original"] is not None:
714717
original_layer = viewer.add_image(
715718
data["original"],
716719
colormap="inferno",

0 commit comments

Comments
 (0)