Skip to content

Commit 4374d17

Browse files
committed
Fix tests issue with MONAI 1.3.0
1 parent ccd422a commit 4374d17

File tree

3 files changed

+27
-18
lines changed

3 files changed

+27
-18
lines changed

napari_cellseg3d/code_models/worker_inference.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@ def load_layer(self):
244244
if self.config.model_info.name != "WNet"
245245
else lambda x: x
246246
)
247+
volume = np.reshape(volume, newshape=(1, *volume.shape))
247248
if self.config.sliding_window_config.is_enabled():
248249
load_transforms = Compose(
249250
[
@@ -278,6 +279,7 @@ def load_layer(self):
278279
)
279280

280281
input_image = load_transforms(volume)
282+
input_image = input_image.unsqueeze(0)
281283
logger.debug(f"INPUT IMAGE SHAPE : {input_image.shape}")
282284
logger.debug(f"INPUT IMAGE TYPE : {input_image.dtype}")
283285
self.log("Done")
@@ -579,7 +581,6 @@ def stats_csv(self, instance_labels):
579581
def inference_on_layer(self, image, model, post_process_transforms):
580582
self.log("-" * 10)
581583
self.log("Inference started on layer...")
582-
image = image.view((1, 1, *image.shape))
583584
logger.debug(f"Layer shape @ inference input: {image.shape}")
584585
out = self.model_output(
585586
image,

napari_cellseg3d/code_models/worker_training.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1246,7 +1246,7 @@ def get_patch_loader_func(num_samples):
12461246
logger.debug("Cache dataset : train")
12471247
train_dataset = CacheDataset(
12481248
data=self.train_files,
1249-
transform=Compose(load_whole_images, train_transforms),
1249+
transform=Compose([load_whole_images, train_transforms]),
12501250
)
12511251
logger.debug("Cache dataset : val")
12521252
validation_dataset = CacheDataset(

napari_cellseg3d/code_plugins/plugin_model_inference.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -609,22 +609,30 @@ def _display_results(self, result: InferenceResult):
609609
# stats = result.stats
610610

611611
if self.worker_config.compute_stats and stats is not None:
612-
stats_dict = stats.get_dict()
613-
stats_df = pd.DataFrame(stats_dict)
614-
615-
self.log.print_and_log(
616-
f"Number of instances in channel {i} : {stats.number_objects[0]}"
617-
)
618-
619-
csv_name = f"/{method_name}_seg_results_{image_id}_channel_{i}_{utils.get_date_time()}.csv"
620-
stats_df.to_csv(
621-
self.worker_config.results_path + csv_name,
622-
index=False,
623-
)
624-
625-
# self.log.print_and_log(
626-
# f"OBJECTS DETECTED : {number_cells}\n"
627-
# )
612+
try:
613+
stats_dict = stats.get_dict()
614+
stats_df = pd.DataFrame(stats_dict)
615+
616+
self.log.print_and_log(
617+
f"Number of instances in channel {i} : {stats.number_objects[0]}"
618+
)
619+
620+
csv_name = f"/{method_name}_seg_results_{image_id}_channel_{i}_{utils.get_date_time()}.csv"
621+
stats_df.to_csv(
622+
self.worker_config.results_path + csv_name,
623+
index=False,
624+
)
625+
626+
stats_df.to_csv(
627+
self.worker_config.results_path + csv_name,
628+
index=False,
629+
)
630+
except ValueError as e:
631+
logger.warning(f"Error saving stats to csv : {e}")
632+
logger.debug(
633+
f"Length of stats array : {[len(s) for s in stats.get_dict().values()]}"
634+
)
635+
# logger.debug(f"Stats dict : {stats.get_dict()}")
628636

629637
def _setup_worker(self):
630638
if self.folder_choice.isChecked():

0 commit comments

Comments
 (0)