Skip to content

Commit 8cf755c

Browse files
committed
DRY
- Refactored inference on folder/layer to not have horrible duplicates
1 parent 16c72da commit 8cf755c

File tree

1 file changed

+133
-121
lines changed

1 file changed

+133
-121
lines changed

napari_cellseg3d/model_workers.py

Lines changed: 133 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -311,9 +311,8 @@ def load_folder(self):
311311

312312
# TODO : better solution than loading first image always ?
313313
data_check = LoadImaged(keys=["image"])(images_dict[0])
314-
# print(data)
314+
315315
check = data_check["image"].shape
316-
# print(check)
317316
# TODO remove
318317
# z_aniso = 5 / 1.5
319318
# if zoom is not None :
@@ -384,7 +383,14 @@ def load_layer(self):
384383
self.log("Done")
385384
return input_image
386385

387-
def model_output(self, inputs, model, post_process_transforms):
386+
def model_output(
387+
self,
388+
inputs,
389+
model,
390+
post_process_transforms,
391+
post_process=True,
392+
aniso_transform=None,
393+
):
388394

389395
inputs = inputs.to("cpu")
390396

@@ -412,7 +418,103 @@ def model_output(self, inputs, model, post_process_transforms):
412418
)
413419

414420
out = outputs.detach().cpu()
415-
return out
421+
422+
if aniso_transform is not None:
423+
out = aniso_transform(out)
424+
425+
if post_process:
426+
out = np.array(out).astype(np.float32)
427+
out = np.squeeze(out)
428+
return out
429+
else:
430+
return out
431+
432+
def create_result_dict(
433+
self,
434+
semantic_labels,
435+
instance_labels,
436+
from_layer: bool,
437+
original=None,
438+
data_dict=None,
439+
i=0,
440+
):
441+
442+
if not from_layer and original is None:
443+
raise ValueError(
444+
"If the image is not from a layer, an original should always be available"
445+
)
446+
447+
if from_layer:
448+
if i != 0:
449+
raise ValueError(
450+
"A layer's ID should always be 0 (default value)"
451+
)
452+
semantic_labels = np.swapaxes(semantic_labels, 0, 2)
453+
454+
return {
455+
"image_id": i + 1,
456+
"original": original,
457+
"instance_labels": instance_labels,
458+
"object stats": data_dict,
459+
"result": semantic_labels,
460+
"model_name": self.model_dict["name"],
461+
}
462+
463+
def get_original_filename(self, i):
464+
return os.path.basename(self.images_filepaths[i]).split(".")[0]
465+
466+
def get_instance_result(self, semantic_labels, from_layer=False, i=-1):
467+
468+
if not from_layer and i == -1:
469+
raise ValueError(
470+
"An ID should be provided when running from a file"
471+
)
472+
473+
if self.instance_params["do_instance"]:
474+
instance_labels = self.instance_seg(
475+
semantic_labels,
476+
i + 1,
477+
)
478+
if from_layer:
479+
instance_labels = np.swapaxes(instance_labels, 0, 2)
480+
data_dict = self.stats_csv(instance_labels)
481+
else:
482+
instance_labels = None
483+
data_dict = None
484+
return instance_labels, data_dict
485+
486+
def save_image(
487+
self,
488+
image,
489+
from_layer=False,
490+
i=0,
491+
):
492+
493+
if not from_layer:
494+
original_filename = "_" + self.get_original_filename(i) + "_"
495+
else:
496+
original_filename = "_"
497+
498+
time = utils.get_date_time()
499+
500+
file_path = (
501+
self.results_path
502+
+ "/"
503+
+ f"Prediction_{i+1}"
504+
+ original_filename
505+
+ self.model_dict["name"]
506+
+ f"_{time}_"
507+
+ self.filetype
508+
)
509+
510+
imwrite(file_path, image)
511+
512+
if from_layer:
513+
self.log(f"\nLayer prediction saved as :")
514+
else:
515+
self.log(f"\nFile n°{i+1} saved as :")
516+
filename = os.path.split(file_path)[1]
517+
self.log(filename)
416518

417519
def aniso_transform(self, image):
418520
zoom = self.transforms["zoom"][1]
@@ -468,71 +570,35 @@ def method(image):
468570
return instance_labels
469571
# print(self.stats_to_csv)
470572

471-
def inference_on_list(self, inf_data, i, model, post_process_transforms):
573+
def inference_on_folder(self, inf_data, i, model, post_process_transforms):
472574

473575
self.log("-" * 10)
474576
self.log(f"Inference started on image n°{i + 1}...")
475577

476578
inputs = inf_data["image"]
477-
# print(inputs.shape)
478-
479-
out = self.model_output(inputs, model, post_process_transforms)
480-
out = self.aniso_transform(out)
481579

