Skip to content

Commit 70c1011

Browse files
authored
Merge pull request #3 from WayScience/visualize_3D
Refactored Framework and Added 3D Visualization Capabilities
2 parents 6e177e1 + 23b5e93 commit 70c1011

22 files changed

+3103
-1725
lines changed

MLproject

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@ name: cell_segmentation_gff
22

33
entry_points:
44
train_model:
5-
command: "python3 train.py"
5+
command: "uv run train.py"

callbacks/Callbacks.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ def _on_epoch_end(
167167
("train", train_dataloader),
168168
("validation", val_dataloader),
169169
]:
170+
f = 0
170171

171172
self._log_epoch_metrics(
172173
model=model,
@@ -179,9 +180,11 @@ def _on_epoch_end(
179180

180181
# Images can be saved in different ways if desired in the future
181182
if self.image_savers is not None and not isinstance(self.image_savers, list):
182-
self.image_savers(
183-
dataset=val_dataloader.dataset.dataset, model=model, epoch=epoch
184-
)
183+
self.image_savers(model=model, epoch=epoch)
184+
185+
else:
186+
for image_saver in self.image_savers:
187+
image_saver(model=model, epoch=epoch)
185188

186189
val_sample = next(iter(val_dataloader))
187190
val_sample = val_sample["input"]

callbacks/utils/SaveEpochSlices.py

Lines changed: 21 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import tifffile
88
import torch
99

10+
from .save_utils import save_image_mlflow
11+
1012

1113
class SaveEpochSlices:
1214
"""
@@ -15,30 +17,24 @@ class SaveEpochSlices:
1517

1618
def __init__(
1719
self,
18-
image_dataset_idxs: list[int],
19-
data_split: str,
20+
image_dataset: torch.utils.data.Dataset,
2021
image_postprocessor: Any = lambda x: x,
22+
image_dataset_idxs: Optional[list[int]] = None,
2123
) -> None:
2224

25+
self.image_dataset = image_dataset
2326
self.image_dataset_idxs = image_dataset_idxs
24-
self.data_split = data_split
2527
self.crop_key_order = ["height_start", "height_end", "width_start", "width_end"]
2628
self.image_postprocessor = image_postprocessor
2729

28-
def save_image_mlflow(
29-
self,
30-
image: torch.Tensor,
31-
save_image_path_folder: str,
32-
image_filename: str,
33-
) -> None:
34-
35-
with tempfile.TemporaryDirectory() as tmp_dir:
36-
save_path = pathlib.Path(tmp_dir) / image_filename
37-
tifffile.imwrite(save_path, image.astype(np.uint8))
30+
self.epoch = None
31+
self.metadata = None
3832

39-
mlflow.log_artifact(
40-
local_path=save_path, artifact_path=save_image_path_folder
41-
)
33+
self.image_dataset_idxs = (
34+
range(len(image_dataset))
35+
if image_dataset_idxs is None
36+
else image_dataset_idxs
37+
)
4238

4339
def save_image(
4440
self,
@@ -67,20 +63,18 @@ def save_image(
6763
for k in self.crop_key_order
6864
)
6965

70-
filename = (
71-
f"{image_path.stem}__{image_type}{image_path.suffix}"
72-
if image_type == "generated_prediction"
73-
else image_path.name
74-
)
66+
image_suffix = ".tiff" if ".tif" in image_path.suffix else image_path.suffix
7567

76-
image_filename = f"{crop_name}__{filename}"
68+
image_filename = (
69+
f"3D_{image_type}_{image_path.stem}__{crop_name}__{image_suffix}"
70+
)
7771

7872
fov_well_name = image_path.parent.name
7973
patient_name = image_path.parents[2].name
8074

81-
save_image_path_folder = f"epoch_{self.epoch:02}/{patient_name}/{fov_well_name}/{input_slices_name}__{target_slices_name}"
75+
save_image_path_folder = f"cropped_images/epoch_{self.epoch:02}/{patient_name}/{fov_well_name}/{input_slices_name}__{target_slices_name}"
8276

83-
self.save_image_mlflow(
77+
save_image_mlflow(
8478
image=image,
8579
save_image_path_folder=save_image_path_folder,
8680
image_filename=image_filename,
@@ -93,12 +87,10 @@ def predict_target(
9387
) -> torch.Tensor:
9488
return self.image_postprocessor(model(image.unsqueeze(0)).squeeze(0))
9589

96-
def __call__(
97-
self, dataset: torch.utils.data.Dataset, model: torch.nn.Module, epoch: int
98-
) -> None:
90+
def __call__(self, model: torch.nn.Module, epoch: int) -> None:
9991
self.epoch = epoch
10092
for sample_idx in self.image_dataset_idxs:
101-
sample = dataset[sample_idx]
93+
sample = self.image_dataset[sample_idx]
10294
self.metadata = sample["metadata"]
10395

10496
sample_image = self.save_image(
@@ -121,6 +113,6 @@ def __call__(
121113

122114
self.save_image(
123115
image_path=sample["target_path"],
124-
image_type="generated_prediction",
116+
image_type="generated-prediction",
125117
image=generated_prediction,
126118
)

callbacks/utils/SaveWholeSlices.py

Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
import pathlib
2+
from typing import Any, Optional
3+
4+
import numpy as np
5+
import pandas as pd
6+
import tifffile
7+
import torch
8+
9+
from .image_padding_specs import compute_patch_mapping
10+
from .save_utils import save_image_locally, save_image_mlflow
11+
12+
13+
class SaveWholeSlices:
14+
"""
15+
Saves chosen images, and all voxels from those images, to a 3D tiff format either locally or in MLflow.
16+
"""
17+
18+
def __init__(
19+
self,
20+
image_dataset: torch.utils.data.Dataset,
21+
image_dataset_idxs: list[int],
22+
image_specs: dict[str, Any],
23+
stride: tuple[int],
24+
crop_shape: tuple[int],
25+
pad_mode="reflect",
26+
image_postprocessor: Any = lambda x: x,
27+
local_save_path: Optional[pathlib.Path] = None,
28+
):
29+
30+
self.image_dataset = image_dataset
31+
self.image_dataset_idxs = image_dataset_idxs
32+
self.image_specs = image_specs
33+
self.stride = stride
34+
self.crop_shape = crop_shape
35+
self.pad_mode = pad_mode
36+
self.image_postprocessor = image_postprocessor
37+
self.local_save_path = local_save_path
38+
39+
self.unique_image_dataset_idxs = []
40+
self.reduce_dataset_idxs(image_dataset=image_dataset)
41+
42+
self.pad_width, self.original_crop_coords = None, None
43+
self.epoch = None
44+
45+
def reduce_dataset_idxs(self, image_dataset: torch.utils.data.Dataset):
46+
"""
47+
For reducing the dataset to only unique indices.
48+
We don't want to save redundant images.
49+
Dataset indices reflect crop samples, and not whole image samples prior to this function.
50+
"""
51+
self.unique_image_dataset_idxs = []
52+
53+
for sample_idx in self.image_dataset_idxs:
54+
if (
55+
image_dataset[sample_idx]["metadata"]["Metadata_ID"]
56+
not in self.unique_image_dataset_idxs
57+
):
58+
self.unique_image_dataset_idxs.append(sample_idx)
59+
60+
def predict_target(
61+
self, padded_image: torch.Tensor, model: torch.nn.Module
62+
) -> torch.Tensor:
63+
"""
64+
padded_image:
65+
Expects image of shape: (Z, H, W)
66+
Z -> Number of Z slices
67+
H -> Image Height
68+
W -> Image Width
69+
"""
70+
71+
output = torch.zeros(
72+
*padded_image.shape,
73+
dtype=torch.float32,
74+
device=padded_image.device,
75+
)
76+
weight = torch.zeros_like(output)
77+
78+
spatial_ranges = [
79+
range(0, s - c, st)
80+
for s, c, st in zip(padded_image.shape, self.crop_shape, self.stride)
81+
]
82+
83+
for idx in torch.cartesian_prod(
84+
*[torch.tensor(list(r)) for r in spatial_ranges]
85+
):
86+
start = idx.tolist()
87+
end = [s + c for s, c in zip(start, self.crop_shape)]
88+
89+
slices = tuple(slice(s, e) for s, e in zip(start, end))
90+
crop = padded_image[slices].unsqueeze(0) # add batch dim
91+
92+
with torch.no_grad():
93+
generated_prediction = self.image_postprocessor(
94+
generated_prediction=model(crop)
95+
).squeeze(0)
96+
97+
output[slices] += generated_prediction
98+
weight[slices] += 1.0
99+
100+
output /= weight
101+
102+
return output[self.original_crop_coords]
103+
104+
def pad_image(self, input_image: torch.Tensor) -> torch.Tensor:
105+
"""
106+
input_image:
107+
Expects image of shape: (Z, H, W)
108+
Z -> Number of Z slices
109+
H -> Image Height
110+
W -> Image Width
111+
"""
112+
113+
padded_image = np.pad(
114+
input_image.detach().cpu().numpy(),
115+
pad_width=self.pad_width,
116+
mode=self.pad_mode,
117+
)
118+
119+
padded_image = torch.from_numpy(padded_image).to(
120+
dtype=torch.float32, device=input_image.device
121+
)
122+
123+
return padded_image
124+
125+
def save_image(
126+
self,
127+
image_path: pathlib.Path,
128+
image_type: str,
129+
image: torch.Tensor,
130+
) -> bool:
131+
"""
132+
- Determines if the image is completely black or not.
133+
- Saves images in the correct format to the hardcoded path.
134+
"""
135+
136+
if not ((image > 0.0) & (image < 1.0)).any():
137+
if image_type == "input":
138+
raise ValueError("Pixels should be between 0 and 1 in the input image")
139+
140+
if image_type == "target":
141+
image = (image != 0).float()
142+
143+
image = (image * 255).byte().cpu().numpy()
144+
145+
# Black images will not be saved
146+
if np.max(image) == 0:
147+
return False
148+
149+
image_suffix = ".tiff" if ".tif" in image_path.suffix else image_path.suffix
150+
151+
filename = f"3D_{image_type}_{image_path.stem}{image_suffix}"
152+
153+
fov_well_name = image_path.parent.name
154+
patient_name = image_path.parents[2].name
155+
156+
save_image_path_folder = f"{patient_name}/{fov_well_name}"
157+
save_image_path_folder = (
158+
f"whole_images/epoch_{self.epoch:02}/{save_image_path_folder}"
159+
if self.epoch is not None
160+
else save_image_path_folder
161+
)
162+
163+
if self.local_save_path is None:
164+
save_image_mlflow(
165+
image=image,
166+
save_image_path_folder=save_image_path_folder,
167+
image_filename=filename,
168+
)
169+
else:
170+
save_image_path_folder = self.local_save_path / save_image_path_folder
171+
save_image_locally(
172+
image=image,
173+
save_image_path_folder=save_image_path_folder,
174+
image_filename=filename,
175+
)
176+
177+
return True
178+
179+
def __call__(
180+
self,
181+
model: torch.nn.Module,
182+
epoch: Optional[int] = None,
183+
) -> None:
184+
185+
self.epoch = epoch
186+
for sample_idx in self.unique_image_dataset_idxs:
187+
188+
self.image_specs["image_shape"][0] = tifffile.imread(
189+
self.image_dataset[sample_idx]["input_path"]
190+
).shape[0]
191+
192+
# For computing image padding and original crop coordinates
193+
# Only the z-padding and the z-crop coordinates need to be computed
194+
# each time, because the number of z-slices isn't consistent across
195+
# 3D images.
196+
self.pad_width, self.original_crop_coords = compute_patch_mapping(
197+
image_specs=self.image_specs,
198+
crop_shape=self.crop_shape,
199+
stride=self.stride,
200+
pad_slices=True,
201+
)
202+
203+
sample_image = self.save_image(
204+
image_path=self.image_dataset[sample_idx]["target_path"],
205+
image_type="target",
206+
image=self.image_dataset[sample_idx]["target"],
207+
)
208+
209+
# Only save these images if the segmentation mask isn't black
210+
# We expect the model to generate black segmentation crops,
211+
# which will present regardless of weather or not the whole segmented image
212+
# is black or not.
213+
if sample_image:
214+
padded_image = self.pad_image(
215+
input_image=self.image_dataset[sample_idx]["input"]
216+
)
217+
218+
generated_prediction = self.predict_target(
219+
padded_image=padded_image, model=model
220+
)
221+
222+
self.save_image(
223+
image_path=self.image_dataset[sample_idx]["input_path"],
224+
image_type="input",
225+
image=self.image_dataset[sample_idx]["input"],
226+
)
227+
228+
self.save_image(
229+
image_path=self.image_dataset[sample_idx]["target_path"],
230+
image_type="generated-prediction",
231+
image=generated_prediction,
232+
)

0 commit comments

Comments
 (0)