Skip to content
Open
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
c698999
psudocode
ieivanov Aug 6, 2025
4f703b7
check right shape
ieivanov Aug 6, 2025
f542d2c
first try
tayllatheodoro Aug 25, 2025
da31198
add channel
tayllatheodoro Aug 26, 2025
a3679f8
current stage
tayllatheodoro Aug 26, 2025
5b81b34
working stage
tayllatheodoro Aug 26, 2025
8ac3941
Merge branch 'main' into predict_volume
tayllatheodoro Aug 26, 2025
30b6fbf
api wrapper
tayllatheodoro Aug 26, 2025
433d195
add examples
tayllatheodoro Aug 27, 2025
324cdf1
reorder inputs
tayllatheodoro Aug 27, 2025
8b4d243
rename input tensor
tayllatheodoro Aug 27, 2025
04cbbb1
remove comment
tayllatheodoro Aug 27, 2025
214bbeb
test
tayllatheodoro Aug 27, 2025
22d5f4f
style
tayllatheodoro Aug 27, 2025
f5eacca
add test docstrig
tayllatheodoro Aug 27, 2025
9cf5870
style
tayllatheodoro Aug 27, 2025
612db64
first pass of review corrections
tayllatheodoro Aug 28, 2025
227ef0f
bug fix
tayllatheodoro Aug 28, 2025
1cb4b74
move to translation tests
tayllatheodoro Aug 28, 2025
236dfc6
use predict_step
tayllatheodoro Aug 28, 2025
5117de2
update example
tayllatheodoro Aug 28, 2025
7b1be81
move to shrimpy
tayllatheodoro Aug 28, 2025
e489ef5
docstring
tayllatheodoro Aug 28, 2025
5da4c99
Update viscy/translation/engine.py
tayllatheodoro Sep 5, 2025
70dab29
Update viscy/translation/engine.py
tayllatheodoro Sep 5, 2025
ddbd83a
Update examples/virtual_staining/VS_model_inference/demo_api.py
tayllatheodoro Sep 5, 2025
5b74bc2
Update examples/virtual_staining/VS_model_inference/demo_api.py
tayllatheodoro Sep 5, 2025
b93020b
Update examples/virtual_staining/VS_model_inference/demo_api.py
tayllatheodoro Sep 5, 2025
1ba5bba
Update viscy/translation/engine.py
tayllatheodoro Sep 5, 2025
ecd11b5
Update viscy/translation/engine.py
tayllatheodoro Sep 5, 2025
8f09e40
fallback logic to the init method
tayllatheodoro Sep 5, 2025
f05ace3
doc string
tayllatheodoro Sep 5, 2025
35b70d1
remove torch.cuda.synchronize()
tayllatheodoro Sep 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 127 additions & 0 deletions examples/virtual_staining/VS_model_inference/demo_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@

# %%
import time
from pathlib import Path
import numpy as np
import torch
from iohub import open_ome_zarr
import napari

from viscy.api.inference import VS_inference_t2t

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Configuration dictionary (from CLI .yaml)
config = {
"model": {
"class_path": "viscy.translation.engine.VSUNet",
"init_args": {
"architecture": "fcmae",
"model_config": {
"in_channels": 1,
"out_channels": 2,
"in_stack_depth": 21,
"encoder_blocks": [3, 3, 9, 3],
"dims": [96, 192, 384, 768],
"encoder_drop_path_rate": 0.0,
"stem_kernel_size": [7, 4, 4],
"decoder_conv_blocks": 2,
"pretraining": False,
"head_conv": True,
"head_conv_expansion_ratio": 4,
"head_conv_pool": False,
},
},
"test_time_augmentations": True,
"tta_type": "median",
},
"ckpt_path": "/path/to/checkpoint.ckpt"
}

# Load Phase3D input volume
path = Path("/path/to/your.zarr/0/1/000000")
with open_ome_zarr(path) as ds:
vol_np = np.asarray(ds.data[0, 0]) # (Z, Y, X)

vol = torch.from_numpy(vol_np).unsqueeze(0).unsqueeze(0).float().to(DEVICE) # (B=1, C=1, Z, Y, X)

# Run model
start = time.time()
pred = VS_inference_t2t(vol, config)
torch.cuda.synchronize()
print(f"Inference time: {time.time() - start:.2f} sec")

# Visualize
pred_np = pred.cpu().numpy()
nuc, mem = pred_np[0, 0], pred_np[0, 1]

viewer = napari.Viewer()
viewer.add_image(vol_np, name="phase_input", colormap="gray")
viewer.add_image(nuc, name="virt_nuclei", colormap="magenta")
viewer.add_image(mem, name="virt_membrane", colormap="cyan")
napari.run()


#%%

# examples/inference_manual.py
import time
from pathlib import Path
import numpy as np
import torch
from iohub import open_ome_zarr
import napari

from viscy.translation.engine import VSUNet, AugmentedPredictionVSUNet

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Instantiate model manually
vs = VSUNet(
architecture="fcmae",
model_config={
"in_channels": 1,
"out_channels": 2,
"in_stack_depth": 21,
"encoder_blocks": [3, 3, 9, 3],
"dims": [96, 192, 384, 768],
"decoder_conv_blocks": 2,
"stem_kernel_size": [7, 4, 4],
"pretraining": False,
"head_conv": True,
"head_conv_expansion_ratio": 4,
"head_conv_pool": False,
},
ckpt_path="/path/to/checkpoint.ckpt",
test_time_augmentations=True,
tta_type="median",
).to(DEVICE).eval()

wrapper = AugmentedPredictionVSUNet(
model=vs.model,
forward_transforms=[lambda t: t],
inverse_transforms=[lambda t: t],
).to(DEVICE).eval()
wrapper.on_predict_start()