482-
out = np.array(out).astype(np.float32)
483-
out = np.squeeze(out)
484-
to_instance = out # avoid post processing since thresholding is done there anyway
485-
image_id = i + 1
486-
time = utils.get_date_time()
487-
# print(time)
488-
489-
original_filename = os.path.basename(self.images_filepaths[i]).split(
490-
"."
491-
)[0]
492-
493-
# File output save name : original-name_model_date+time_number.filetype
494-
file_path = (
495-
self.results_path
496-
+ "/"
497-
+ f"Prediction_{image_id}_"
498-
+ original_filename
499-
+ "_"
500-
+ self.model_dict["name"]
501-
+ f"_{time}_"
502-
+ self.filetype
580+
out = self.model_output(
581+
inputs,
582+
model,
583+
post_process_transforms,
584+
aniso_transform=self.aniso_transform,
503585
)
504586

505-
# print(filename)
506-
imwrite(file_path, out)
507-
508-
self.log(f"\nFile n°{image_id} saved as :")
509-
filename = os.path.split(file_path)[1]
510-
self.log(filename)
511-
512-
#################
513-
#################
514-
#################
515-
if self.instance_params["do_instance"]:
516-
instance_labels = self.instance_seg(
517-
to_instance, image_id, original_filename
518-
)
519-
data_dict = self.stats_csv(instance_labels)
520-
else:
521-
instance_labels = None
522-
data_dict = None
587+
self.save_image(out, i=i)
588+
instance_labels, data_dict = self.get_instance_result(out, i=i)
523589

524590
original = np.array(inf_data["image"]).astype(np.float32)
525591

526592
self.log(f"Inference completed on layer")
527-
result = {
528-
"image_id": i + 1,
529-
"original": original,
530-
"instance_labels": instance_labels,
531-
"object stats": data_dict,
532-
"result": out,
533-
"model_name": self.model_dict["name"],
534-
}
535-
return result
593+
594+
return self.create_result_dict(
595+
out,
596+
instance_labels,
597+
from_layer=False,
598+
original=original,
599+
data_dict=data_dict,
600+
i=i,
601+
)
536602

537603
def stats_csv(self, instance_labels):
538604
if self.stats_to_csv:
@@ -554,74 +620,20 @@ def inference_on_layer(self, image, model, post_process_transforms):
554620
self.log("-" * 10)
555621
self.log(f"Inference started on layer...")
556622

557-
# print(inputs.shape)
558623
image = image.type(torch.FloatTensor)
559-
out = self.model_output(image, model, post_process_transforms)
560-
out = self.aniso_transform(out)
561-
# if self.transforms["zoom"][0]:
562-
# zoom = self.transforms["zoom"][1]
563-
# anisotropic_transform = Zoom(
564-
# zoom=zoom,
565-
# keep_size=False,
566-
# padding_mode="empty",
567-
# )
568-
# out = anisotropic_transform(out[0])
569-
##################
570-
##################
571-
##################
572-
out = post_process_transforms(out)
573-
out = np.array(out).astype(np.float32)
574-
out = np.squeeze(out)
575-
to_instance = out # avoid post processing since thresholding is done there anyway
576-
577-
# batch_len = out.shape[1]
578-
# print("trying to check len")
579-
# print(batch_len)
580-
# if batch_len != 1 :
581-
# sum = np.sum(out, axis=1)
582-
# print(sum.shape)
583-
# out = sum
584-
# print(out.shape)
585624

586-
time = utils.get_date_time()
587-
# print(time)
588-
# File output save name : original-name_model_date+time_number.filetype
589-
file_path = os.path.join(
590-
self.results_path,
591-
f"Prediction_layer"
592-
+ "_"
593-
+ self.model_dict["name"]
594-
+ f"_{time}_"
595-
+ self.filetype,
625+
out = self.model_output(
626+
image,
627+
model,
628+
post_process_transforms,
629+
aniso_transform=self.aniso_transform,
596630
)
597631

598-
# print(filename)
599-
imwrite(file_path, out)
600-
601-
self.log(f"\nLayer prediction saved as :")
602-
filename = os.path.split(file_path)[1]
603-
self.log(filename)
632+
self.save_image(out, from_layer=True)
604633

605-
#################
606-
#################
607-
#################
608-
if self.instance_params["do_instance"]:
609-
instance_labels = self.instance_seg(to_instance)
610-
instance_labels = np.swapaxes(instance_labels, 0, 2)
611-
data_dict = self.stats_csv(instance_labels)
612-
else:
613-
instance_labels = None
614-
data_dict = None
634+
instance_labels, data_dict = self.get_instance_result(out,from_layer=True)
615635

616-
result = {
617-
"image_id": 0,
618-
"original": None,
619-
"instance_labels": instance_labels,
620-
"object stats": data_dict,
621-
"result": np.swapaxes(out, 0, 2),
622-
"model_name": self.model_dict["name"],
623-
}
624-
return result
636+
return self.create_result_dict(out, instance_labels, from_layer=True, data_dict=data_dict)
625637

626638
def inference(self):
627639
"""
@@ -789,7 +801,7 @@ def inference(self):
789801
################################
790802
if is_folder:
791803
for i, inf_data in enumerate(inference_loader):
792-
yield self.inference_on_list(
804+
yield self.inference_on_folder(
793805
inf_data, i, model, post_process_transforms
794806
)
795807
elif is_layer:
@@ -1328,7 +1340,7 @@ def train(self):
13281340
f"at epoch: {best_metric_epoch}"
13291341
)
13301342
model.to("cpu")
1331-
1343+
13321344
except Exception as e:
13331345
self.log(f"Error : {e}")
13341346
self.quit()

0 commit comments

Comments
 (0)