Skip to content

Commit 5537112

Browse files
committed
add overlay in animation of segmentation masks
1 parent 1aa8274 commit 5537112

File tree

6 files changed

+192
-56
lines changed

6 files changed

+192
-56
lines changed

ncalab/prediction.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def __init__(self, model, steps: int, output_image: torch.Tensor):
2222
self.steps = steps
2323
assert output_image.shape[1] == model.num_channels
2424
self.output_image = output_image
25-
self.output_array: Optional[np.ndarray] = None
25+
self._output_array: Optional[np.ndarray] = None
2626

2727
@property
2828
def image_channels(self) -> torch.Tensor:
@@ -56,23 +56,32 @@ def output_channels(self) -> torch.Tensor:
5656
:,
5757
]
5858

59+
@property
60+
def output_array(self) -> np.ndarray:
61+
"""
62+
:returns [np.ndarray]: BCWH
63+
"""
64+
if self._output_array is None:
65+
self._output_array = self.output_image.detach().cpu().numpy()
66+
return self._output_array
67+
5968
@property
6069
def image_channels_np(self) -> np.ndarray:
6170
"""
6271
:returns [np.ndarray]: BCWH
6372
"""
64-
if self.output_array is None:
65-
self.output_array = self.output_image.detach().cpu().numpy()
66-
return self.output_array[:, : self.model.num_image_channels, :, :]
73+
if self._output_array is None:
74+
self._output_array = self.output_image.detach().cpu().numpy()
75+
return self._output_array[:, : self.model.num_image_channels, :, :]
6776

