-
Notifications
You must be signed in to change notification settings - Fork 12
Predict volume #280
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Predict volume #280
Changes from 16 commits
c698999
4f703b7
f542d2c
da31198
a3679f8
5b81b34
8ac3941
30b6fbf
433d195
324cdf1
8b4d243
04cbbb1
214bbeb
22d5f4f
f5eacca
9cf5870
612db64
227ef0f
1cb4b74
236dfc6
5117de2
7b1be81
e489ef5
5da4c99
70dab29
ddbd83a
5b74bc2
b93020b
1ba5bba
ecd11b5
8f09e40
f05ace3
35b70d1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
|
||
# %% | ||
import time | ||
from pathlib import Path | ||
import numpy as np | ||
import torch | ||
from iohub import open_ome_zarr | ||
import napari | ||
|
||
from viscy.api.inference import VS_inference_t2t | ||
|
||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
# Configuration dictionary (from CLI .yaml) | ||
config = { | ||
"model": { | ||
"class_path": "viscy.translation.engine.VSUNet", | ||
"init_args": { | ||
"architecture": "fcmae", | ||
"model_config": { | ||
"in_channels": 1, | ||
"out_channels": 2, | ||
"in_stack_depth": 21, | ||
"encoder_blocks": [3, 3, 9, 3], | ||
"dims": [96, 192, 384, 768], | ||
"encoder_drop_path_rate": 0.0, | ||
"stem_kernel_size": [7, 4, 4], | ||
"decoder_conv_blocks": 2, | ||
"pretraining": False, | ||
"head_conv": True, | ||
"head_conv_expansion_ratio": 4, | ||
"head_conv_pool": False, | ||
}, | ||
}, | ||
"test_time_augmentations": True, | ||
"tta_type": "median", | ||
}, | ||
"ckpt_path": "/path/to/checkpoint.ckpt" | ||
} | ||
|
||
# Load Phase3D input volume | ||
path = Path("/path/to/your.zarr/0/1/000000") | ||
with open_ome_zarr(path) as ds: | ||
vol_np = np.asarray(ds.data[0, 0]) # (Z, Y, X) | ||
|
||
vol = torch.from_numpy(vol_np).unsqueeze(0).unsqueeze(0).float().to(DEVICE) # (B=1, C=1, Z, Y, X) | ||
|
||
# Run model | ||
start = time.time() | ||
pred = VS_inference_t2t(vol, config) | ||
torch.cuda.synchronize() | ||
print(f"Inference time: {time.time() - start:.2f} sec") | ||
|
||
# Visualize | ||
pred_np = pred.cpu().numpy() | ||
nuc, mem = pred_np[0, 0], pred_np[0, 1] | ||
|
||
viewer = napari.Viewer() | ||
viewer.add_image(vol_np, name="phase_input", colormap="gray") | ||
viewer.add_image(nuc, name="virt_nuclei", colormap="magenta") | ||
viewer.add_image(mem, name="virt_membrane", colormap="cyan") | ||
napari.run() | ||
|
||
|
||
#%% | ||
|
||
# examples/inference_manual.py | ||
import time | ||
from pathlib import Path | ||
import numpy as np | ||
import torch | ||
from iohub import open_ome_zarr | ||
import napari | ||
|
||
from viscy.translation.engine import VSUNet, AugmentedPredictionVSUNet | ||
|
||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
# Instantiate model manually | ||
vs = VSUNet( | ||
architecture="fcmae", | ||
model_config={ | ||
"in_channels": 1, | ||
"out_channels": 2, | ||
"in_stack_depth": 21, | ||
"encoder_blocks": [3, 3, 9, 3], | ||
"dims": [96, 192, 384, 768], | ||
"decoder_conv_blocks": 2, | ||
"stem_kernel_size": [7, 4, 4], | ||
"pretraining": False, | ||
"head_conv": True, | ||
"head_conv_expansion_ratio": 4, | ||
"head_conv_pool": False, | ||
}, | ||
ckpt_path="/path/to/checkpoint.ckpt", | ||
test_time_augmentations=True, | ||
tta_type="median", | ||
).to(DEVICE).eval() | ||
|
||
wrapper = AugmentedPredictionVSUNet( | ||
model=vs.model, | ||
forward_transforms=[lambda t: t], | ||
inverse_transforms=[lambda t: t], | ||
).to(DEVICE).eval() | ||
wrapper.on_predict_start() | ||
|
||
# Load data | ||
path = Path("/path/to/your.zarr/0/1/000000") | ||
with open_ome_zarr(path) as ds: | ||
vol_np = np.asarray(ds.data[0, 0]) # (Z, Y, X) | ||
|
||
vol = torch.from_numpy(vol_np).unsqueeze(0).unsqueeze(0).float().to(DEVICE) | ||
tayllatheodoro marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
# Run inference | ||
with torch.no_grad(): | ||
tayllatheodoro marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
pred = wrapper.inference_tiled(vol) | ||
torch.cuda.synchronize() | ||
ziw-liu marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
# Visualize | ||
pred_np = pred.cpu().numpy() | ||
nuc, mem = pred_np[0, 0], pred_np[0, 1] | ||
|
||
viewer = napari.Viewer() | ||
viewer.add_image(vol_np, name="phase_input", colormap="gray") | ||
viewer.add_image(nuc, name="virt_nuclei", colormap="magenta") | ||
viewer.add_image(mem, name="virt_membrane", colormap="cyan") | ||
napari.run() |
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import torch | ||
from viscy.api.inference import VS_inference_t2t | ||
|
||
|
||
def test_vs_inference_t2t(): | ||
""" | ||
Test the VS_inference_t2t function with a simple config and random input. | ||
""" | ||
in_stack_depth = 21 | ||
dims = [24, 48, 96, 192] # dims[0] must be divisible by ratio (24/3=8) | ||
|
||
cfg = { | ||
"model": { | ||
"class_path": "viscy.translation.engine.VSUNet", | ||
"init_args": { | ||
"architecture": "fcmae", | ||
"model_config": { | ||
"in_channels": 1, | ||
"out_channels": 2, | ||
"in_stack_depth": in_stack_depth, | ||
"encoder_blocks": [1, 1, 1, 1], | ||
"dims": dims, | ||
"stem_kernel_size": [7, 4, 4], | ||
"pretraining": False, | ||
"decoder_conv_blocks": 1, | ||
"head_conv": True, | ||
"head_conv_expansion_ratio": 2, | ||
"head_conv_pool": False, | ||
}, | ||
"test_time_augmentations": False, | ||
"tta_type": "none", | ||
"ckpt_path": None, | ||
}, | ||
} | ||
} | ||
|
||
x = torch.rand(1, 1, in_stack_depth, 64, 64) | ||
pred = VS_inference_t2t(x, cfg) | ||
|
||
assert isinstance(pred, torch.Tensor) | ||
assert pred.shape == (1, 2, in_stack_depth, 64, 64), ( | ||
f"Unexpected shape: {pred.shape}" | ||
) |
tayllatheodoro marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import importlib | ||
|
||
import torch | ||
|
||
from viscy.translation.engine import AugmentedPredictionVSUNet | ||
|
||
|
||
@torch.no_grad() | ||
def VS_inference_t2t(x: torch.Tensor, cfg: dict) -> torch.Tensor: | ||
tayllatheodoro marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
tayllatheodoro marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
""" | ||
Run virtual staining using a config dictionary and 5D input tensor (B, C, Z, Y, X). | ||
Returns predicted tensor of shape (B, C_out, Z, Y, X). | ||
""" | ||
|
||
# Extract model info | ||
model_cfg = cfg["model"].copy() | ||
init_args = model_cfg["init_args"] | ||
ziw-liu marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
class_path = model_cfg["class_path"] | ||
|
||
# Inject ckpt_path from top-level config if needed | ||
if "ckpt_path" in cfg: | ||
init_args["ckpt_path"] = cfg["ckpt_path"] | ||
|
||
# Import model class dynamically | ||
module_path, class_name = class_path.rsplit(".", 1) | ||
model_class = getattr(importlib.import_module(module_path), class_name) | ||
ziw-liu marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
# Instantiate model | ||
model = model_class(**init_args).to(x.device).eval() | ||
|
||
|
||
# Wrap with augmentation logic | ||
wrapper = ( | ||
AugmentedPredictionVSUNet( | ||
model=model.model, | ||
forward_transforms=[lambda t: t], | ||
inverse_transforms=[lambda t: t], | ||
) | ||
.to(x.device) | ||
.eval() | ||
) | ||
|
||
wrapper.on_predict_start() | ||
return wrapper.inference_tiled(x) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -530,6 +530,59 @@ def __init__( | |
def forward(self, x: Tensor) -> Tensor: | ||
return self.model(x) | ||
|
||
@torch.no_grad() | ||
tayllatheodoro marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
def inference_tiled( | ||
tayllatheodoro marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
self, x: torch.Tensor, out_channel: int = 2, step: int = 1 | ||
) -> torch.Tensor: | ||
""" | ||
Example: | ||
tayllatheodoro marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
pred = VS_inference_t2t(input_tensor, config) | ||
# input_tensor: torch.Tensor of shape (B, 1, Z, Y, X) | ||
# pred: torch.Tensor of shape (B, 2, Z, Y, X) | ||
""" | ||
|
||
self.eval() | ||
tayllatheodoro marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
assert x.ndim == 5, f"Expected shape (B,C,Z,Y,X), got {x.shape}" | ||
tayllatheodoro marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
B, _, Z, Y, X = x.shape | ||
tayllatheodoro marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
in_stack_depth = self.model.out_stack_depth | ||
tayllatheodoro marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
out_tensor = x.new_zeros((B, out_channel, Z, Y, X)) | ||
weights = x.new_zeros((1, 1, Z, 1, 1)) | ||
|
||
pad = getattr(self, "_predict_pad", None) | ||
tayllatheodoro marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
if pad is None: | ||
|
||
raise RuntimeError( | ||
"Missing _predict_pad; call on_predict_start() before inference." | ||
) | ||
if in_stack_depth > Z: | ||
raise ValueError( | ||
f"Input stack depth {in_stack_depth} is larger than input Z dimension {Z}" | ||
) | ||
|
||
for start in range(0, Z, step): | ||
end = min(start + in_stack_depth, Z) | ||
slab = x[:, :, start:end] | ||
|
||
# pad if last slab is shorter | ||
if end - start < in_stack_depth: | ||
pad_z = in_stack_depth - (end - start) | ||
slab = torch.nn.functional.pad(slab, (0, 0, 0, 0, 0, pad_z)) | ||
|
||
slab = pad(slab) | ||
pred = self(slab) | ||
tayllatheodoro marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
pred = pad.inverse(pred) | ||
|
||
# clip prediction if padded in Z | ||
pred = pred[:, :, : end - start] | ||
|
||
out_tensor[:, :, start:end] += pred | ||
weights[:, :, start:end] += 1.0 | ||
|
||
blended = out_tensor / weights.clamp_min(1e-8) | ||
tayllatheodoro marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
assert blended.shape[-3] == Z | ||
tayllatheodoro marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
return blended | ||
|
||
def setup(self, stage: str) -> None: | ||
if stage != "predict": | ||
raise NotImplementedError( | ||
|
Uh oh!
There was an error while loading. Please reload this page.