diff --git a/examples/virtual_staining/VS_model_inference/demo_api.py b/examples/virtual_staining/VS_model_inference/demo_api.py new file mode 100644 index 000000000..9de9bdb03 --- /dev/null +++ b/examples/virtual_staining/VS_model_inference/demo_api.py @@ -0,0 +1,56 @@ +#%% +from pathlib import Path +import numpy as np +import torch +from iohub import open_ome_zarr +import napari + +from viscy.translation.engine import VSUNet, AugmentedPredictionVSUNet + +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# Instantiate model manually +model = VSUNet( + architecture="fcmae", + model_config={ + "in_channels": 1, + "out_channels": 2, + "in_stack_depth": 21, + "encoder_blocks": [3, 3, 9, 3], + "dims": [96, 192, 384, 768], + "decoder_conv_blocks": 2, + "stem_kernel_size": [7, 4, 4], + "pretraining": False, + "head_conv": True, + "head_conv_expansion_ratio": 4, + "head_conv_pool": False, + }, + ckpt_path="/path/to/checkpoint.ckpt", +).to(DEVICE).eval() + +vs = AugmentedPredictionVSUNet( + model=model.model, + forward_transforms=[lambda t: t], + inverse_transforms=[lambda t: t], +).to(DEVICE).eval() + +# Load data +path = Path("/path/to/your.zarr/0/1/000000") +with open_ome_zarr(path) as ds: + vol_np = np.asarray(ds.data[0:1, 0:1]) # (1, 1, Z, Y, X) + +vol = torch.from_numpy(vol_np).float().to(DEVICE) + +# Run inference +with torch.inference_mode(): + pred = vs.predict_sliding_windows(vol) + +# Visualize +pred_np = pred.cpu().numpy() +nuc, mem = pred_np[0, 0], pred_np[0, 1] + +viewer = napari.Viewer() +viewer.add_image(vol_np, name="phase_input", colormap="gray") +viewer.add_image(nuc, name="virt_nuclei", colormap="magenta") +viewer.add_image(mem, name="virt_membrane", colormap="cyan") +napari.run() diff --git a/viscy/translation/engine.py b/viscy/translation/engine.py index 56af9b985..6eb550395 100644 --- a/viscy/translation/engine.py +++ b/viscy/translation/engine.py @@ -487,10 +487,12 @@ class AugmentedPredictionVSUNet(LightningModule): forward_transforms : list[Callable[[Tensor], Tensor]] A collection of transforms to apply to the input image before passing it to the model. Each one is applied independently. + If None, no forward transforms are applied, fallback to the identity transform. For example, resizing the input to match the expected voxel size of the model. inverse_transforms : list[Callable[[Tensor], Tensor]] Inverse transforms to apply to the model output before reduction. They should be the inverse of each forward transform. + If None, no inverse transforms are applied, fallback to the identity transform. For example, resizing the output to match the original input shape for storage. reduction : Literal["mean", "median"], optional The reduction method to apply to the predictions, by default "mean" @@ -515,21 +517,84 @@ class AugmentedPredictionVSUNet(LightningModule): def __init__( self, model: nn.Module, - forward_transforms: list[Callable[[Tensor], Tensor]], - inverse_transforms: list[Callable[[Tensor], Tensor]], + forward_transforms: list[Callable[[Tensor], Tensor]] | None = None, + inverse_transforms: list[Callable[[Tensor], Tensor]] | None = None, reduction: Literal["mean", "median"] = "mean", ) -> None: super().__init__() down_factor = 2**model.num_blocks self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) self.model = model - self._forward_transforms = forward_transforms - self._inverse_transforms = inverse_transforms + self._forward_transforms = forward_transforms or [lambda x: x] + self._inverse_transforms = inverse_transforms or [lambda x: x] self._reduction = reduction def forward(self, x: Tensor) -> Tensor: return self.model(x) + def predict_sliding_windows( + self, x: torch.Tensor, out_channel: int = 2, step: int = 1 + ) -> torch.Tensor: + """ + Run inference on a 5D input tensor (B, C, Z, Y, X) using sliding windows + along the Z dimension with overlap and average blending. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (B, C, Z, Y, X). + out_channel : int, optional + Number of output channels, by default 2. + step : int, optional + Step size for sliding window along Z, by default 1. + + Returns + ------- + torch.Tensor + Output tensor of shape (B, out_channel, Z, Y, X). + """ + + + if x.ndim != 5: + raise ValueError( + f"Expected input with 5 dimensions (B, C, Z, Y, X), got {x.shape}" + ) + + batch_size, _, depth, height, width = x.shape + in_stack_depth = self.model.out_stack_depth + + if not hasattr(self, "_predict_pad"): + raise RuntimeError( + "Missing _predict_pad; make sure to call `on_predict_start()` before inference." + ) + if in_stack_depth > depth: + raise ValueError(f"in_stack_depth {in_stack_depth} > input depth {depth}") + + out_tensor = x.new_zeros((batch_size, out_channel, depth, height, width)) + weights = x.new_zeros((1, 1, depth, 1, 1)) + + for start in range(0, depth, step): + end = min(start + in_stack_depth, depth) + slab = x[:, :, start:end] + + if end - start < in_stack_depth: + pad_z = in_stack_depth - (end - start) + slab = F.pad(slab, (0, 0, 0, 0, 0, pad_z)) + + pred = self._predict_with_tta(slab) + + pred = pred[:, :, : end - start] # Trim if Z was padded + out_tensor[:, :, start:end] += pred + weights[:, :, start:end] += 1.0 + + if (weights == 0).any(): + raise RuntimeError( + "Some Z slices were not covered during sliding window inference." + ) + + blended = out_tensor / weights + return blended + def setup(self, stage: str) -> None: if stage != "predict": raise NotImplementedError( @@ -544,25 +609,39 @@ def _reduce_predictions(self, preds: list[Tensor]) -> Tensor: prediction = prediction.median(dim=0).values return prediction + def _predict_with_tta(self, source: torch.Tensor) -> torch.Tensor: + preds = [] + for fwd_t, inv_t in zip( + self._forward_transforms, + self._inverse_transforms, + ): + src = fwd_t(source) + src = self._predict_pad(src) + y = self.forward(src) + y = self._predict_pad.inverse(y) + preds.append(inv_t(y)) + return preds[0] if len(preds) == 1 else self._reduce_predictions(preds) + def predict_step( self, batch: Sample, batch_idx: int, dataloader_idx: int = 0 - ) -> Tensor: + ) -> torch.Tensor: + """ + Parameters + ---------- + batch : dict[str, Tensor] + A dictionary containing a "source" tensor of shape (B, C, Z, Y, X). + batch_idx : int + Index of the batch. + dataloader_idx : int + Index of the dataloader if multiple dataloaders are used. + + Returns + ------- + torch.Tensor + Prediction tensor of shape (B, out_channels, Z, Y, X). + """ source = batch["source"] - preds = [] - for forward_t, inverse_t in zip( - self._forward_transforms, self._inverse_transforms - ): - source = forward_t(source) - source = self._predict_pad(source) - pred = self.forward(source) - pred = self._predict_pad.inverse(pred) - pred = inverse_t(pred) - preds.append(pred) - if len(preds) == 1: - prediction = preds[0] - else: - prediction = self._reduce_predictions(preds) - return prediction + return self._predict_with_tta(source) class FcmaeUNet(VSUNet):