# Load data
path = Path("/path/to/your.zarr/0/1/000000")
with open_ome_zarr(path) as ds:
vol_np = np.asarray(ds.data[0, 0]) # (Z, Y, X)

vol = torch.from_numpy(vol_np).unsqueeze(0).unsqueeze(0).float().to(DEVICE)

# Run inference
with torch.no_grad():
pred = wrapper.inference_tiled(vol)
torch.cuda.synchronize()

# Visualize
pred_np = pred.cpu().numpy()
nuc, mem = pred_np[0, 0], pred_np[0, 1]

viewer = napari.Viewer()
viewer.add_image(vol_np, name="phase_input", colormap="gray")
viewer.add_image(nuc, name="virt_nuclei", colormap="magenta")
viewer.add_image(mem, name="virt_membrane", colormap="cyan")
napari.run()
Empty file added tests/api/__init__.py
Empty file.
43 changes: 43 additions & 0 deletions tests/api/test_inference.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test that the output is numerically identical to the out-of-core prediction writer.

Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import torch
from viscy.api.inference import VS_inference_t2t


def test_vs_inference_t2t():
"""
Test the VS_inference_t2t function with a simple config and random input.
"""
in_stack_depth = 21
dims = [24, 48, 96, 192] # dims[0] must be divisible by ratio (24/3=8)

cfg = {
"model": {
"class_path": "viscy.translation.engine.VSUNet",
"init_args": {
"architecture": "fcmae",
"model_config": {
"in_channels": 1,
"out_channels": 2,
"in_stack_depth": in_stack_depth,
"encoder_blocks": [1, 1, 1, 1],
"dims": dims,
"stem_kernel_size": [7, 4, 4],
"pretraining": False,
"decoder_conv_blocks": 1,
"head_conv": True,
"head_conv_expansion_ratio": 2,
"head_conv_pool": False,
},
"test_time_augmentations": False,
"tta_type": "none",
"ckpt_path": None,
},
}
}

x = torch.rand(1, 1, in_stack_depth, 64, 64)
pred = VS_inference_t2t(x, cfg)

assert isinstance(pred, torch.Tensor)
assert pred.shape == (1, 2, in_stack_depth, 64, 64), (
f"Unexpected shape: {pred.shape}"
)
Empty file added viscy/api/__init__.py
Empty file.
43 changes: 43 additions & 0 deletions viscy/api/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import importlib

import torch

from viscy.translation.engine import AugmentedPredictionVSUNet


@torch.no_grad()
def VS_inference_t2t(x: torch.Tensor, cfg: dict) -> torch.Tensor:
"""
Run virtual staining using a config dictionary and 5D input tensor (B, C, Z, Y, X).
Returns predicted tensor of shape (B, C_out, Z, Y, X).
"""

# Extract model info
model_cfg = cfg["model"].copy()
init_args = model_cfg["init_args"]
class_path = model_cfg["class_path"]

# Inject ckpt_path from top-level config if needed
if "ckpt_path" in cfg:
init_args["ckpt_path"] = cfg["ckpt_path"]

# Import model class dynamically
module_path, class_name = class_path.rsplit(".", 1)
model_class = getattr(importlib.import_module(module_path), class_name)

# Instantiate model
model = model_class(**init_args).to(x.device).eval()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move it to the model to device and move the tensors to the device.


# Wrap with augmentation logic
wrapper = (
AugmentedPredictionVSUNet(
model=model.model,
forward_transforms=[lambda t: t],
inverse_transforms=[lambda t: t],
)
.to(x.device)
.eval()
)

wrapper.on_predict_start()
return wrapper.inference_tiled(x)
53 changes: 53 additions & 0 deletions viscy/translation/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,59 @@ def __init__(
def forward(self, x: Tensor) -> Tensor:
return self.model(x)

@torch.no_grad()
def inference_tiled(
self, x: torch.Tensor, out_channel: int = 2, step: int = 1
) -> torch.Tensor:
"""
Example:
pred = VS_inference_t2t(input_tensor, config)
# input_tensor: torch.Tensor of shape (B, 1, Z, Y, X)
# pred: torch.Tensor of shape (B, 2, Z, Y, X)
"""

self.eval()
assert x.ndim == 5, f"Expected shape (B,C,Z,Y,X), got {x.shape}"

B, _, Z, Y, X = x.shape
in_stack_depth = self.model.out_stack_depth

out_tensor = x.new_zeros((B, out_channel, Z, Y, X))
weights = x.new_zeros((1, 1, Z, 1, 1))

pad = getattr(self, "_predict_pad", None)
if pad is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Im not sure if this is needed?

raise RuntimeError(
"Missing _predict_pad; call on_predict_start() before inference."
)
if in_stack_depth > Z:
raise ValueError(
f"Input stack depth {in_stack_depth} is larger than input Z dimension {Z}"
)

for start in range(0, Z, step):
end = min(start + in_stack_depth, Z)
slab = x[:, :, start:end]

# pad if last slab is shorter
if end - start < in_stack_depth:
pad_z = in_stack_depth - (end - start)
slab = torch.nn.functional.pad(slab, (0, 0, 0, 0, 0, pad_z))

slab = pad(slab)
pred = self(slab)
pred = pad.inverse(pred)

# clip prediction if padded in Z
pred = pred[:, :, : end - start]

out_tensor[:, :, start:end] += pred
weights[:, :, start:end] += 1.0

blended = out_tensor / weights.clamp_min(1e-8)
assert blended.shape[-3] == Z
return blended

def setup(self, stage: str) -> None:
if stage != "predict":
raise NotImplementedError(
Expand Down