diff --git a/tests/translation/test_predict_writer.py b/tests/translation/test_predict_writer.py index c757774ae..0a71dbcea 100644 --- a/tests/translation/test_predict_writer.py +++ b/tests/translation/test_predict_writer.py @@ -1,4 +1,9 @@ -from viscy.translation.predict_writer import _pad_shape +from iohub import open_ome_zarr + +from viscy.data.hcs import HCSDataModule +from viscy.trainer import VisCyTrainer +from viscy.translation.engine import VSUNet +from viscy.translation.predict_writer import HCSPredictionWriter, _pad_shape def test_pad_shape(): @@ -6,3 +11,52 @@ def test_pad_shape(): assert _pad_shape((4, 5), 4) == (1, 1, 4, 5) full_shape = tuple(range(1, 6)) assert _pad_shape(full_shape, 5) == full_shape + + +def test_predict_writer(preprocessed_hcs_dataset, tmp_path): + z_window_size = 5 + data_path = preprocessed_hcs_dataset + channel_split = 2 + with open_ome_zarr(data_path) as dataset: + channel_names = dataset.channel_names + dm = HCSDataModule( + data_path=data_path, + source_channel=channel_names[:channel_split], + target_channel=channel_names[channel_split:], + z_window_size=z_window_size, + target_2d=bool(z_window_size == 1), + batch_size=2, + num_workers=0, + ) + + model = VSUNet( + architecture="fcmae", + model_config=dict( + in_channels=channel_split, + out_channels=len(channel_names) - channel_split, + encoder_blocks=[2, 2, 2, 2], + dims=[4, 8, 16, 32], + decoder_conv_blocks=2, + stem_kernel_size=[z_window_size, 4, 4], + in_stack_depth=z_window_size, + pretraining=False, + ), + ) + + output_path = tmp_path / "predictions.zarr" + prediction_writer = HCSPredictionWriter( + output_store=str(output_path), write_input=False + ) + + trainer = VisCyTrainer( + logger=False, + callbacks=[prediction_writer], + fast_dev_run=False, + default_root_dir=tmp_path, + ) + + trainer.predict(model, datamodule=dm) + assert output_path.exists() + with open_ome_zarr(output_path) as result: + for _, pos in result.positions(): + assert pos["0"][:].any() diff --git a/viscy/translation/predict_writer.py b/viscy/translation/predict_writer.py index 75d6d4152..eb593e4f4 100644 --- a/viscy/translation/predict_writer.py +++ b/viscy/translation/predict_writer.py @@ -33,10 +33,12 @@ def _resize_image(image: ImageArray, t_index: int, z_slice: slice) -> None: f"T={t_index}, Z-sclice={z_slice}." ) image.resize( - max(t_index + 1, image.shape[0]), - image.channels, - max(z_slice.stop, image.shape[2]), - *image.shape[-2:], + ( + max(t_index + 1, image.shape[0]), + image.channels, + max(z_slice.stop, image.shape[2]), + *image.shape[-2:], + ) ) @@ -142,7 +144,7 @@ def on_predict_start(self, trainer: Trainer, pl_module: LightningModule) -> None self.plate = open_ome_zarr( self.output_store, layout="hcs", mode="a", channel_names=channel_names ) - _logger.info(f"Writing prediction to: '{self.plate.zgroup.store.path}'.") + _logger.info(f"Writing prediction to: '{self.plate.zgroup.store.root}'.") if self.write_input: self.source_index = self._get_channel_indices(source_channel) self.target_index = self._get_channel_indices(target_channel)