-
Notifications
You must be signed in to change notification settings - Fork 12
Predict volume #280
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
ieivanov
wants to merge
33
commits into
main
Choose a base branch
from
predict_volume
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+155
−20
Open
Predict volume #280
Changes from 23 commits
Commits
Show all changes
33 commits
Select commit
Hold shift + click to select a range
c698999
psudocode
ieivanov 4f703b7
check right shape
ieivanov f542d2c
first try
tayllatheodoro da31198
add channel
tayllatheodoro a3679f8
current stage
tayllatheodoro 5b81b34
working stage
tayllatheodoro 8ac3941
Merge branch 'main' into predict_volume
tayllatheodoro 30b6fbf
api wrapper
tayllatheodoro 433d195
add examples
tayllatheodoro 324cdf1
reorder inputs
tayllatheodoro 8b4d243
rename input tensor
tayllatheodoro 04cbbb1
remove comment
tayllatheodoro 214bbeb
test
tayllatheodoro 22d5f4f
style
tayllatheodoro f5eacca
add test docstrig
tayllatheodoro 9cf5870
style
tayllatheodoro 612db64
first pass of review corrections
tayllatheodoro 227ef0f
bug fix
tayllatheodoro 1cb4b74
move to translation tests
tayllatheodoro 236dfc6
use predict_step
tayllatheodoro 5117de2
update example
tayllatheodoro 7b1be81
move to shrimpy
tayllatheodoro e489ef5
docstring
tayllatheodoro 5da4c99
Update viscy/translation/engine.py
tayllatheodoro 70dab29
Update viscy/translation/engine.py
tayllatheodoro ddbd83a
Update examples/virtual_staining/VS_model_inference/demo_api.py
tayllatheodoro 5b74bc2
Update examples/virtual_staining/VS_model_inference/demo_api.py
tayllatheodoro b93020b
Update examples/virtual_staining/VS_model_inference/demo_api.py
tayllatheodoro 1ba5bba
Update viscy/translation/engine.py
tayllatheodoro ecd11b5
Update viscy/translation/engine.py
tayllatheodoro 8f09e40
fallback logic to the init method
tayllatheodoro f05ace3
doc string
tayllatheodoro 35b70d1
remove torch.cuda.synchronize()
tayllatheodoro File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
#%% | ||
from pathlib import Path | ||
import numpy as np | ||
import torch | ||
from iohub import open_ome_zarr | ||
import napari | ||
|
||
from viscy.translation.engine import VSUNet, AugmentedPredictionVSUNet | ||
|
||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
# Instantiate model manually | ||
model = VSUNet( | ||
architecture="fcmae", | ||
model_config={ | ||
"in_channels": 1, | ||
"out_channels": 2, | ||
"in_stack_depth": 21, | ||
"encoder_blocks": [3, 3, 9, 3], | ||
"dims": [96, 192, 384, 768], | ||
"decoder_conv_blocks": 2, | ||
"stem_kernel_size": [7, 4, 4], | ||
"pretraining": False, | ||
"head_conv": True, | ||
"head_conv_expansion_ratio": 4, | ||
"head_conv_pool": False, | ||
}, | ||
ckpt_path="/path/to/checkpoint.ckpt", | ||
test_time_augmentations=True, | ||
tta_type="median", | ||
).to(DEVICE).eval() | ||
|
||
vs = AugmentedPredictionVSUNet( | ||
model=model.model, | ||
forward_transforms=[lambda t: t], | ||
inverse_transforms=[lambda t: t], | ||
).to(DEVICE).eval() | ||
|
||
# Load data | ||
path = Path("/path/to/your.zarr/0/1/000000") | ||
with open_ome_zarr(path) as ds: | ||
vol_np = np.asarray(ds.data[0, 0]) # (Z, Y, X) | ||
|
||
vol = torch.from_numpy(vol_np).unsqueeze(0).unsqueeze(0).float().to(DEVICE) | ||
tayllatheodoro marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
# Run inference | ||
with torch.no_grad(): | ||
tayllatheodoro marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
pred = vs.predict_sliding_windows(vol) | ||
torch.cuda.synchronize() | ||
ziw-liu marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
# Visualize | ||
pred_np = pred.cpu().numpy() | ||
nuc, mem = pred_np[0, 0], pred_np[0, 1] | ||
|
||
viewer = napari.Viewer() | ||
viewer.add_image(vol_np, name="phase_input", colormap="gray") | ||
viewer.add_image(nuc, name="virt_nuclei", colormap="magenta") | ||
viewer.add_image(mem, name="virt_membrane", colormap="cyan") | ||
napari.run() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -530,6 +530,71 @@ def __init__( | |
def forward(self, x: Tensor) -> Tensor: | ||
return self.model(x) | ||
|
||
@torch.no_grad() | ||
tayllatheodoro marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
def predict_sliding_windows( | ||
self, x: torch.Tensor, out_channel: int = 2, step: int = 1 | ||
) -> torch.Tensor: | ||
""" | ||
Run inference on a 5D input tensor (B, C, Z, Y, X) using sliding windows | ||
along the Z dimension with overlap and blending. | ||
tayllatheodoro marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
Parameters | ||
---------- | ||
x : torch.Tensor | ||
Input tensor of shape (B, C, Z, Y, X). | ||
out_channel : int, optional | ||
Number of output channels, by default 2. | ||
step : int, optional | ||
Step size for sliding window along Z, by default 1. | ||
|
||
Returns | ||
------- | ||
torch.Tensor | ||
Output tensor of shape (B, out_channel, Z, Y, X). | ||
""" | ||
|
||
self.eval() | ||
tayllatheodoro marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
if x.ndim != 5: | ||
raise ValueError( | ||
f"Expected input with 5 dimensions (B, C, Z, Y, X), got {x.shape}" | ||
) | ||
|
||
batch_size, _, depth, height, width = x.shape | ||
in_stack_depth = self.model.out_stack_depth | ||
|
||
if not hasattr(self, "_predict_pad"): | ||
raise RuntimeError( | ||
"Missing _predict_pad; make sure to call `on_predict_start()` before inference." | ||
) | ||
Comment on lines
+566
to
+569
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This attribute is assigned in |
||
if in_stack_depth > depth: | ||
raise ValueError(f"in_stack_depth {in_stack_depth} > input depth {depth}") | ||
|
||
out_tensor = x.new_zeros((batch_size, out_channel, depth, height, width)) | ||
weights = x.new_zeros((1, 1, depth, 1, 1)) | ||
|
||
for start in range(0, depth, step): | ||
end = min(start + in_stack_depth, depth) | ||
slab = x[:, :, start:end] | ||
|
||
if end - start < in_stack_depth: | ||
pad_z = in_stack_depth - (end - start) | ||
slab = F.pad(slab, (0, 0, 0, 0, 0, pad_z)) | ||
|
||
pred = self._predict_with_tta(slab) | ||
|
||
pred = pred[:, :, : end - start] # Trim if Z was padded | ||
out_tensor[:, :, start:end] += pred | ||
weights[:, :, start:end] += 1.0 | ||
|
||
if (weights == 0).any(): | ||
raise RuntimeError( | ||
"Some Z slices were not covered during sliding window inference." | ||
) | ||
|
||
blended = out_tensor / weights | ||
return blended | ||
|
||
def setup(self, stage: str) -> None: | ||
if stage != "predict": | ||
raise NotImplementedError( | ||
|
@@ -544,25 +609,43 @@ def _reduce_predictions(self, preds: list[Tensor]) -> Tensor: | |
prediction = prediction.median(dim=0).values | ||
return prediction | ||
|
||
def _predict_with_tta(self, source: torch.Tensor) -> torch.Tensor: | ||
preds = [] | ||
for fwd_t, inv_t in zip( | ||
self._forward_transforms or [lambda x: x], | ||
self._inverse_transforms or [lambda x: x], | ||
tayllatheodoro marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
): | ||
src = fwd_t(source) | ||
src = self._predict_pad(src) | ||
y = self.forward(src) | ||
y = self._predict_pad.inverse(y) | ||
preds.append(inv_t(y)) | ||
return preds[0] if len(preds) == 1 else self._reduce_predictions(preds) | ||
|
||
def predict_step( | ||
self, batch: Sample, batch_idx: int, dataloader_idx: int = 0 | ||
) -> Tensor: | ||
) -> torch.Tensor: | ||
""" | ||
Lightning's built-in prediction step. This method is called by the Trainer during `.predict(...)`. | ||
|
||
Applies test-time augmentations (TTA) and padding logic to the input batch["source"]. | ||
|
||
tayllatheodoro marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
Parameters | ||
---------- | ||
batch : dict[str, Tensor] | ||
A dictionary containing a "source" tensor of shape (B, C, Z, Y, X). | ||
batch_idx : int | ||
Index of the batch. | ||
dataloader_idx : int | ||
Index of the dataloader if multiple dataloaders are used. | ||
|
||
Returns | ||
------- | ||
torch.Tensor | ||
Prediction tensor of shape (B, out_channels, Z, Y, X). | ||
""" | ||
source = batch["source"] | ||
preds = [] | ||
for forward_t, inverse_t in zip( | ||
self._forward_transforms, self._inverse_transforms | ||
): | ||
source = forward_t(source) | ||
source = self._predict_pad(source) | ||
pred = self.forward(source) | ||
pred = self._predict_pad.inverse(pred) | ||
pred = inverse_t(pred) | ||
preds.append(pred) | ||
if len(preds) == 1: | ||
prediction = preds[0] | ||
else: | ||
prediction = self._reduce_predictions(preds) | ||
return prediction | ||
return self._predict_with_tta(source) | ||
|
||
|
||
class FcmaeUNet(VSUNet): | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.