Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
c698999
psudocode
ieivanov Aug 6, 2025
4f703b7
check right shape
ieivanov Aug 6, 2025
f542d2c
first try
tayllatheodoro Aug 25, 2025
da31198
add channel
tayllatheodoro Aug 26, 2025
a3679f8
current stage
tayllatheodoro Aug 26, 2025
5b81b34
working stage
tayllatheodoro Aug 26, 2025
8ac3941
Merge branch 'main' into predict_volume
tayllatheodoro Aug 26, 2025
30b6fbf
api wrapper
tayllatheodoro Aug 26, 2025
433d195
add examples
tayllatheodoro Aug 27, 2025
324cdf1
reorder inputs
tayllatheodoro Aug 27, 2025
8b4d243
rename input tensor
tayllatheodoro Aug 27, 2025
04cbbb1
remove comment
tayllatheodoro Aug 27, 2025
214bbeb
test
tayllatheodoro Aug 27, 2025
22d5f4f
style
tayllatheodoro Aug 27, 2025
f5eacca
add test docstrig
tayllatheodoro Aug 27, 2025
9cf5870
style
tayllatheodoro Aug 27, 2025
612db64
first pass of review corrections
tayllatheodoro Aug 28, 2025
227ef0f
bug fix
tayllatheodoro Aug 28, 2025
1cb4b74
move to translation tests
tayllatheodoro Aug 28, 2025
236dfc6
use predict_step
tayllatheodoro Aug 28, 2025
5117de2
update example
tayllatheodoro Aug 28, 2025
7b1be81
move to shrimpy
tayllatheodoro Aug 28, 2025
e489ef5
docstring
tayllatheodoro Aug 28, 2025
5da4c99
Update viscy/translation/engine.py
tayllatheodoro Sep 5, 2025
70dab29
Update viscy/translation/engine.py
tayllatheodoro Sep 5, 2025
ddbd83a
Update examples/virtual_staining/VS_model_inference/demo_api.py
tayllatheodoro Sep 5, 2025
5b74bc2
Update examples/virtual_staining/VS_model_inference/demo_api.py
tayllatheodoro Sep 5, 2025
b93020b
Update examples/virtual_staining/VS_model_inference/demo_api.py
tayllatheodoro Sep 5, 2025
1ba5bba
Update viscy/translation/engine.py
tayllatheodoro Sep 5, 2025
ecd11b5
Update viscy/translation/engine.py
tayllatheodoro Sep 5, 2025
8f09e40
fallback logic to the init method
tayllatheodoro Sep 5, 2025
f05ace3
doc string
tayllatheodoro Sep 5, 2025
35b70d1
remove torch.cuda.synchronize()
tayllatheodoro Sep 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions examples/virtual_staining/VS_model_inference/demo_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#%%
from pathlib import Path
import numpy as np
import torch
from iohub import open_ome_zarr
import napari

from viscy.translation.engine import VSUNet, AugmentedPredictionVSUNet

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Instantiate model manually
model = VSUNet(
architecture="fcmae",
model_config={
"in_channels": 1,
"out_channels": 2,
"in_stack_depth": 21,
"encoder_blocks": [3, 3, 9, 3],
"dims": [96, 192, 384, 768],
"decoder_conv_blocks": 2,
"stem_kernel_size": [7, 4, 4],
"pretraining": False,
"head_conv": True,
"head_conv_expansion_ratio": 4,
"head_conv_pool": False,
},
ckpt_path="/path/to/checkpoint.ckpt",
).to(DEVICE).eval()

vs = AugmentedPredictionVSUNet(
model=model.model,
forward_transforms=[lambda t: t],
inverse_transforms=[lambda t: t],
).to(DEVICE).eval()

# Load data
path = Path("/path/to/your.zarr/0/1/000000")
with open_ome_zarr(path) as ds:
vol_np = np.asarray(ds.data[0:1, 0:1]) # (1, 1, Z, Y, X)

vol = torch.from_numpy(vol_np).float().to(DEVICE)

# Run inference
with torch.inference_mode():
pred = vs.predict_sliding_windows(vol)

# Visualize
pred_np = pred.cpu().numpy()
nuc, mem = pred_np[0, 0], pred_np[0, 1]

