Skip to content

Commit 5609040

Browse files
committed
Initial commit
Functional prototype for single layer inference Features : - functional basic inference on selected layer - improvements to inference code layout MISSING : - Proper button behaviour - Error if no layer selected - Full test of all functionalities
1 parent 226a47b commit 5609040

File tree

2 files changed

+40
-43
lines changed

2 files changed

+40
-43
lines changed

napari_cellseg3d/model_workers.py

Lines changed: 32 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from monai.transforms import AddChannel
2323
from monai.transforms import AsDiscrete
2424
from monai.transforms import Compose
25-
from monai.transforms import EnsureChannelFirst
2625
from monai.transforms import EnsureChannelFirstd
2726
from monai.transforms import EnsureType
2827
from monai.transforms import EnsureTyped
@@ -346,33 +345,32 @@ def load_folder(self):
346345
def load_layer(self):
347346

348347
volume = np.array(self.layer.data, dtype=np.int16)
348+
volume = np.swapaxes(
349+
volume, 0, 2
350+
) # for anisotropy to be monai-like, i.e. zyx
349351
print("Loading layer")
350-
print(volume.shape)
351-
352352
dims_check = volume.shape
353353
self.log("\nChecking dimensions...")
354354
pad = utils.get_padding_dim(dims_check)
355-
print(f"padding: {pad}")
355+
print(volume.shape)
356+
print(volume.dtype)
356357
load_transforms = Compose(
357358
[
358-
AddChannel(),
359-
AddChannel(),
360359
ToTensor(),
361-
EnsureType(),
362-
# EnsureChannelFirst(),
363-
# Orientationd(keys=["image"], axcodes="PLI"),
364360
# anisotropic_transform,
361+
AddChannel(),
365362
SpatialPad(spatial_size=pad),
366-
]
363+
AddChannel(),
364+
EnsureType(),
365+
],
366+
map_items=False,
367+
log_stats=True,
367368
)
368369

369370
self.log("\nLoading dataset...")
370-
inference_ds = Dataset(data=volume, transform=load_transforms)
371-
inference_loader = DataLoader(
372-
inference_ds, batch_size=1, num_workers=2
373-
)
371+
input_image = load_transforms(volume)
374372
self.log("Done")
375-
return inference_loader
373+
return input_image
376374

377375
def model_output(self, inputs, model, post_process_transforms):
378376

@@ -457,12 +455,7 @@ def method(image):
457455
self.log(os.path.split(instance_filepath)[1])
458456

459457
# print(self.stats_to_csv)
460-
if self.stats_to_csv:
461-
data_dict = volume_stats(
462-
instance_labels
463-
) # TODO test with area mesh function
464-
return data_dict
465-
return None
458+
return instance_labels
466459

467460
def inference_on_list(self, inf_data, i, model, post_process_transforms):
468461

@@ -638,16 +631,13 @@ def inference_on_list(self, inf_data, i, model, post_process_transforms):
638631
# print(result)
639632
return result
640633

641-
def inference_on_array(self, image, model, post_process_transforms):
634+
def inference_on_layer(self, image, model, post_process_transforms):
642635

643636
self.log("-" * 10)
644637
self.log(f"Inference started on layer...")
645638

646639
# print(inputs.shape)
647-
648-
image = ToTensor()(image)
649-
image = AddChannel()(image)
650-
image = AddChannel()(image)
640+
image = image.type(torch.FloatTensor)
651641
inputs = image.to("cpu")
652642

