From c69899949a9a0feb44e609849fed47a2b76db998 Mon Sep 17 00:00:00 2001 From: Ivan Ivanov Date: Wed, 6 Aug 2025 10:33:19 -0700 Subject: [PATCH 01/32] psudocode --- viscy/translation/engine.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/viscy/translation/engine.py b/viscy/translation/engine.py index 56af9b985..8d6524ed1 100644 --- a/viscy/translation/engine.py +++ b/viscy/translation/engine.py @@ -530,6 +530,26 @@ def __init__( def forward(self, x: Tensor) -> Tensor: return self.model(x) + # TODO: come up with better name + def predict_volume(self, x: Tensor) -> Tensor: + # x.dype (Phase 3D) will be float32 + # x.device should be CUDA + + assert x.ndim == 5 + + input_shape = x.shape # BCZYX shape, Z is ~100 slices, B is 1 for real-time processing, C is 1 - phase + window_size = self.model.config.z_stack # TODO: check + + slabs = [] + for idx in range(0, input_shape[-3], window_size): # TODO: make sure this goes over the whole volume + slab.append(self(x[:, :, idx:idx+window_size])) # Size of slabs is (B, C, window_size, Y, X), C will be 2 for VSCyto3D - nucleus and membrane + + # TODO: add linear blending + blended_slab = + + return blended_slab + + def setup(self, stage: str) -> None: if stage != "predict": raise NotImplementedError( From 4f703b791759b040f88a29f77a4366c803cbb17c Mon Sep 17 00:00:00 2001 From: Ivan Ivanov Date: Wed, 6 Aug 2025 10:34:34 -0700 Subject: [PATCH 02/32] check right shape --- viscy/translation/engine.py | 1 + 1 file changed, 1 insertion(+) diff --git a/viscy/translation/engine.py b/viscy/translation/engine.py index 8d6524ed1..223673c88 100644 --- a/viscy/translation/engine.py +++ b/viscy/translation/engine.py @@ -546,6 +546,7 @@ def predict_volume(self, x: Tensor) -> Tensor: # TODO: add linear blending blended_slab = + assert blended_slab.shape[-3:] == input_shape[-3:] return blended_slab From f542d2c73dbdc1f5609c23aea40725351b577cf5 Mon Sep 17 00:00:00 2001 From: Taylla Theodoro Date: Mon, 25 Aug 2025 16:44:30 -0700 Subject: [PATCH 03/32] first try --- viscy/translation/engine.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/viscy/translation/engine.py b/viscy/translation/engine.py index 223673c88..a02adc56c 100644 --- a/viscy/translation/engine.py +++ b/viscy/translation/engine.py @@ -531,21 +531,27 @@ def forward(self, x: Tensor) -> Tensor: return self.model(x) # TODO: come up with better name - def predict_volume(self, x: Tensor) -> Tensor: + @torch.no_grad() + def inference_tiled(self, x: Tensor) -> Tensor: # x.dype (Phase 3D) will be float32 # x.device should be CUDA assert x.ndim == 5 input_shape = x.shape # BCZYX shape, Z is ~100 slices, B is 1 for real-time processing, C is 1 - phase - window_size = self.model.config.z_stack # TODO: check + window_size = self.data.z_window_size # TODO: check + + accum_tensor = torch.zeros_like(x, dtype=torch.float32, device=x.device) + weights = torch.zeros((x.shape[-3],), dtype=torch.float32, device=x.device) # Z dimension + step = 1 # TODO: verify + for idx in range(0, input_shape[-3], step): # TODO: make sure this goes over the whole volume + accum_tensor[:,:idx:idx+window_size] += self(x[:, :, idx:idx+window_size]) # Size of slabs is (B, C, window_size, Y, X), C will be 2 for VSCyto3D - nucleus and membrane + weights[idx:idx+window_size] += 1.0 - slabs = [] - for idx in range(0, input_shape[-3], window_size): # TODO: make sure this goes over the whole volume - slab.append(self(x[:, :, idx:idx+window_size])) # Size of slabs is (B, C, window_size, Y, X), C will be 2 for VSCyto3D - nucleus and membrane + blended_slab = accum_tensor / weights.view(1, 1, -1, 1, 1) # Shape is (B, C, Z, Y, X) # TODO: add linear blending - blended_slab = + #blended_slab = slabs / torch.from_numpy(weights).to(slabs.device).view(1, 1, -1, 1, 1) # Shape is (B, C, Z, Y, X) assert blended_slab.shape[-3:] == input_shape[-3:] return blended_slab From da3119840950674a499df776b31aeb668125837b Mon Sep 17 00:00:00 2001 From: Taylla Theodoro Date: Mon, 25 Aug 2025 17:25:08 -0700 Subject: [PATCH 04/32] add channel --- viscy/translation/engine.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/viscy/translation/engine.py b/viscy/translation/engine.py index a02adc56c..227b7db7b 100644 --- a/viscy/translation/engine.py +++ b/viscy/translation/engine.py @@ -536,17 +536,20 @@ def inference_tiled(self, x: Tensor) -> Tensor: # x.dype (Phase 3D) will be float32 # x.device should be CUDA - assert x.ndim == 5 + assert x.ndim == 5, f"Expected (B,C,Z,Y,X), got {x.shape}" - input_shape = x.shape # BCZYX shape, Z is ~100 slices, B is 1 for real-time processing, C is 1 - phase - window_size = self.data.z_window_size # TODO: check + B,_,Z,Y,X = x.shape # BCZYX shape, Z is ~100 slices, B is 1 for real-time processing, C is 1 - phase + output_shape = B,self.model.out_channels,Z,Y,X # C is 2 - nucleus and membrane + input_shape = x.shape + in_stack_depth = self.model.in_stack_depth # TODO: check read from some place, - accum_tensor = torch.zeros_like(x, dtype=torch.float32, device=x.device) - weights = torch.zeros((x.shape[-3],), dtype=torch.float32, device=x.device) # Z dimension + + accum_tensor = torch.zeros(output_shape, dtype=torch.float32, device=x.device) + weights = torch.zeros((Z,), dtype=torch.float32, device=x.device) # Z dimension step = 1 # TODO: verify - for idx in range(0, input_shape[-3], step): # TODO: make sure this goes over the whole volume - accum_tensor[:,:idx:idx+window_size] += self(x[:, :, idx:idx+window_size]) # Size of slabs is (B, C, window_size, Y, X), C will be 2 for VSCyto3D - nucleus and membrane - weights[idx:idx+window_size] += 1.0 + for idx in range(0, Z, step): # TODO: make sure this goes over the whole volume + accum_tensor[:,:idx:idx+in_stack_depth] += self(x[:, :, idx:idx+in_stack_depth]) # Size of slabs is (B, C, window_size, Y, X), C will be 2 for VSCyto3D - nucleus and membrane + weights[idx:idx+in_stack_depth] += 1.0 blended_slab = accum_tensor / weights.view(1, 1, -1, 1, 1) # Shape is (B, C, Z, Y, X) From a3679f824b4383da92209a28bb7195d47358d5d3 Mon Sep 17 00:00:00 2001 From: Taylla Theodoro Date: Tue, 26 Aug 2025 13:51:25 -0700 Subject: [PATCH 05/32] current stage --- viscy/translation/engine.py | 101 ++++++++++++++++++++++++++++-------- 1 file changed, 80 insertions(+), 21 deletions(-) diff --git a/viscy/translation/engine.py b/viscy/translation/engine.py index 227b7db7b..c0f6192fb 100644 --- a/viscy/translation/engine.py +++ b/viscy/translation/engine.py @@ -530,34 +530,93 @@ def __init__( def forward(self, x: Tensor) -> Tensor: return self.model(x) - # TODO: come up with better name - @torch.no_grad() - def inference_tiled(self, x: Tensor) -> Tensor: - # x.dype (Phase 3D) will be float32 - # x.device should be CUDA + # @torch.no_grad() + # def inference_tiled(self, x: Tensor) -> Tensor: + # """ + # Perform tiled inference on a 3D volume. + # Args: + # x (Tensor): Input tensor of shape (B, C, Z, Y, X) + # Returns: + # Tensor: Output tensor of shape (B, C_out, Z, Y, X) + # """ + # self.eval() + + # assert x.ndim == 5, f"Expected (B,C,Z,Y,X), got {x.shape}" + + # B,_,Z,Y,X = x.shape # BCZYX shape, Z is ~100 slices, B is 1 for real-time processing, C is 1 - phase + # in_stack_depth = self.model.out_stack_depth + + + # out_tensor = x.new_zeros((B,2, Z, Y, X)) + # print("out_tensor shape:", out_tensor.shape) + # print("in_stack_depth:",in_stack_depth) + # weights = x.new_zeros((1,1,Z,1,1)) # Z dimension + # step = 1 + # # use padding logic from on_predict_start() + # pad = self._predict_pad - assert x.ndim == 5, f"Expected (B,C,Z,Y,X), got {x.shape}" - B,_,Z,Y,X = x.shape # BCZYX shape, Z is ~100 slices, B is 1 for real-time processing, C is 1 - phase - output_shape = B,self.model.out_channels,Z,Y,X # C is 2 - nucleus and membrane - input_shape = x.shape - in_stack_depth = self.model.in_stack_depth # TODO: check read from some place, + # for start in range(0, Z, step): # TODO: make sure this goes over the whole volume + # end = min(start + in_stack_depth, Z) + # slab = x[:,:,start:end] # Shape is (B, C, window_size, Y, X), C will be 1 for Phase3D + # slab = pad(slab) + # pred = self(slab) # Shape is (B, C, window_size, Y, X), C will be 2 for VSCyto3D - nucleus and membrane + # pred = pad.inverse(pred) + # assert pred.shape[-3] == (end - start) + # out_tensor[:,:, start:end] += pred # Size of slabs is (B, C, window_size, Y, X), C will be 2 for VSCyto3D - nucleus and membrane + # weights[:,:, start:end] += 1.0 + # blended_slab = out_tensor / weights.clamp_min(1e-8) # Shape is (B, C, Z, Y, X) + # assert blended_slab.shape[-3] == Z - accum_tensor = torch.zeros(output_shape, dtype=torch.float32, device=x.device) - weights = torch.zeros((Z,), dtype=torch.float32, device=x.device) # Z dimension - step = 1 # TODO: verify - for idx in range(0, Z, step): # TODO: make sure this goes over the whole volume - accum_tensor[:,:idx:idx+in_stack_depth] += self(x[:, :, idx:idx+in_stack_depth]) # Size of slabs is (B, C, window_size, Y, X), C will be 2 for VSCyto3D - nucleus and membrane - weights[idx:idx+in_stack_depth] += 1.0 + # return blended_slab - blended_slab = accum_tensor / weights.view(1, 1, -1, 1, 1) # Shape is (B, C, Z, Y, X) - # TODO: add linear blending - #blended_slab = slabs / torch.from_numpy(weights).to(slabs.device).view(1, 1, -1, 1, 1) # Shape is (B, C, Z, Y, X) - assert blended_slab.shape[-3:] == input_shape[-3:] + @torch.no_grad() + def inference_tiled(self, x: Tensor) -> Tensor: + """ + Run tiled inference on a 3D volume with internal padding setup. + Args: + x (Tensor): Input of shape (B, C, Z, Y, X) + Returns: + Tensor: Output of shape (B, C_out, Z, Y, X) + """ + self.eval() + assert x.ndim == 5, f"Expected (B,C,Z,Y,X), got {x.shape}" + + B, _, Z, Y, X = x.shape + z_window = self.model.out_stack_depth # e.g., 21 + C_out = 2 # from config + step = 1 + + out_tensor = x.new_zeros((B, C_out, Z, Y, X)) + weights = x.new_zeros((1, 1, Z, 1, 1)) + + # ⬇️ create padding transform here (was on_predict_start) + down_factor = 32 # stem=4 × encoder=2×2×2 + pad = DivisiblePad((0, 0, down_factor, down_factor)) + + for start in range(0, Z, step): + print("Processing Z slice:", start) + end = min(start + z_window, Z) + slab = x[:, :, start:end] + print("slab shape:", slab.shape) # (B, 1, Zs, Y, X) + slab = pad(slab) + print("padded slab shape:", slab.shape) # (B, 1, Zs, Y_pad, X_pad) + pred = self(slab) + print("pred shape:", pred.shape) # (B, 2, Zs, Y_pad, X_pad) + pred = pad.inverse(pred) + print("pred shape:", pred.shape) # (B, 2, Zs, Y_pad, X_pad) + # crop back to original + + assert pred.shape[-3] == (end - start), f"Z mismatch: got {pred.shape[-3]}, expected {end - start}" + out_tensor[:, :, start:end] += pred + weights[:, :, start:end] += 1.0 + + blended = out_tensor / weights.clamp_min(1e-8) + assert blended.shape[-3] == Z + return blended - return blended_slab def setup(self, stage: str) -> None: From 5b81b34d4336d6762011d16f5ae768ae93dad87e Mon Sep 17 00:00:00 2001 From: Taylla Theodoro Date: Tue, 26 Aug 2025 14:04:02 -0700 Subject: [PATCH 06/32] working stage --- viscy/translation/engine.py | 89 +++++++++++-------------------------- 1 file changed, 25 insertions(+), 64 deletions(-) diff --git a/viscy/translation/engine.py b/viscy/translation/engine.py index c0f6192fb..10864ccd7 100644 --- a/viscy/translation/engine.py +++ b/viscy/translation/engine.py @@ -530,86 +530,46 @@ def __init__( def forward(self, x: Tensor) -> Tensor: return self.model(x) - # @torch.no_grad() - # def inference_tiled(self, x: Tensor) -> Tensor: - # """ - # Perform tiled inference on a 3D volume. - # Args: - # x (Tensor): Input tensor of shape (B, C, Z, Y, X) - # Returns: - # Tensor: Output tensor of shape (B, C_out, Z, Y, X) - # """ - # self.eval() - - # assert x.ndim == 5, f"Expected (B,C,Z,Y,X), got {x.shape}" - - # B,_,Z,Y,X = x.shape # BCZYX shape, Z is ~100 slices, B is 1 for real-time processing, C is 1 - phase - # in_stack_depth = self.model.out_stack_depth - - - # out_tensor = x.new_zeros((B,2, Z, Y, X)) - # print("out_tensor shape:", out_tensor.shape) - # print("in_stack_depth:",in_stack_depth) - # weights = x.new_zeros((1,1,Z,1,1)) # Z dimension - # step = 1 - # # use padding logic from on_predict_start() - # pad = self._predict_pad - - - # for start in range(0, Z, step): # TODO: make sure this goes over the whole volume - # end = min(start + in_stack_depth, Z) - # slab = x[:,:,start:end] # Shape is (B, C, window_size, Y, X), C will be 1 for Phase3D - # slab = pad(slab) - # pred = self(slab) # Shape is (B, C, window_size, Y, X), C will be 2 for VSCyto3D - nucleus and membrane - # pred = pad.inverse(pred) - # assert pred.shape[-3] == (end - start) - # out_tensor[:,:, start:end] += pred # Size of slabs is (B, C, window_size, Y, X), C will be 2 for VSCyto3D - nucleus and membrane - # weights[:,:, start:end] += 1.0 - - # blended_slab = out_tensor / weights.clamp_min(1e-8) # Shape is (B, C, Z, Y, X) - # assert blended_slab.shape[-3] == Z - - # return blended_slab - - @torch.no_grad() - def inference_tiled(self, x: Tensor) -> Tensor: + def inference_tiled(self, x: torch.Tensor) -> torch.Tensor: """ - Run tiled inference on a 3D volume with internal padding setup. + Run tiled inference over a 3D volume. Args: - x (Tensor): Input of shape (B, C, Z, Y, X) + x: Input tensor of shape (B, C, Z, Y, X) Returns: - Tensor: Output of shape (B, C_out, Z, Y, X) + Tensor of shape (B, C_out, Z, Y, X) """ self.eval() - assert x.ndim == 5, f"Expected (B,C,Z,Y,X), got {x.shape}" + assert x.ndim == 5, f"Expected shape (B,C,Z,Y,X), got {x.shape}" B, _, Z, Y, X = x.shape - z_window = self.model.out_stack_depth # e.g., 21 - C_out = 2 # from config + in_stack_depth = self.model.out_stack_depth + C_out = 2 # model was trained with out_channels=2 step = 1 out_tensor = x.new_zeros((B, C_out, Z, Y, X)) weights = x.new_zeros((1, 1, Z, 1, 1)) - # ⬇️ create padding transform here (was on_predict_start) - down_factor = 32 # stem=4 × encoder=2×2×2 - pad = DivisiblePad((0, 0, down_factor, down_factor)) + pad = getattr(self, "_predict_pad", None) + if pad is None: + raise RuntimeError("Missing _predict_pad; call on_predict_start() before inference.") for start in range(0, Z, step): - print("Processing Z slice:", start) - end = min(start + z_window, Z) - slab = x[:, :, start:end] - print("slab shape:", slab.shape) # (B, 1, Zs, Y, X) + end = min(start + in_stack_depth, Z) + slab = x[:, :, start:end] + + # pad if last slab is shorter + if end - start < in_stack_depth: + pad_z = in_stack_depth - (end - start) + slab = torch.nn.functional.pad(slab, (0, 0, 0, 0, 0, pad_z)) + slab = pad(slab) - print("padded slab shape:", slab.shape) # (B, 1, Zs, Y_pad, X_pad) - pred = self(slab) - print("pred shape:", pred.shape) # (B, 2, Zs, Y_pad, X_pad) - pred = pad.inverse(pred) - print("pred shape:", pred.shape) # (B, 2, Zs, Y_pad, X_pad) - # crop back to original - - assert pred.shape[-3] == (end - start), f"Z mismatch: got {pred.shape[-3]}, expected {end - start}" + pred = self(slab) + pred = pad.inverse(pred) + + # clip prediction if padded in Z + pred = pred[:, :, : end - start] + out_tensor[:, :, start:end] += pred weights[:, :, start:end] += 1.0 @@ -619,6 +579,7 @@ def inference_tiled(self, x: Tensor) -> Tensor: + def setup(self, stage: str) -> None: if stage != "predict": raise NotImplementedError( From 30b6fbf229cc7e11c985361b5cb8b24b5ff2a36a Mon Sep 17 00:00:00 2001 From: Taylla Theodoro Date: Tue, 26 Aug 2025 16:57:25 -0700 Subject: [PATCH 07/32] api wrapper --- viscy/api/__init__.py | 0 viscy/api/inference.py | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+) create mode 100644 viscy/api/__init__.py create mode 100644 viscy/api/inference.py diff --git a/viscy/api/__init__.py b/viscy/api/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/viscy/api/inference.py b/viscy/api/inference.py new file mode 100644 index 000000000..9832bd9a5 --- /dev/null +++ b/viscy/api/inference.py @@ -0,0 +1,36 @@ +import torch +import importlib +from viscy.translation.engine import AugmentedPredictionVSUNet + +@torch.no_grad() +def VS_inference_t2t(cfg: dict, input_tensor: torch.Tensor) -> torch.Tensor: + """ + Run virtual staining using a config dictionary and 5D input tensor (B, C, Z, Y, X). + Returns predicted tensor of shape (B, C_out, Z, Y, X). + """ + + # Extract model info + model_cfg = cfg["model"].copy() + init_args = model_cfg["init_args"] + class_path = model_cfg["class_path"] + + # Inject ckpt_path from top-level config if needed + if "ckpt_path" in cfg: + init_args["ckpt_path"] = cfg["ckpt_path"] + + # Import model class dynamically + module_path, class_name = class_path.rsplit(".", 1) + model_class = getattr(importlib.import_module(module_path), class_name) + + # Instantiate model + model = model_class(**init_args).to(input_tensor.device).eval() + + # Wrap with augmentation logic + wrapper = AugmentedPredictionVSUNet( + model=model.model, + forward_transforms=[lambda t: t], + inverse_transforms=[lambda t: t], + ).to(input_tensor.device).eval() + + wrapper.on_predict_start() + return wrapper.inference_tiled(input_tensor) From 433d19555cc704cfc3e68b8d0252f904d7c06b7a Mon Sep 17 00:00:00 2001 From: Taylla Theodoro Date: Tue, 26 Aug 2025 17:01:04 -0700 Subject: [PATCH 08/32] add examples --- .../VS_model_inference/demo_api.py | 127 ++++++++++++++++++ 1 file changed, 127 insertions(+) create mode 100644 examples/virtual_staining/VS_model_inference/demo_api.py diff --git a/examples/virtual_staining/VS_model_inference/demo_api.py b/examples/virtual_staining/VS_model_inference/demo_api.py new file mode 100644 index 000000000..bb000fe82 --- /dev/null +++ b/examples/virtual_staining/VS_model_inference/demo_api.py @@ -0,0 +1,127 @@ + +# examples/inference_from_config.py +import time +from pathlib import Path +import numpy as np +import torch +from iohub import open_ome_zarr +import napari + +from viscy.api.inference import VS_inference_t2t + +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# Configuration dictionary (from CLI .yaml) +config = { + "model": { + "class_path": "viscy.translation.engine.VSUNet", + "init_args": { + "architecture": "fcmae", + "model_config": { + "in_channels": 1, + "out_channels": 2, + "in_stack_depth": 21, + "encoder_blocks": [3, 3, 9, 3], + "dims": [96, 192, 384, 768], + "encoder_drop_path_rate": 0.0, + "stem_kernel_size": [7, 4, 4], + "decoder_conv_blocks": 2, + "pretraining": False, + "head_conv": True, + "head_conv_expansion_ratio": 4, + "head_conv_pool": False, + }, + }, + "test_time_augmentations": True, + "tta_type": "median", + }, + "ckpt_path": "/path/to/checkpoint.ckpt" +} + +# Load Phase3D input volume +path = Path("/path/to/your.zarr/0/1/000000") +with open_ome_zarr(path) as ds: + vol_np = np.asarray(ds.data[0, 0]) # (Z, Y, X) + +vol = torch.from_numpy(vol_np).unsqueeze(0).unsqueeze(0).float().to(DEVICE) # (B=1, C=1, Z, Y, X) + +# Run model +start = time.time() +pred = VS_inference_t2t(vol, config) +torch.cuda.synchronize() +print(f"Inference time: {time.time() - start:.2f} sec") + +# Visualize +pred_np = pred.cpu().numpy() +nuc, mem = pred_np[0, 0], pred_np[0, 1] + +viewer = napari.Viewer() +viewer.add_image(vol_np, name="phase_input", colormap="gray") +viewer.add_image(nuc, name="virt_nuclei", colormap="magenta") +viewer.add_image(mem, name="virt_membrane", colormap="cyan") +napari.run() + + +#%% + +# examples/inference_manual.py +import time +from pathlib import Path +import numpy as np +import torch +from iohub import open_ome_zarr +import napari + +from viscy.translation.engine import VSUNet, AugmentedPredictionVSUNet + +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# Instantiate model manually +vs = VSUNet( + architecture="fcmae", + model_config={ + "in_channels": 1, + "out_channels": 2, + "in_stack_depth": 21, + "encoder_blocks": [3, 3, 9, 3], + "dims": [96, 192, 384, 768], + "decoder_conv_blocks": 2, + "stem_kernel_size": [7, 4, 4], + "pretraining": False, + "head_conv": True, + "head_conv_expansion_ratio": 4, + "head_conv_pool": False, + }, + ckpt_path="/path/to/checkpoint.ckpt", + test_time_augmentations=True, + tta_type="median", +).to(DEVICE).eval() + +wrapper = AugmentedPredictionVSUNet( + model=vs.model, + forward_transforms=[lambda t: t], + inverse_transforms=[lambda t: t], +).to(DEVICE).eval() +wrapper.on_predict_start() + +# Load data +path = Path("/path/to/your.zarr/0/1/000000") +with open_ome_zarr(path) as ds: + vol_np = np.asarray(ds.data[0, 0]) # (Z, Y, X) + +vol = torch.from_numpy(vol_np).unsqueeze(0).unsqueeze(0).float().to(DEVICE) + +# Run inference +with torch.no_grad(): + pred = wrapper.inference_tiled(vol) +torch.cuda.synchronize() + +# Visualize +pred_np = pred.cpu().numpy() +nuc, mem = pred_np[0, 0], pred_np[0, 1] + +viewer = napari.Viewer() +viewer.add_image(vol_np, name="phase_input", colormap="gray") +viewer.add_image(nuc, name="virt_nuclei", colormap="magenta") +viewer.add_image(mem, name="virt_membrane", colormap="cyan") +napari.run() From 324cdf1d10c01bd3067f17a46ebb23ee63435144 Mon Sep 17 00:00:00 2001 From: Taylla Theodoro Date: Tue, 26 Aug 2025 17:01:20 -0700 Subject: [PATCH 09/32] reorder inputs --- viscy/api/inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/viscy/api/inference.py b/viscy/api/inference.py index 9832bd9a5..734e336f3 100644 --- a/viscy/api/inference.py +++ b/viscy/api/inference.py @@ -3,7 +3,7 @@ from viscy.translation.engine import AugmentedPredictionVSUNet @torch.no_grad() -def VS_inference_t2t(cfg: dict, input_tensor: torch.Tensor) -> torch.Tensor: +def VS_inference_t2t(input_tensor: torch.Tensor, cfg: dict) -> torch.Tensor: """ Run virtual staining using a config dictionary and 5D input tensor (B, C, Z, Y, X). Returns predicted tensor of shape (B, C_out, Z, Y, X). From 8b4d24368eaf123e4bb7c126d75b3d322d2dd475 Mon Sep 17 00:00:00 2001 From: Taylla Theodoro Date: Tue, 26 Aug 2025 17:01:48 -0700 Subject: [PATCH 10/32] rename input tensor --- viscy/api/inference.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/viscy/api/inference.py b/viscy/api/inference.py index 734e336f3..2587afbfe 100644 --- a/viscy/api/inference.py +++ b/viscy/api/inference.py @@ -3,7 +3,7 @@ from viscy.translation.engine import AugmentedPredictionVSUNet @torch.no_grad() -def VS_inference_t2t(input_tensor: torch.Tensor, cfg: dict) -> torch.Tensor: +def VS_inference_t2t(x: torch.Tensor, cfg: dict) -> torch.Tensor: """ Run virtual staining using a config dictionary and 5D input tensor (B, C, Z, Y, X). Returns predicted tensor of shape (B, C_out, Z, Y, X). @@ -23,14 +23,14 @@ def VS_inference_t2t(input_tensor: torch.Tensor, cfg: dict) -> torch.Tensor: model_class = getattr(importlib.import_module(module_path), class_name) # Instantiate model - model = model_class(**init_args).to(input_tensor.device).eval() + model = model_class(**init_args).to(x.device).eval() # Wrap with augmentation logic wrapper = AugmentedPredictionVSUNet( model=model.model, forward_transforms=[lambda t: t], inverse_transforms=[lambda t: t], - ).to(input_tensor.device).eval() + ).to(x.device).eval() wrapper.on_predict_start() - return wrapper.inference_tiled(input_tensor) + return wrapper.inference_tiled(x) From 04cbbb1795b44c9f66b1a247db805ce0c4b1f660 Mon Sep 17 00:00:00 2001 From: Taylla Theodoro Date: Tue, 26 Aug 2025 17:03:03 -0700 Subject: [PATCH 11/32] remove comment --- examples/virtual_staining/VS_model_inference/demo_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/virtual_staining/VS_model_inference/demo_api.py b/examples/virtual_staining/VS_model_inference/demo_api.py index bb000fe82..09ac5ebb2 100644 --- a/examples/virtual_staining/VS_model_inference/demo_api.py +++ b/examples/virtual_staining/VS_model_inference/demo_api.py @@ -1,5 +1,5 @@ -# examples/inference_from_config.py +# %% import time from pathlib import Path import numpy as np From 214bbeb168c0296ea05203f86d6687d5283e1828 Mon Sep 17 00:00:00 2001 From: Taylla Theodoro Date: Tue, 26 Aug 2025 17:45:04 -0700 Subject: [PATCH 12/32] test --- tests/api/__init__.py | 0 tests/api/test_inference.py | 37 +++++++++++++++++++++++++++++++++++++ viscy/translation/engine.py | 20 +++++++++++--------- 3 files changed, 48 insertions(+), 9 deletions(-) create mode 100644 tests/api/__init__.py create mode 100644 tests/api/test_inference.py diff --git a/tests/api/__init__.py b/tests/api/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/api/test_inference.py b/tests/api/test_inference.py new file mode 100644 index 000000000..55529116d --- /dev/null +++ b/tests/api/test_inference.py @@ -0,0 +1,37 @@ +import torch +from viscy.api.inference import VS_inference_t2t + +def test_vs_inference_t2t(): + in_stack_depth = 21 + dims = [24, 48, 96, 192] # dims[0] must be divisible by ratio (24/3=8) + + cfg = { + "model": { + "class_path": "viscy.translation.engine.VSUNet", + "init_args": { + "architecture": "fcmae", + "model_config": { + "in_channels": 1, + "out_channels": 2, + "in_stack_depth": in_stack_depth, + "encoder_blocks": [1, 1, 1, 1], + "dims": dims, + "stem_kernel_size": [7, 4, 4], + "pretraining": False, + "decoder_conv_blocks": 1, + "head_conv": True, + "head_conv_expansion_ratio": 2, + "head_conv_pool": False, + }, + "test_time_augmentations": False, + "tta_type": "none", + "ckpt_path": None, + } + } + } + + x = torch.rand(1, 1, in_stack_depth, 64, 64) + pred = VS_inference_t2t(x, cfg) + + assert isinstance(pred, torch.Tensor) + assert pred.shape == (1, 2, in_stack_depth, 64, 64), f"Unexpected shape: {pred.shape}" diff --git a/viscy/translation/engine.py b/viscy/translation/engine.py index 10864ccd7..655aba048 100644 --- a/viscy/translation/engine.py +++ b/viscy/translation/engine.py @@ -531,28 +531,30 @@ def forward(self, x: Tensor) -> Tensor: return self.model(x) @torch.no_grad() - def inference_tiled(self, x: torch.Tensor) -> torch.Tensor: + def inference_tiled(self, x: torch.Tensor, out_channel: int = 2, step: int =1) -> torch.Tensor: """ - Run tiled inference over a 3D volume. - Args: - x: Input tensor of shape (B, C, Z, Y, X) - Returns: - Tensor of shape (B, C_out, Z, Y, X) + Example: + pred = VS_inference_t2t(input_tensor, config) + # input_tensor: torch.Tensor of shape (B, 1, Z, Y, X) + # pred: torch.Tensor of shape (B, 2, Z, Y, X) """ + self.eval() assert x.ndim == 5, f"Expected shape (B,C,Z,Y,X), got {x.shape}" B, _, Z, Y, X = x.shape in_stack_depth = self.model.out_stack_depth - C_out = 2 # model was trained with out_channels=2 - step = 1 - out_tensor = x.new_zeros((B, C_out, Z, Y, X)) + out_tensor = x.new_zeros((B, out_channel, Z, Y, X)) weights = x.new_zeros((1, 1, Z, 1, 1)) pad = getattr(self, "_predict_pad", None) if pad is None: raise RuntimeError("Missing _predict_pad; call on_predict_start() before inference.") + if in_stack_depth > Z: + raise ValueError( + f"Input stack depth {in_stack_depth} is larger than input Z dimension {Z}" + ) for start in range(0, Z, step): end = min(start + in_stack_depth, Z) From 22d5f4fb5918e9614540f6dcca015d212e7a7651 Mon Sep 17 00:00:00 2001 From: Taylla Theodoro Date: Tue, 26 Aug 2025 17:51:45 -0700 Subject: [PATCH 13/32] style --- tests/api/test_inference.py | 7 +++++-- viscy/api/inference.py | 15 ++++++++++----- viscy/translation/engine.py | 11 ++++++----- 3 files changed, 21 insertions(+), 12 deletions(-) diff --git a/tests/api/test_inference.py b/tests/api/test_inference.py index 55529116d..8e129070c 100644 --- a/tests/api/test_inference.py +++ b/tests/api/test_inference.py @@ -1,6 +1,7 @@ import torch from viscy.api.inference import VS_inference_t2t + def test_vs_inference_t2t(): in_stack_depth = 21 dims = [24, 48, 96, 192] # dims[0] must be divisible by ratio (24/3=8) @@ -26,7 +27,7 @@ def test_vs_inference_t2t(): "test_time_augmentations": False, "tta_type": "none", "ckpt_path": None, - } + }, } } @@ -34,4 +35,6 @@ def test_vs_inference_t2t(): pred = VS_inference_t2t(x, cfg) assert isinstance(pred, torch.Tensor) - assert pred.shape == (1, 2, in_stack_depth, 64, 64), f"Unexpected shape: {pred.shape}" + assert pred.shape == (1, 2, in_stack_depth, 64, 64), ( + f"Unexpected shape: {pred.shape}" + ) diff --git a/viscy/api/inference.py b/viscy/api/inference.py index 2587afbfe..3b56ced3a 100644 --- a/viscy/api/inference.py +++ b/viscy/api/inference.py @@ -2,6 +2,7 @@ import importlib from viscy.translation.engine import AugmentedPredictionVSUNet + @torch.no_grad() def VS_inference_t2t(x: torch.Tensor, cfg: dict) -> torch.Tensor: """ @@ -26,11 +27,15 @@ def VS_inference_t2t(x: torch.Tensor, cfg: dict) -> torch.Tensor: model = model_class(**init_args).to(x.device).eval() # Wrap with augmentation logic - wrapper = AugmentedPredictionVSUNet( - model=model.model, - forward_transforms=[lambda t: t], - inverse_transforms=[lambda t: t], - ).to(x.device).eval() + wrapper = ( + AugmentedPredictionVSUNet( + model=model.model, + forward_transforms=[lambda t: t], + inverse_transforms=[lambda t: t], + ) + .to(x.device) + .eval() + ) wrapper.on_predict_start() return wrapper.inference_tiled(x) diff --git a/viscy/translation/engine.py b/viscy/translation/engine.py index 655aba048..6ff4528b1 100644 --- a/viscy/translation/engine.py +++ b/viscy/translation/engine.py @@ -531,7 +531,9 @@ def forward(self, x: Tensor) -> Tensor: return self.model(x) @torch.no_grad() - def inference_tiled(self, x: torch.Tensor, out_channel: int = 2, step: int =1) -> torch.Tensor: + def inference_tiled( + self, x: torch.Tensor, out_channel: int = 2, step: int = 1 + ) -> torch.Tensor: """ Example: pred = VS_inference_t2t(input_tensor, config) @@ -550,7 +552,9 @@ def inference_tiled(self, x: torch.Tensor, out_channel: int = 2, step: int =1) - pad = getattr(self, "_predict_pad", None) if pad is None: - raise RuntimeError("Missing _predict_pad; call on_predict_start() before inference.") + raise RuntimeError( + "Missing _predict_pad; call on_predict_start() before inference." + ) if in_stack_depth > Z: raise ValueError( f"Input stack depth {in_stack_depth} is larger than input Z dimension {Z}" @@ -579,9 +583,6 @@ def inference_tiled(self, x: torch.Tensor, out_channel: int = 2, step: int =1) - assert blended.shape[-3] == Z return blended - - - def setup(self, stage: str) -> None: if stage != "predict": raise NotImplementedError( From f5eacca2178043c20c66229a915f0ab787543020 Mon Sep 17 00:00:00 2001 From: Taylla Theodoro Date: Tue, 26 Aug 2025 17:55:01 -0700 Subject: [PATCH 14/32] add test docstrig --- tests/api/test_inference.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/api/test_inference.py b/tests/api/test_inference.py index 8e129070c..0fc0ee718 100644 --- a/tests/api/test_inference.py +++ b/tests/api/test_inference.py @@ -3,6 +3,9 @@ def test_vs_inference_t2t(): + """ + Test the VS_inference_t2t function with a simple config and random input. + """ in_stack_depth = 21 dims = [24, 48, 96, 192] # dims[0] must be divisible by ratio (24/3=8) From 9cf5870a17a4f57d0f87b79c6017050f852395b8 Mon Sep 17 00:00:00 2001 From: Taylla Theodoro Date: Tue, 26 Aug 2025 17:55:55 -0700 Subject: [PATCH 15/32] style --- viscy/api/inference.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/viscy/api/inference.py b/viscy/api/inference.py index 3b56ced3a..e2e1363e9 100644 --- a/viscy/api/inference.py +++ b/viscy/api/inference.py @@ -1,5 +1,7 @@ -import torch import importlib + +import torch + from viscy.translation.engine import AugmentedPredictionVSUNet From 612db6465333b83ddaa78d6a837698832dede4a5 Mon Sep 17 00:00:00 2001 From: Taylla Theodoro Date: Thu, 28 Aug 2025 12:55:59 -0700 Subject: [PATCH 16/32] first pass of review corrections --- .../VS_model_inference/demo_api.py | 2 +- tests/api/test_inference.py | 2 +- viscy/api/__init__.py | 0 viscy/translation/engine.py | 44 ++++++++++++++----- viscy/{api => translation}/inference.py | 10 +++-- 5 files changed, 43 insertions(+), 15 deletions(-) delete mode 100644 viscy/api/__init__.py rename viscy/{api => translation}/inference.py (76%) diff --git a/examples/virtual_staining/VS_model_inference/demo_api.py b/examples/virtual_staining/VS_model_inference/demo_api.py index 09ac5ebb2..3dc35cd9c 100644 --- a/examples/virtual_staining/VS_model_inference/demo_api.py +++ b/examples/virtual_staining/VS_model_inference/demo_api.py @@ -7,7 +7,7 @@ from iohub import open_ome_zarr import napari -from viscy.api.inference import VS_inference_t2t +from viscy.translation.inference import VS_inference_t2t DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") diff --git a/tests/api/test_inference.py b/tests/api/test_inference.py index 0fc0ee718..1cf20edaa 100644 --- a/tests/api/test_inference.py +++ b/tests/api/test_inference.py @@ -1,5 +1,5 @@ import torch -from viscy.api.inference import VS_inference_t2t +from viscy.translation.inference import VS_inference_t2t def test_vs_inference_t2t(): diff --git a/viscy/api/__init__.py b/viscy/api/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/viscy/translation/engine.py b/viscy/translation/engine.py index 6ff4528b1..31a66cf90 100644 --- a/viscy/translation/engine.py +++ b/viscy/translation/engine.py @@ -531,10 +531,27 @@ def forward(self, x: Tensor) -> Tensor: return self.model(x) @torch.no_grad() - def inference_tiled( + def predict_sliding_windows( self, x: torch.Tensor, out_channel: int = 2, step: int = 1 ) -> torch.Tensor: """ + Run inference on a 5D input tensor (B, C, Z, Y, X) using sliding windows + along the Z dimension with overlap and blending. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (B, C, Z, Y, X). + out_channel : int, optional + Number of output channels, by default 2. + step : int, optional + Step size for sliding window along Z, by default 1. + Returns + ------- + torch.Tensor + Output tensor of shape (B, out_channel, Z, Y, X). + Notes + ----- Example: pred = VS_inference_t2t(input_tensor, config) # input_tensor: torch.Tensor of shape (B, 1, Z, Y, X) @@ -542,26 +559,29 @@ def inference_tiled( """ self.eval() - assert x.ndim == 5, f"Expected shape (B,C,Z,Y,X), got {x.shape}" - B, _, Z, Y, X = x.shape + if x.ndim != 5: + raise ValueError(f"Expected input with 5 dimensions (B, C, Z, Y, X), but got shape {x.shape}") + + batch_size, _, depth, height, width = x.shape + in_stack_depth = self.model.out_stack_depth - out_tensor = x.new_zeros((B, out_channel, Z, Y, X)) - weights = x.new_zeros((1, 1, Z, 1, 1)) + out_tensor = x.new_zeros((batch_size, out_channel, depth, height, width)) + weights = x.new_zeros((1, 1, depth, 1, 1)) pad = getattr(self, "_predict_pad", None) if pad is None: raise RuntimeError( "Missing _predict_pad; call on_predict_start() before inference." ) - if in_stack_depth > Z: + if in_stack_depth > depth: raise ValueError( - f"Input stack depth {in_stack_depth} is larger than input Z dimension {Z}" + f"Input stack depth {in_stack_depth} is larger than input Z dimension {depth}" ) - for start in range(0, Z, step): - end = min(start + in_stack_depth, Z) + for start in range(0, depth, step): + end = min(start + in_stack_depth, depth) slab = x[:, :, start:end] # pad if last slab is shorter @@ -580,7 +600,11 @@ def inference_tiled( weights[:, :, start:end] += 1.0 blended = out_tensor / weights.clamp_min(1e-8) - assert blended.shape[-3] == Z + if not blended.shape[-3] == depth: + raise ValueError( + f"Output depth {blended.shape[-3]} matches input depth {Z}, " + "something went wrong in sliding window inference" + ) return blended def setup(self, stage: str) -> None: diff --git a/viscy/api/inference.py b/viscy/translation/inference.py similarity index 76% rename from viscy/api/inference.py rename to viscy/translation/inference.py index e2e1363e9..1f3937657 100644 --- a/viscy/api/inference.py +++ b/viscy/translation/inference.py @@ -6,11 +6,15 @@ @torch.no_grad() -def VS_inference_t2t(x: torch.Tensor, cfg: dict) -> torch.Tensor: +def vs_inference_t2t(x: torch.Tensor, cfg: dict, gpu: bool = True) -> torch.Tensor: """ Run virtual staining using a config dictionary and 5D input tensor (B, C, Z, Y, X). Returns predicted tensor of shape (B, C_out, Z, Y, X). """ + if gpu: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + device = torch.device("cpu") # Extract model info model_cfg = cfg["model"].copy() @@ -26,7 +30,7 @@ def VS_inference_t2t(x: torch.Tensor, cfg: dict) -> torch.Tensor: model_class = getattr(importlib.import_module(module_path), class_name) # Instantiate model - model = model_class(**init_args).to(x.device).eval() + model = model_class(**init_args).to(device).eval() # Wrap with augmentation logic wrapper = ( @@ -40,4 +44,4 @@ def VS_inference_t2t(x: torch.Tensor, cfg: dict) -> torch.Tensor: ) wrapper.on_predict_start() - return wrapper.inference_tiled(x) + return wrapper.predict_sliding_windows(x) From 227ef0f66cbc5d46a9c8dd2a9db9ee1910d57d4f Mon Sep 17 00:00:00 2001 From: Taylla Theodoro Date: Thu, 28 Aug 2025 13:11:35 -0700 Subject: [PATCH 17/32] bug fix --- viscy/translation/engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/viscy/translation/engine.py b/viscy/translation/engine.py index 31a66cf90..435215da3 100644 --- a/viscy/translation/engine.py +++ b/viscy/translation/engine.py @@ -602,7 +602,7 @@ def predict_sliding_windows( blended = out_tensor / weights.clamp_min(1e-8) if not blended.shape[-3] == depth: raise ValueError( - f"Output depth {blended.shape[-3]} matches input depth {Z}, " + f"Output depth {blended.shape[-3]} matches input depth {depth}, " "something went wrong in sliding window inference" ) return blended From 1cb4b74cea7b95fa5cf2c5098610a4fbe2d5329e Mon Sep 17 00:00:00 2001 From: Taylla Theodoro Date: Thu, 28 Aug 2025 13:40:22 -0700 Subject: [PATCH 18/32] move to translation tests --- tests/api/__init__.py | 0 tests/{api => translation}/test_inference.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) delete mode 100644 tests/api/__init__.py rename tests/{api => translation}/test_inference.py (100%) diff --git a/tests/api/__init__.py b/tests/api/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/api/test_inference.py b/tests/translation/test_inference.py similarity index 100% rename from tests/api/test_inference.py rename to tests/translation/test_inference.py From 236dfc6a4a045dfdae76e06dac1faf2b8326bab6 Mon Sep 17 00:00:00 2001 From: Taylla Theodoro Date: Thu, 28 Aug 2025 14:36:38 -0700 Subject: [PATCH 19/32] use predict_step --- viscy/translation/engine.py | 104 +++++++++++++++++++----------------- 1 file changed, 54 insertions(+), 50 deletions(-) diff --git a/viscy/translation/engine.py b/viscy/translation/engine.py index 435215da3..b64a260a0 100644 --- a/viscy/translation/engine.py +++ b/viscy/translation/engine.py @@ -529,6 +529,8 @@ def __init__( def forward(self, x: Tensor) -> Tensor: return self.model(x) + + @torch.no_grad() def predict_sliding_windows( @@ -536,8 +538,9 @@ def predict_sliding_windows( ) -> torch.Tensor: """ Run inference on a 5D input tensor (B, C, Z, Y, X) using sliding windows - along the Z dimension with overlap and blending. - + along the Z dimension with overlap and blending. Applies test-time augmentations + and padding as defined in predict_step. + Parameters ---------- x : torch.Tensor @@ -546,65 +549,49 @@ def predict_sliding_windows( Number of output channels, by default 2. step : int, optional Step size for sliding window along Z, by default 1. + Returns ------- torch.Tensor Output tensor of shape (B, out_channel, Z, Y, X). - Notes - ----- - Example: - pred = VS_inference_t2t(input_tensor, config) - # input_tensor: torch.Tensor of shape (B, 1, Z, Y, X) - # pred: torch.Tensor of shape (B, 2, Z, Y, X) """ self.eval() if x.ndim != 5: - raise ValueError(f"Expected input with 5 dimensions (B, C, Z, Y, X), but got shape {x.shape}") + raise ValueError(f"Expected input with 5 dimensions (B, C, Z, Y, X), got {x.shape}") batch_size, _, depth, height, width = x.shape - in_stack_depth = self.model.out_stack_depth + if not hasattr(self, "_predict_pad"): + raise RuntimeError("Missing _predict_pad; make sure to call `on_predict_start()` before inference.") + if in_stack_depth > depth: + raise ValueError(f"in_stack_depth {in_stack_depth} > input depth {depth}") + out_tensor = x.new_zeros((batch_size, out_channel, depth, height, width)) weights = x.new_zeros((1, 1, depth, 1, 1)) - pad = getattr(self, "_predict_pad", None) - if pad is None: - raise RuntimeError( - "Missing _predict_pad; call on_predict_start() before inference." - ) - if in_stack_depth > depth: - raise ValueError( - f"Input stack depth {in_stack_depth} is larger than input Z dimension {depth}" - ) - + # Loop over Z with overlapping slabs for start in range(0, depth, step): end = min(start + in_stack_depth, depth) slab = x[:, :, start:end] - # pad if last slab is shorter if end - start < in_stack_depth: pad_z = in_stack_depth - (end - start) - slab = torch.nn.functional.pad(slab, (0, 0, 0, 0, 0, pad_z)) + slab = F.pad(slab, (0, 0, 0, 0, 0, pad_z)) - slab = pad(slab) - pred = self(slab) - pred = pad.inverse(pred) - - # clip prediction if padded in Z - pred = pred[:, :, : end - start] + # Use the same logic as predict_step (TTA + pad + model + unpad) + pred = self._predict_with_tta(slab) + pred = pred[:, :, : end - start] # Trim if Z was padded out_tensor[:, :, start:end] += pred weights[:, :, start:end] += 1.0 - blended = out_tensor / weights.clamp_min(1e-8) - if not blended.shape[-3] == depth: - raise ValueError( - f"Output depth {blended.shape[-3]} matches input depth {depth}, " - "something went wrong in sliding window inference" - ) + if (weights == 0).any(): + raise RuntimeError("Some Z slices were not covered during sliding window inference.") + + blended = out_tensor / weights return blended def setup(self, stage: str) -> None: @@ -620,26 +607,43 @@ def _reduce_predictions(self, preds: list[Tensor]) -> Tensor: elif self._reduction == "median": prediction = prediction.median(dim=0).values return prediction + + def _predict_with_tta(self, source: torch.Tensor) -> torch.Tensor: + preds = [] + for fwd_t, inv_t in zip(self._forward_transforms or [lambda x: x], + self._inverse_transforms or [lambda x: x]): + src = fwd_t(source) + src = self._predict_pad(src) + y = self.forward(src) + y = self._predict_pad.inverse(y) + preds.append(inv_t(y)) + return preds[0] if len(preds) == 1 else self._reduce_predictions(preds) def predict_step( self, batch: Sample, batch_idx: int, dataloader_idx: int = 0 - ) -> Tensor: + ) -> torch.Tensor: + """ + Lightning's built-in prediction step. This method is called by the Trainer during `.predict(...)`. + + Applies test-time augmentations (TTA) and padding logic to the input batch["source"]. + + Parameters + ---------- + batch : dict[str, Tensor] + A dictionary containing a "source" tensor of shape (B, C, Z, Y, X). + batch_idx : int + Index of the batch. + dataloader_idx : int + Index of the dataloader if multiple dataloaders are used. + + Returns + ------- + torch.Tensor + Prediction tensor of shape (B, out_channels, Z, Y, X). + """ source = batch["source"] - preds = [] - for forward_t, inverse_t in zip( - self._forward_transforms, self._inverse_transforms - ): - source = forward_t(source) - source = self._predict_pad(source) - pred = self.forward(source) - pred = self._predict_pad.inverse(pred) - pred = inverse_t(pred) - preds.append(pred) - if len(preds) == 1: - prediction = preds[0] - else: - prediction = self._reduce_predictions(preds) - return prediction + return self._predict_with_tta(source) + class FcmaeUNet(VSUNet): From 5117de2e3a959273cef02cc7d03132f0e6f3c788 Mon Sep 17 00:00:00 2001 From: Taylla Theodoro Date: Thu, 28 Aug 2025 14:48:50 -0700 Subject: [PATCH 20/32] update example --- .../VS_model_inference/demo_api.py | 76 +------------------ 1 file changed, 4 insertions(+), 72 deletions(-) diff --git a/examples/virtual_staining/VS_model_inference/demo_api.py b/examples/virtual_staining/VS_model_inference/demo_api.py index 3dc35cd9c..11fdd27df 100644 --- a/examples/virtual_staining/VS_model_inference/demo_api.py +++ b/examples/virtual_staining/VS_model_inference/demo_api.py @@ -1,71 +1,4 @@ - -# %% -import time -from pathlib import Path -import numpy as np -import torch -from iohub import open_ome_zarr -import napari - -from viscy.translation.inference import VS_inference_t2t - -DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - -# Configuration dictionary (from CLI .yaml) -config = { - "model": { - "class_path": "viscy.translation.engine.VSUNet", - "init_args": { - "architecture": "fcmae", - "model_config": { - "in_channels": 1, - "out_channels": 2, - "in_stack_depth": 21, - "encoder_blocks": [3, 3, 9, 3], - "dims": [96, 192, 384, 768], - "encoder_drop_path_rate": 0.0, - "stem_kernel_size": [7, 4, 4], - "decoder_conv_blocks": 2, - "pretraining": False, - "head_conv": True, - "head_conv_expansion_ratio": 4, - "head_conv_pool": False, - }, - }, - "test_time_augmentations": True, - "tta_type": "median", - }, - "ckpt_path": "/path/to/checkpoint.ckpt" -} - -# Load Phase3D input volume -path = Path("/path/to/your.zarr/0/1/000000") -with open_ome_zarr(path) as ds: - vol_np = np.asarray(ds.data[0, 0]) # (Z, Y, X) - -vol = torch.from_numpy(vol_np).unsqueeze(0).unsqueeze(0).float().to(DEVICE) # (B=1, C=1, Z, Y, X) - -# Run model -start = time.time() -pred = VS_inference_t2t(vol, config) -torch.cuda.synchronize() -print(f"Inference time: {time.time() - start:.2f} sec") - -# Visualize -pred_np = pred.cpu().numpy() -nuc, mem = pred_np[0, 0], pred_np[0, 1] - -viewer = napari.Viewer() -viewer.add_image(vol_np, name="phase_input", colormap="gray") -viewer.add_image(nuc, name="virt_nuclei", colormap="magenta") -viewer.add_image(mem, name="virt_membrane", colormap="cyan") -napari.run() - - #%% - -# examples/inference_manual.py -import time from pathlib import Path import numpy as np import torch @@ -77,7 +10,7 @@ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Instantiate model manually -vs = VSUNet( +model = VSUNet( architecture="fcmae", model_config={ "in_channels": 1, @@ -97,12 +30,11 @@ tta_type="median", ).to(DEVICE).eval() -wrapper = AugmentedPredictionVSUNet( - model=vs.model, +vs = AugmentedPredictionVSUNet( + model=model.model, forward_transforms=[lambda t: t], inverse_transforms=[lambda t: t], ).to(DEVICE).eval() -wrapper.on_predict_start() # Load data path = Path("/path/to/your.zarr/0/1/000000") @@ -113,7 +45,7 @@ # Run inference with torch.no_grad(): - pred = wrapper.inference_tiled(vol) + pred = vs.predict_sliding_windows(vol) torch.cuda.synchronize() # Visualize From 7b1be81be705de27cb3d88e2f2c262604e058ebb Mon Sep 17 00:00:00 2001 From: Taylla Theodoro Date: Thu, 28 Aug 2025 14:51:27 -0700 Subject: [PATCH 21/32] move to shrimpy --- tests/translation/test_inference.py | 43 -------------------------- viscy/translation/inference.py | 47 ----------------------------- 2 files changed, 90 deletions(-) delete mode 100644 tests/translation/test_inference.py delete mode 100644 viscy/translation/inference.py diff --git a/tests/translation/test_inference.py b/tests/translation/test_inference.py deleted file mode 100644 index 1cf20edaa..000000000 --- a/tests/translation/test_inference.py +++ /dev/null @@ -1,43 +0,0 @@ -import torch -from viscy.translation.inference import VS_inference_t2t - - -def test_vs_inference_t2t(): - """ - Test the VS_inference_t2t function with a simple config and random input. - """ - in_stack_depth = 21 - dims = [24, 48, 96, 192] # dims[0] must be divisible by ratio (24/3=8) - - cfg = { - "model": { - "class_path": "viscy.translation.engine.VSUNet", - "init_args": { - "architecture": "fcmae", - "model_config": { - "in_channels": 1, - "out_channels": 2, - "in_stack_depth": in_stack_depth, - "encoder_blocks": [1, 1, 1, 1], - "dims": dims, - "stem_kernel_size": [7, 4, 4], - "pretraining": False, - "decoder_conv_blocks": 1, - "head_conv": True, - "head_conv_expansion_ratio": 2, - "head_conv_pool": False, - }, - "test_time_augmentations": False, - "tta_type": "none", - "ckpt_path": None, - }, - } - } - - x = torch.rand(1, 1, in_stack_depth, 64, 64) - pred = VS_inference_t2t(x, cfg) - - assert isinstance(pred, torch.Tensor) - assert pred.shape == (1, 2, in_stack_depth, 64, 64), ( - f"Unexpected shape: {pred.shape}" - ) diff --git a/viscy/translation/inference.py b/viscy/translation/inference.py deleted file mode 100644 index 1f3937657..000000000 --- a/viscy/translation/inference.py +++ /dev/null @@ -1,47 +0,0 @@ -import importlib - -import torch - -from viscy.translation.engine import AugmentedPredictionVSUNet - - -@torch.no_grad() -def vs_inference_t2t(x: torch.Tensor, cfg: dict, gpu: bool = True) -> torch.Tensor: - """ - Run virtual staining using a config dictionary and 5D input tensor (B, C, Z, Y, X). - Returns predicted tensor of shape (B, C_out, Z, Y, X). - """ - if gpu: - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - else: - device = torch.device("cpu") - - # Extract model info - model_cfg = cfg["model"].copy() - init_args = model_cfg["init_args"] - class_path = model_cfg["class_path"] - - # Inject ckpt_path from top-level config if needed - if "ckpt_path" in cfg: - init_args["ckpt_path"] = cfg["ckpt_path"] - - # Import model class dynamically - module_path, class_name = class_path.rsplit(".", 1) - model_class = getattr(importlib.import_module(module_path), class_name) - - # Instantiate model - model = model_class(**init_args).to(device).eval() - - # Wrap with augmentation logic - wrapper = ( - AugmentedPredictionVSUNet( - model=model.model, - forward_transforms=[lambda t: t], - inverse_transforms=[lambda t: t], - ) - .to(x.device) - .eval() - ) - - wrapper.on_predict_start() - return wrapper.predict_sliding_windows(x) From e489ef55d93fa434cc15582e556f91016c213a04 Mon Sep 17 00:00:00 2001 From: Taylla Theodoro Date: Thu, 28 Aug 2025 14:55:07 -0700 Subject: [PATCH 22/32] docstring --- viscy/translation/engine.py | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/viscy/translation/engine.py b/viscy/translation/engine.py index b64a260a0..b7579d896 100644 --- a/viscy/translation/engine.py +++ b/viscy/translation/engine.py @@ -529,8 +529,6 @@ def __init__( def forward(self, x: Tensor) -> Tensor: return self.model(x) - - @torch.no_grad() def predict_sliding_windows( @@ -538,8 +536,7 @@ def predict_sliding_windows( ) -> torch.Tensor: """ Run inference on a 5D input tensor (B, C, Z, Y, X) using sliding windows - along the Z dimension with overlap and blending. Applies test-time augmentations - and padding as defined in predict_step. + along the Z dimension with overlap and blending. Parameters ---------- @@ -559,20 +556,23 @@ def predict_sliding_windows( self.eval() if x.ndim != 5: - raise ValueError(f"Expected input with 5 dimensions (B, C, Z, Y, X), got {x.shape}") + raise ValueError( + f"Expected input with 5 dimensions (B, C, Z, Y, X), got {x.shape}" + ) batch_size, _, depth, height, width = x.shape in_stack_depth = self.model.out_stack_depth if not hasattr(self, "_predict_pad"): - raise RuntimeError("Missing _predict_pad; make sure to call `on_predict_start()` before inference.") + raise RuntimeError( + "Missing _predict_pad; make sure to call `on_predict_start()` before inference." + ) if in_stack_depth > depth: raise ValueError(f"in_stack_depth {in_stack_depth} > input depth {depth}") out_tensor = x.new_zeros((batch_size, out_channel, depth, height, width)) weights = x.new_zeros((1, 1, depth, 1, 1)) - # Loop over Z with overlapping slabs for start in range(0, depth, step): end = min(start + in_stack_depth, depth) slab = x[:, :, start:end] @@ -581,7 +581,6 @@ def predict_sliding_windows( pad_z = in_stack_depth - (end - start) slab = F.pad(slab, (0, 0, 0, 0, 0, pad_z)) - # Use the same logic as predict_step (TTA + pad + model + unpad) pred = self._predict_with_tta(slab) pred = pred[:, :, : end - start] # Trim if Z was padded @@ -589,7 +588,9 @@ def predict_sliding_windows( weights[:, :, start:end] += 1.0 if (weights == 0).any(): - raise RuntimeError("Some Z slices were not covered during sliding window inference.") + raise RuntimeError( + "Some Z slices were not covered during sliding window inference." + ) blended = out_tensor / weights return blended @@ -607,11 +608,13 @@ def _reduce_predictions(self, preds: list[Tensor]) -> Tensor: elif self._reduction == "median": prediction = prediction.median(dim=0).values return prediction - + def _predict_with_tta(self, source: torch.Tensor) -> torch.Tensor: preds = [] - for fwd_t, inv_t in zip(self._forward_transforms or [lambda x: x], - self._inverse_transforms or [lambda x: x]): + for fwd_t, inv_t in zip( + self._forward_transforms or [lambda x: x], + self._inverse_transforms or [lambda x: x], + ): src = fwd_t(source) src = self._predict_pad(src) y = self.forward(src) @@ -621,7 +624,7 @@ def _predict_with_tta(self, source: torch.Tensor) -> torch.Tensor: def predict_step( self, batch: Sample, batch_idx: int, dataloader_idx: int = 0 - ) -> torch.Tensor: + ) -> torch.Tensor: """ Lightning's built-in prediction step. This method is called by the Trainer during `.predict(...)`. @@ -645,7 +648,6 @@ def predict_step( return self._predict_with_tta(source) - class FcmaeUNet(VSUNet): def __init__( self, From 5da4c991e88fbf22d45c55705ab44f4ccc562372 Mon Sep 17 00:00:00 2001 From: Taylla Milena Theodoro Date: Thu, 4 Sep 2025 17:15:57 -0700 Subject: [PATCH 23/32] Update viscy/translation/engine.py Co-authored-by: Ziwen Liu <67518483+ziw-liu@users.noreply.github.com> --- viscy/translation/engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/viscy/translation/engine.py b/viscy/translation/engine.py index b7579d896..5e8d5f71c 100644 --- a/viscy/translation/engine.py +++ b/viscy/translation/engine.py @@ -536,7 +536,7 @@ def predict_sliding_windows( ) -> torch.Tensor: """ Run inference on a 5D input tensor (B, C, Z, Y, X) using sliding windows - along the Z dimension with overlap and blending. + along the Z dimension with overlap and average blending. Parameters ---------- From 70dab298cefa22b78b8da5e7218fefaa03d31397 Mon Sep 17 00:00:00 2001 From: Taylla Milena Theodoro Date: Thu, 4 Sep 2025 17:16:06 -0700 Subject: [PATCH 24/32] Update viscy/translation/engine.py Co-authored-by: Ziwen Liu <67518483+ziw-liu@users.noreply.github.com> --- viscy/translation/engine.py | 1 - 1 file changed, 1 deletion(-) diff --git a/viscy/translation/engine.py b/viscy/translation/engine.py index 5e8d5f71c..a338efed9 100644 --- a/viscy/translation/engine.py +++ b/viscy/translation/engine.py @@ -530,7 +530,6 @@ def __init__( def forward(self, x: Tensor) -> Tensor: return self.model(x) - @torch.no_grad() def predict_sliding_windows( self, x: torch.Tensor, out_channel: int = 2, step: int = 1 ) -> torch.Tensor: From ddbd83a1776426ea2e65124dc3c7ff4bf15f6d0d Mon Sep 17 00:00:00 2001 From: Taylla Milena Theodoro Date: Thu, 4 Sep 2025 17:16:21 -0700 Subject: [PATCH 25/32] Update examples/virtual_staining/VS_model_inference/demo_api.py Co-authored-by: Ziwen Liu <67518483+ziw-liu@users.noreply.github.com> --- examples/virtual_staining/VS_model_inference/demo_api.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/virtual_staining/VS_model_inference/demo_api.py b/examples/virtual_staining/VS_model_inference/demo_api.py index 11fdd27df..8eab03f77 100644 --- a/examples/virtual_staining/VS_model_inference/demo_api.py +++ b/examples/virtual_staining/VS_model_inference/demo_api.py @@ -26,8 +26,6 @@ "head_conv_pool": False, }, ckpt_path="/path/to/checkpoint.ckpt", - test_time_augmentations=True, - tta_type="median", ).to(DEVICE).eval() vs = AugmentedPredictionVSUNet( From 5b74bc23b89463f14a1e48cad067bbb30718fc92 Mon Sep 17 00:00:00 2001 From: Taylla Milena Theodoro Date: Thu, 4 Sep 2025 17:16:56 -0700 Subject: [PATCH 26/32] Update examples/virtual_staining/VS_model_inference/demo_api.py Co-authored-by: Ziwen Liu <67518483+ziw-liu@users.noreply.github.com> --- examples/virtual_staining/VS_model_inference/demo_api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/virtual_staining/VS_model_inference/demo_api.py b/examples/virtual_staining/VS_model_inference/demo_api.py index 8eab03f77..885b01ca5 100644 --- a/examples/virtual_staining/VS_model_inference/demo_api.py +++ b/examples/virtual_staining/VS_model_inference/demo_api.py @@ -37,9 +37,9 @@ # Load data path = Path("/path/to/your.zarr/0/1/000000") with open_ome_zarr(path) as ds: - vol_np = np.asarray(ds.data[0, 0]) # (Z, Y, X) + vol_np = np.asarray(ds.data[0:1, 0:1]) # (1, 1, Z, Y, X) -vol = torch.from_numpy(vol_np).unsqueeze(0).unsqueeze(0).float().to(DEVICE) +vol = torch.from_numpy(vol_np).float().to(DEVICE) # Run inference with torch.no_grad(): From b93020bf2cde77b369ce53e42bd6de6529fedb94 Mon Sep 17 00:00:00 2001 From: Taylla Milena Theodoro Date: Thu, 4 Sep 2025 17:17:06 -0700 Subject: [PATCH 27/32] Update examples/virtual_staining/VS_model_inference/demo_api.py Co-authored-by: Ziwen Liu <67518483+ziw-liu@users.noreply.github.com> --- examples/virtual_staining/VS_model_inference/demo_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/virtual_staining/VS_model_inference/demo_api.py b/examples/virtual_staining/VS_model_inference/demo_api.py index 885b01ca5..17c16b479 100644 --- a/examples/virtual_staining/VS_model_inference/demo_api.py +++ b/examples/virtual_staining/VS_model_inference/demo_api.py @@ -42,7 +42,7 @@ vol = torch.from_numpy(vol_np).float().to(DEVICE) # Run inference -with torch.no_grad(): +with torch.inference_mode(): pred = vs.predict_sliding_windows(vol) torch.cuda.synchronize() From 1ba5bbad28ca5af2d0c68fe539c0461f492f5d4e Mon Sep 17 00:00:00 2001 From: Taylla Milena Theodoro Date: Thu, 4 Sep 2025 17:17:22 -0700 Subject: [PATCH 28/32] Update viscy/translation/engine.py Co-authored-by: Ziwen Liu <67518483+ziw-liu@users.noreply.github.com> --- viscy/translation/engine.py | 1 - 1 file changed, 1 deletion(-) diff --git a/viscy/translation/engine.py b/viscy/translation/engine.py index a338efed9..daa5955c3 100644 --- a/viscy/translation/engine.py +++ b/viscy/translation/engine.py @@ -552,7 +552,6 @@ def predict_sliding_windows( Output tensor of shape (B, out_channel, Z, Y, X). """ - self.eval() if x.ndim != 5: raise ValueError( From ecd11b52e0fc35ae03beaafee1a2e78fb23d8425 Mon Sep 17 00:00:00 2001 From: Taylla Milena Theodoro Date: Thu, 4 Sep 2025 17:17:41 -0700 Subject: [PATCH 29/32] Update viscy/translation/engine.py Co-authored-by: Ziwen Liu <67518483+ziw-liu@users.noreply.github.com> --- viscy/translation/engine.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/viscy/translation/engine.py b/viscy/translation/engine.py index daa5955c3..154fd2435 100644 --- a/viscy/translation/engine.py +++ b/viscy/translation/engine.py @@ -624,10 +624,6 @@ def predict_step( self, batch: Sample, batch_idx: int, dataloader_idx: int = 0 ) -> torch.Tensor: """ - Lightning's built-in prediction step. This method is called by the Trainer during `.predict(...)`. - - Applies test-time augmentations (TTA) and padding logic to the input batch["source"]. - Parameters ---------- batch : dict[str, Tensor] From 8f09e40116e6ef012e957173d8702b619877262c Mon Sep 17 00:00:00 2001 From: Taylla Theodoro Date: Thu, 4 Sep 2025 17:27:30 -0700 Subject: [PATCH 30/32] fallback logic to the init method --- viscy/translation/engine.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/viscy/translation/engine.py b/viscy/translation/engine.py index 154fd2435..a57b7889a 100644 --- a/viscy/translation/engine.py +++ b/viscy/translation/engine.py @@ -515,16 +515,16 @@ class AugmentedPredictionVSUNet(LightningModule): def __init__( self, model: nn.Module, - forward_transforms: list[Callable[[Tensor], Tensor]], - inverse_transforms: list[Callable[[Tensor], Tensor]], + forward_transforms: list[Callable[[Tensor], Tensor]] | None = None, + inverse_transforms: list[Callable[[Tensor], Tensor]] | None = None, reduction: Literal["mean", "median"] = "mean", ) -> None: super().__init__() down_factor = 2**model.num_blocks self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) self.model = model - self._forward_transforms = forward_transforms - self._inverse_transforms = inverse_transforms + self._forward_transforms = forward_transforms or [lambda x: x] + self._inverse_transforms = inverse_transforms or [lambda x: x] self._reduction = reduction def forward(self, x: Tensor) -> Tensor: @@ -610,8 +610,8 @@ def _reduce_predictions(self, preds: list[Tensor]) -> Tensor: def _predict_with_tta(self, source: torch.Tensor) -> torch.Tensor: preds = [] for fwd_t, inv_t in zip( - self._forward_transforms or [lambda x: x], - self._inverse_transforms or [lambda x: x], + self._forward_transforms, + self._inverse_transforms, ): src = fwd_t(source) src = self._predict_pad(src) From f05ace3376049c1bc2b2cee0e8fd16f8faf267bc Mon Sep 17 00:00:00 2001 From: Taylla Theodoro Date: Thu, 4 Sep 2025 17:28:05 -0700 Subject: [PATCH 31/32] doc string --- viscy/translation/engine.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/viscy/translation/engine.py b/viscy/translation/engine.py index a57b7889a..6eb550395 100644 --- a/viscy/translation/engine.py +++ b/viscy/translation/engine.py @@ -487,10 +487,12 @@ class AugmentedPredictionVSUNet(LightningModule): forward_transforms : list[Callable[[Tensor], Tensor]] A collection of transforms to apply to the input image before passing it to the model. Each one is applied independently. + If None, no forward transforms are applied, fallback to the identity transform. For example, resizing the input to match the expected voxel size of the model. inverse_transforms : list[Callable[[Tensor], Tensor]] Inverse transforms to apply to the model output before reduction. They should be the inverse of each forward transform. + If None, no inverse transforms are applied, fallback to the identity transform. For example, resizing the output to match the original input shape for storage. reduction : Literal["mean", "median"], optional The reduction method to apply to the predictions, by default "mean" From 35b70d147ae1462433071352f42824a1aea43311 Mon Sep 17 00:00:00 2001 From: Taylla Theodoro Date: Thu, 4 Sep 2025 17:30:11 -0700 Subject: [PATCH 32/32] remove torch.cuda.synchronize() --- examples/virtual_staining/VS_model_inference/demo_api.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/virtual_staining/VS_model_inference/demo_api.py b/examples/virtual_staining/VS_model_inference/demo_api.py index 17c16b479..9de9bdb03 100644 --- a/examples/virtual_staining/VS_model_inference/demo_api.py +++ b/examples/virtual_staining/VS_model_inference/demo_api.py @@ -44,7 +44,6 @@ # Run inference with torch.inference_mode(): pred = vs.predict_sliding_windows(vol) -torch.cuda.synchronize() # Visualize pred_np = pred.cpu().numpy()