viewer = napari.Viewer()
viewer.add_image(vol_np, name="phase_input", colormap="gray")
viewer.add_image(nuc, name="virt_nuclei", colormap="magenta")
viewer.add_image(mem, name="virt_membrane", colormap="cyan")
napari.run()
119 changes: 99 additions & 20 deletions viscy/translation/engine.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
import logging
import os
import random
Expand Down Expand Up @@ -487,10 +487,12 @@
forward_transforms : list[Callable[[Tensor], Tensor]]
A collection of transforms to apply to the input image before passing it to the model.
Each one is applied independently.
If None, no forward transforms are applied, fallback to the identity transform.
For example, resizing the input to match the expected voxel size of the model.
inverse_transforms : list[Callable[[Tensor], Tensor]]
Inverse transforms to apply to the model output before reduction.
They should be the inverse of each forward transform.
If None, no inverse transforms are applied, fallback to the identity transform.
For example, resizing the output to match the original input shape for storage.
reduction : Literal["mean", "median"], optional
The reduction method to apply to the predictions, by default "mean"
Expand All @@ -515,21 +517,84 @@
def __init__(
self,
model: nn.Module,
forward_transforms: list[Callable[[Tensor], Tensor]],
inverse_transforms: list[Callable[[Tensor], Tensor]],
forward_transforms: list[Callable[[Tensor], Tensor]] | None = None,
inverse_transforms: list[Callable[[Tensor], Tensor]] | None = None,
reduction: Literal["mean", "median"] = "mean",
) -> None:
super().__init__()
down_factor = 2**model.num_blocks
self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor))
self.model = model
self._forward_transforms = forward_transforms
self._inverse_transforms = inverse_transforms
self._forward_transforms = forward_transforms or [lambda x: x]
self._inverse_transforms = inverse_transforms or [lambda x: x]
self._reduction = reduction

def forward(self, x: Tensor) -> Tensor:
return self.model(x)

def predict_sliding_windows(
self, x: torch.Tensor, out_channel: int = 2, step: int = 1
) -> torch.Tensor:
"""
Run inference on a 5D input tensor (B, C, Z, Y, X) using sliding windows
along the Z dimension with overlap and average blending.

Parameters
----------
x : torch.Tensor
Input tensor of shape (B, C, Z, Y, X).
out_channel : int, optional
Number of output channels, by default 2.
step : int, optional
Step size for sliding window along Z, by default 1.

Returns
-------
torch.Tensor
Output tensor of shape (B, out_channel, Z, Y, X).
"""


if x.ndim != 5:
raise ValueError(
f"Expected input with 5 dimensions (B, C, Z, Y, X), got {x.shape}"
)

batch_size, _, depth, height, width = x.shape
in_stack_depth = self.model.out_stack_depth

if not hasattr(self, "_predict_pad"):
raise RuntimeError(
"Missing _predict_pad; make sure to call `on_predict_start()` before inference."
)
Comment on lines +566 to +569
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This attribute is assigned in __init__, and I don't see a on_predict_start hook implemented for this class.

if in_stack_depth > depth:
raise ValueError(f"in_stack_depth {in_stack_depth} > input depth {depth}")

out_tensor = x.new_zeros((batch_size, out_channel, depth, height, width))
weights = x.new_zeros((1, 1, depth, 1, 1))

for start in range(0, depth, step):
end = min(start + in_stack_depth, depth)
slab = x[:, :, start:end]

if end - start < in_stack_depth:
pad_z = in_stack_depth - (end - start)
slab = F.pad(slab, (0, 0, 0, 0, 0, pad_z))

pred = self._predict_with_tta(slab)

pred = pred[:, :, : end - start] # Trim if Z was padded
out_tensor[:, :, start:end] += pred
weights[:, :, start:end] += 1.0

if (weights == 0).any():
raise RuntimeError(
"Some Z slices were not covered during sliding window inference."
)

blended = out_tensor / weights
return blended

def setup(self, stage: str) -> None:
if stage != "predict":
raise NotImplementedError(
Expand All @@ -544,25 +609,39 @@
prediction = prediction.median(dim=0).values
return prediction

def _predict_with_tta(self, source: torch.Tensor) -> torch.Tensor:
preds = []
for fwd_t, inv_t in zip(
self._forward_transforms,
self._inverse_transforms,
):
src = fwd_t(source)
src = self._predict_pad(src)
y = self.forward(src)
y = self._predict_pad.inverse(y)
preds.append(inv_t(y))
return preds[0] if len(preds) == 1 else self._reduce_predictions(preds)

def predict_step(
self, batch: Sample, batch_idx: int, dataloader_idx: int = 0
) -> Tensor:
) -> torch.Tensor:
"""
Parameters
----------
batch : dict[str, Tensor]
A dictionary containing a "source" tensor of shape (B, C, Z, Y, X).
batch_idx : int
Index of the batch.
dataloader_idx : int
Index of the dataloader if multiple dataloaders are used.

Returns
-------
torch.Tensor
Prediction tensor of shape (B, out_channels, Z, Y, X).
"""
source = batch["source"]
preds = []
for forward_t, inverse_t in zip(
self._forward_transforms, self._inverse_transforms
):
source = forward_t(source)
source = self._predict_pad(source)
pred = self.forward(source)
pred = self._predict_pad.inverse(pred)
pred = inverse_t(pred)
preds.append(pred)
if len(preds) == 1:
prediction = preds[0]
else:
prediction = self._reduce_predictions(preds)
return prediction
return self._predict_with_tta(source)


class FcmaeUNet(VSUNet):
Expand Down
Loading