6877
@property
6978
def hidden_channels_np(self) -> np.ndarray:
7079
"""
7180
:returns [np.ndarray]: BCWH
7281
"""
73-
if self.output_array is None:
74-
self.output_array = self.output_image.detach().cpu().numpy()
75-
return self.output_array[
82+
if self._output_array is None:
83+
self._output_array = self.output_image.detach().cpu().numpy()
84+
return self._output_array[
7685
:,
7786
self.model.num_image_channels : self.model.num_hidden_channels
7887
+ self.model.num_hidden_channels,
@@ -85,9 +94,9 @@ def output_channels_np(self) -> np.ndarray:
8594
"""
8695
:returns [np.ndarray]: BCWH
8796
"""
88-
if self.output_array is None:
89-
self.output_array = self.output_image.detach().cpu().numpy()
90-
return self.output_array[
97+
if self._output_array is None:
98+
self._output_array = self.output_image.detach().cpu().numpy()
99+
return self._output_array[
91100
:,
92101
-self.model.num_output_channels :,
93102
:,

ncalab/visualization/animation.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,32 +15,31 @@ def __init__(
1515
interval=100,
1616
repeat=True,
1717
repeat_delay=3000,
18+
overlay=False,
1819
):
1920
nca.eval()
2021

2122
fig, ax = plt.subplots()
2223
fig.set_size_inches(2, 2)
2324

2425
# first frame is input image
25-
if nca.immutable_image_channels:
26+
if nca.immutable_image_channels and not overlay:
2627
first_frame = seed[0, -nca.num_output_channels :]
2728
else:
2829
first_frame = seed[0, : nca.num_image_channels]
29-
first_frame = first_frame.permute(1, 2, 0).detach().cpu().numpy()
30+
first_frame_np = first_frame.permute(1, 2, 0).detach().cpu().numpy()
31+
first_frame_np = np.clip(first_frame, 0, 1)
32+
3033
im = ax.imshow(
31-
first_frame,
34+
first_frame_np,
3235
animated=True,
3336
)
3437

3538
predictions = nca.record(seed, steps)
3639
images = []
3740
for prediction in predictions:
38-
if nca.immutable_image_channels:
39-
output_image = prediction.output_channels_np[0]
40-
else:
41-
output_image = prediction.image_channels_np[0]
41+
output_image = prediction.output_array[0]
4242
output_image = output_image.transpose(1, 2, 0)
43-
output_image = np.clip(output_image, 0, 1)
4443
images.append(output_image)
4544

4645
ax.set_axis_off()
@@ -49,8 +48,25 @@ def __init__(
4948
plt.tight_layout()
5049

5150
def update(i):
52-
nonlocal images
53-
im.set_array(images[i])
51+
nonlocal images, nca
52+
arr = images[i]
53+
if not nca.immutable_image_channels:
54+
arr = arr[:, :, : nca.num_image_channels]
55+
elif overlay:
56+
A = np.clip(arr[:, :, : nca.num_image_channels], 0, 1)
57+
B = np.clip(arr[:, :, -nca.num_output_channels :].squeeze(-1), 0, 1)
58+
alpha = 0.8
59+
threshold = 0.2
60+
beta = 0.8
61+
blue = A[:, :, 2]
62+
blue[B > threshold] = beta * (
63+
alpha * B[B > threshold] + (1 - alpha) * blue[B > threshold]
64+
)
65+
A[:, :, 2] = blue
66+
arr = A
67+
else:
68+
arr = arr[:, :, -nca.num_output_channels :]
69+
im.set_array(arr)
5470
return (im,)
5571

5672
self.animation_fig = animation.FuncAnimation(
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
figures/
2+
weights/
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
2+
from pathlib import Path, PosixPath
3+
from typing import Any
4+
5+
import numpy as np
6+
from PIL import Image
7+
8+
from torch.utils.data import Dataset
9+
10+
11+
class KvasirSegDataset(Dataset):
12+
def __init__(self, path: Path | PosixPath, transform) -> None:
13+
super().__init__()
14+
self.path = path
15+
self.image_filenames = sorted((path / "Kvasir-SEG" / "images").glob("*.jpg"))
16+
self.transform = transform
17+
18+
def __len__(self):
19+
return len(self.image_filenames)
20+
21+
def __getitem__(self, index) -> Any:
22+
filename = self.image_filenames[index].name
23+
image_filename = (self.path / "Kvasir-SEG" / "images" / filename).resolve()
24+
mask_filename = (self.path / "Kvasir-SEG" / "masks" / filename).resolve()
25+
image = Image.open(image_filename).convert("RGB")
26+
mask = Image.open(mask_filename).convert("L")
27+
bbox = image.getbbox()
28+
image = image.crop(bbox)
29+
mask = mask.crop(bbox)
30+
image_arr = np.asarray(image, dtype=np.float32) / 255.0
31+
mask_arr = np.asarray(mask, dtype=np.float32) / 255.0
32+
sample = {"image": image_arr, "mask": mask_arr}
33+
sample = self.transform(**sample)
34+
return sample["image"], sample["mask"]
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
#!/usr/bin/env python3
2+
import os
3+
import sys
4+
from pathlib import Path
5+
6+
root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
7+
sys.path.append(root_dir)
8+
9+
from ncalab import (
10+
Animator,
11+
SegmentationNCAModel,
12+
CascadeNCA,
13+
get_compute_device,
14+
print_NCALab_banner,
15+
fix_random_seed
16+
)
17+
18+
from download_kvasir_seg import KVASIR_SEG_PATH # type: ignore[import-untyped]
19+
from dataset_kvasir_seg import KvasirSegDataset
20+
21+
import albumentations as A # type: ignore[import-untyped]
22+
from albumentations.pytorch import ToTensorV2 # type: ignore[import-untyped]
23+
import click
24+
25+
import torch
26+
27+
28+
TASK_PATH = Path(__file__).parent
29+
FIGURE_PATH = TASK_PATH / "figures"
30+
FIGURE_PATH.mkdir(exist_ok=True)
31+
WEIGHTS_PATH = TASK_PATH / "weights"
32+
WEIGHTS_PATH.mkdir(exist_ok=True)
33+
34+
35+
@click.command()
36+
@click.option("--hidden-channels", "-H", default=18, type=int)
37+
@click.option(
38+
"--gpu/--no-gpu", is_flag=True, default=True, help="Try using the GPU if available."
39+
)
40+
@click.option(
41+
"--gpu-index", type=int, default=0, help="Index of GPU to use, if --gpu in use."
42+
)
43+
def eval_segmentation_kvasir_seg(hidden_channels: int, gpu: bool, gpu_index: int):
44+
print_NCALab_banner()
45+
fix_random_seed()
46+
47+
device = get_compute_device(f"cuda:{gpu_index}" if gpu else "cpu")
48+
49+
nca = SegmentationNCAModel(
50+
device,
51+
num_image_channels=3,
52+
num_hidden_channels=hidden_channels,
53+
num_classes=1,
54+
pad_noise=True,
55+
fire_rate=0.8,
56+
)
57+
cascade = CascadeNCA(nca, [8, 4, 2, 1], [70, 20, 10, 5])
58+
59+
T = A.Compose(
60+
[
61+
A.RandomCrop(300, 300),
62+
A.Resize(256, 256),
63+
A.RandomRotate90(),
64+
A.HorizontalFlip(),
65+
ToTensorV2(),
66+
]
67+
)
68+
dataset = KvasirSegDataset(KVASIR_SEG_PATH, transform=T)
69+
70+
cascade.load_state_dict(
71+
torch.load(
72+
WEIGHTS_PATH / "segmentation_kvasir_seg" / "last_model.pth",
73+
weights_only=True,
74+
)
75+
)
76+
77+
seed = dataset[0][0].unsqueeze(0).to(device)
78+
animator = Animator(cascade, seed, overlay=True)
79+
80+
out_path = FIGURE_PATH / "segmentation_kvasir_seg.gif"
81+
animator.save(out_path)
82+
click.secho(f"Done. You'll find the generated GIF in {out_path} .")
83+
84+
85+
if __name__ == "__main__":
86+
eval_segmentation_kvasir_seg()

tasks/segmentation_kvasir_seg/train_segmentation_kvasir_seg.py

Lines changed: 25 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -5,64 +5,44 @@
55
root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
66
sys.path.append(root_dir)
77

8-
from pathlib import Path, PosixPath
9-
from typing import Any
8+
from pathlib import Path
9+
1010

1111
from ncalab import (
1212
SegmentationNCAModel,
1313
CascadeNCA,
1414
BasicNCATrainer,
1515
get_compute_device,
1616
print_mascot,
17+
print_NCALab_banner,
18+
fix_random_seed,
1719
)
1820

1921
from download_kvasir_seg import download_and_extract, KVASIR_SEG_PATH # type: ignore[import-untyped]
22+
from dataset_kvasir_seg import KvasirSegDataset
2023

2124
import albumentations as A # type: ignore[import-untyped]
2225
from albumentations.pytorch import ToTensorV2 # type: ignore[import-untyped]
2326
import click
24-
import numpy as np
25-
from PIL import Image
27+
2628
from sklearn.model_selection import train_test_split # type: ignore[import-untyped]
2729
import torch
2830
from torch.utils.tensorboard import SummaryWriter
29-
from torch.utils.data import Dataset, Subset
31+
from torch.utils.data import Subset
32+
3033

3134
TASK_PATH = Path(__file__).parent.resolve()
3235
WEIGHTS_PATH = TASK_PATH / "weights"
3336
WEIGHTS_PATH.mkdir(exist_ok=True)
3437

3538

36-
class KvasirSegDataset(Dataset):
37-
def __init__(self, path: Path | PosixPath, transform) -> None:
38-
super().__init__()
39-
self.path = path
40-
self.image_filenames = sorted((path / "Kvasir-SEG" / "images").glob("*.jpg"))
41-
self.transform = transform
42-
43-
def __len__(self):
44-
return len(self.image_filenames)
45-
46-
def __getitem__(self, index) -> Any:
47-
filename = self.image_filenames[index].name
48-
image_filename = (self.path / "Kvasir-SEG" / "images" / filename).resolve()
49-
mask_filename = (self.path / "Kvasir-SEG" / "masks" / filename).resolve()
50-
image = Image.open(image_filename).convert("RGB")
51-
mask = Image.open(mask_filename).convert("L")
52-
bbox = image.getbbox()
53-
image = image.crop(bbox)
54-
mask = mask.crop(bbox)
55-
image_arr = np.asarray(image, dtype=np.float32) / 255.0
56-
mask_arr = np.asarray(mask, dtype=np.float32) / 255.0
57-
sample = {"image": image_arr, "mask": mask_arr}
58-
sample = self.transform(**sample)
59-
return sample["image"], sample["mask"]
60-
61-
62-
def train_segmentation_kvasir_seg(batch_size: int, hidden_channels: int):
39+
def train_segmentation_kvasir_seg(
40+
batch_size: int, hidden_channels: int, gpu: bool, gpu_index: int
41+
):
6342
writer = SummaryWriter(comment="Segmentation Kvasir-SEG")
64-
65-
device = get_compute_device("cuda:0")
43+
print_NCALab_banner()
44+
fix_random_seed()
45+
device = get_compute_device(f"cuda:{gpu_index}" if gpu else "cpu")
6646

6747
nca = SegmentationNCAModel(
6848
device,
@@ -114,7 +94,13 @@ def train_segmentation_kvasir_seg(batch_size: int, hidden_channels: int):
11494
@click.command()
11595
@click.option("--batch-size", "-b", default=8, type=int)
11696
@click.option("--hidden-channels", "-H", default=18, type=int)
117-
def main(batch_size, hidden_channels):
97+
@click.option(
98+
"--gpu/--no-gpu", is_flag=True, default=True, help="Try using the GPU if available."
99+
)
100+
@click.option(
101+
"--gpu-index", type=int, default=0, help="Index of GPU to use, if --gpu in use."
102+
)
103+
def main(batch_size, hidden_channels, gpu, gpu_index):
118104
print_mascot(
119105
"You're training NCAs on a medical dataset now.\n"
120106
"\n"
@@ -129,7 +115,10 @@ def main(batch_size, hidden_channels):
129115
download_and_extract()
130116

131117
train_segmentation_kvasir_seg(
132-
batch_size=batch_size, hidden_channels=hidden_channels
118+
batch_size=batch_size,
119+
hidden_channels=hidden_channels,
120+
gpu=gpu,
121+
gpu_index=gpu_index,
133122
)
134123

135124

0 commit comments

Comments
 (0)