653643
model_output = lambda inputs: post_process_transforms(
@@ -732,7 +722,13 @@ def inference_on_array(self, image, model, post_process_transforms):
732722
#################
733723
#################
734724
if self.instance_params["do_instance"]:
735-
data_dict = self.instance_seg(to_instance)
725+
instance_labels = self.instance_seg(to_instance)
726+
if self.stats_to_csv:
727+
data_dict = volume_stats(
728+
instance_labels
729+
) # TODO test with area mesh function
730+
else:
731+
data_dict = None
736732
# self.log(
737733
# f"\nRunning instance segmentation for image n°{image_id}"
738734
# )
@@ -791,16 +787,14 @@ def inference_on_array(self, image, model, post_process_transforms):
791787
instance_labels = None
792788
data_dict = None
793789

794-
# logging(f"Inference completed on image {i+1}")
795790
result = {
796791
"image_id": 0,
797792
"original": None,
798-
"instance_labels": instance_labels,
793+
"instance_labels": np.swapaxes(instance_labels, 0,2),
799794
"object stats": data_dict,
800-
"result": out,
795+
"result": np.swapaxes(out, 0,2),
801796
"model_name": self.model_dict["name"],
802797
}
803-
# print(result)
804798
return result
805799

806800
def inference(self):
@@ -945,8 +939,6 @@ def inference(self):
945939
)
946940
elif is_folder:
947941
inference_loader = self.load_folder()
948-
elif is_layer:
949-
inference_loader = self.load_layer()
950942
##################
951943
##################
952944
# DEBUG
@@ -957,6 +949,10 @@ def inference(self):
957949
print(image.shape)
958950
##################
959951
##################
952+
elif is_layer:
953+
input_image = self.load_layer()
954+
print(input_image.shape)
955+
960956
else:
961957
raise ValueError("No data has been provided. Aborting.")
962958

@@ -968,13 +964,11 @@ def inference(self):
968964
if is_folder:
969965
for i, inf_data in enumerate(inference_loader):
970966
yield self.inference_on_list(
971-
inf_data,i, model, post_process_transforms
967+
inf_data, i, model, post_process_transforms
972968
)
973969
elif is_layer:
974-
image = self.layer.data
975-
print(image.shape)
976-
yield self.inference_on_array(
977-
image, model, post_process_transforms
970+
yield self.inference_on_layer(
971+
input_image, model, post_process_transforms
978972
)
979973
# for i, inf_data in enumerate(inference_loader):
980974
#

napari_cellseg3d/plugin_model_inference.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
from napari_cellseg3d.model_workers import InferenceWorker
1616

1717

18+
# TODO for layer inference : button behaviour/visibility, error if no layer selected, test all funcs
19+
20+
1821
class Inferer(ModelFramework):
1922
"""A plugin to run already trained models in evaluation mode to preform inference and output a label on all
2023
given volumes."""
@@ -605,7 +608,7 @@ def start(self, on_layer=False):
605608
keep_on_cpu=self.keep_on_cpu,
606609
stats_csv=self.stats_to_csv,
607610
)
608-
elif on_layer:
611+
else:
609612
layer = self._viewer.layers.selection.active
610613
self.worker = InferenceWorker(
611614
device=device,
@@ -666,7 +669,7 @@ def on_error(self):
666669
"""Catches errors and tries to clean up. TODO : upgrade"""
667670
self.log.print_and_log("Worker errored...")
668671
self.log.print_and_log("Trying to clean up...")
669-
self.btn_start.setText("Start")
672+
self.btn_start.setText("Start on folder")
670673
self.btn_close.setVisible(True)
671674

672675
self.worker = None
@@ -676,7 +679,7 @@ def on_finish(self):
676679
"""Catches finished signal from worker, resets workspace for next run."""
677680
self.log.print_and_log(f"\nWorker finished at {utils.get_time()}")
678681
self.log.print_and_log("*" * 20)
679-
self.btn_start.setText("Start")
682+
self.btn_start.setText("Start on folder")
680683
self.btn_close.setVisible(True)
681684

682685
self.worker = None
@@ -707,13 +710,13 @@ def on_yield(data, widget):
707710

708711
zoom = widget.zoom
709712

710-
print(data["original"].shape)
713+
# print(data["original"].shape)
711714
print(data["result"].shape)
712715

713716
viewer.dims.ndisplay = 3
714717
viewer.scale_bar.visible = True
715718

716-
if widget.show_original:
719+
if widget.show_original and data["original"] is not None:
717720
original_layer = viewer.add_image(
718721
data["original"],
719722
colormap="inferno",

0 commit comments

Comments
 (0)