Skip to content

Commit 057ae81

Browse files
committed
allow choosing different training datasets in depth trainer
1 parent 128bc42 commit 057ae81

File tree

1 file changed

+150
-36
lines changed

1 file changed

+150
-36
lines changed

tasks/depth/train_depth.py

Lines changed: 150 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,25 @@
1212
BasicNCATrainer,
1313
WEIGHTS_PATH,
1414
get_compute_device,
15+
fix_random_seed,
1516
)
1617

1718
import click
1819

19-
import torch
20+
import cv2
21+
2022
import numpy as np
23+
import pandas as pd
2124

2225
import albumentations as A # type: ignore[import-untyped]
2326
from albumentations.pytorch import ToTensorV2 # type: ignore[import-untyped]
2427

28+
import torch
2529
from torch.utils.tensorboard import SummaryWriter
2630
from torch.utils.data import Dataset
2731
from PIL import Image
2832

29-
from config import KID_DATASET_PATH
33+
from config import KID_DATASET_PATH, KVASIR_CAPSULE_DATASET_PATH
3034

3135

3236
TASK_PATH = Path(__file__).parent
@@ -45,8 +49,62 @@ def __len__(self):
4549

4650
def __getitem__(self, index) -> Any:
4751
filename = self.image_filenames[index]
48-
image_filename = KID_DATASET_PATH / "all" / filename
49-
mask_filename = KID_DATASET_PATH / "depth" / filename
52+
image_filename = self.path / "all" / filename
53+
mask_filename = self.path / "depth" / filename
54+
image = Image.open(image_filename).convert("RGB")
55+
mask = Image.open(mask_filename).convert("L")
56+
image_arr = np.asarray(image, dtype=np.float32) / 255.0
57+
mask_arr = np.asarray(mask, dtype=np.float32) / 255.0
58+
image_arr[self.vignette == 0] = 0
59+
mask_arr[self.vignette == 0] = 0
60+
sample = {"image": image_arr, "mask": mask_arr}
61+
if self.transform is not None:
62+
sample = self.transform(**sample)
63+
return sample["image"], sample["mask"]
64+
65+
66+
class KvasirCapsuleDataset(Dataset):
67+
def __init__(self, path: Path | PosixPath, filenames, transform=None) -> None:
68+
super().__init__()
69+
self.path = path
70+
self.image_filenames = filenames
71+
self.transform = transform
72+
self.vignette = cv2.imread(str(path / "vignette_kvasir_capsule.png"))[..., 0]
73+
74+
def __len__(self):
75+
return len(self.image_filenames)
76+
77+
def __getitem__(self, index):
78+
filename = self.image_filenames[index]
79+
image_filename = self.path / "images" / "Any" / filename
80+
mask_filename = self.path / "depth" / filename
81+
image = Image.open(image_filename).convert("RGB")
82+
mask = Image.open(mask_filename).convert("L")
83+
image_arr = np.asarray(image, dtype=np.float32) / 255.0
84+
mask_arr = np.asarray(mask, dtype=np.float32) / 255.0
85+
image_arr[self.vignette == 0] = 0
86+
mask_arr[self.vignette == 0] = 0
87+
sample = {"image": image_arr, "mask": mask_arr}
88+
if self.transform is not None:
89+
sample = self.transform(**sample)
90+
return sample["image"], sample["mask"]
91+
92+
93+
class EndoSLAMDataset(Dataset):
94+
def __init__(self, path: Path | PosixPath, filenames, transform=None) -> None:
95+
super().__init__()
96+
self.path = path
97+
self.image_filenames = filenames
98+
self.transform = transform
99+
self.vignette = np.asarray(Image.open(path / "vignette_unity.png"))[..., 0]
100+
101+
def __len__(self):
102+
return len(self.image_filenames)
103+
104+
def __getitem__(self, index) -> Any:
105+
filename = self.image_filenames[index]
106+
image_filename = self.path / "Frames" / filename
107+
mask_filename = self.path / "Pixelwise Depths" / ("aov_" + filename)
50108
image = Image.open(image_filename).convert("RGB")
51109
mask = Image.open(mask_filename).convert("L")
52110
image_arr = np.asarray(image, dtype=np.float32) / 255.0
@@ -61,7 +119,7 @@ def __getitem__(self, index) -> Any:
61119

62120
def train_depth_KID(batch_size: int, hidden_channels: int):
63121
writer = SummaryWriter()
64-
122+
fix_random_seed()
65123
device = get_compute_device("cuda:0")
66124

67125
nca = DepthNCAModel(
@@ -72,66 +130,122 @@ def train_depth_KID(batch_size: int, hidden_channels: int):
72130
lambda_activity=0.00,
73131
)
74132

133+
INPUT_SIZE = 64
134+
75135
T = A.Compose(
76136
[
77-
A.CenterCrop(320, 320),
78-
A.Resize(80, 80),
137+
A.CenterCrop(300, 300),
138+
A.Resize(INPUT_SIZE, INPUT_SIZE),
79139
A.RandomRotate90(),
80140
ToTensorV2(),
81141
]
82142
)
83-
import pandas as pd
84-
85-
split = pd.read_csv(TASK_PATH / "split_normal_small_bowel.csv")
86-
train_filenames = split[split.split != "val"].filename.values
87-
train_filenames = [
88-
filename
89-
for filename in train_filenames
90-
if (KID_DATASET_PATH / "depth" / filename).exists()
91-
]
92-
train_dataset = KIDDataset(
93-
KID_DATASET_PATH,
94-
filenames=train_filenames,
95-
transform=T,
96-
)
97-
val_filenames = split[split.split == "val"].filename.values
98-
val_filenames = [
99-
filename
100-
for filename in val_filenames
101-
if (KID_DATASET_PATH / "depth" / filename).exists()
102-
]
103-
val_dataset = KIDDataset(
104-
KID_DATASET_PATH,
105-
filenames=val_filenames,
106-
transform=T,
143+
T_val = A.Compose(
144+
[
145+
A.CenterCrop(300, 300),
146+
A.Resize(INPUT_SIZE, INPUT_SIZE),
147+
A.RandomRotate90(),
148+
ToTensorV2(),
149+
]
107150
)
108151

152+
dataset_id = "kvasircapsule"
153+
154+
if dataset_id == "kid":
155+
split = pd.read_csv(KID_DATASET_PATH / "split_depth.csv")
156+
train_filenames = split[split.split != "val"].filename.values
157+
train_filenames = [
158+
filename
159+
for filename in train_filenames
160+
if (KID_DATASET_PATH / "depth" / filename).exists()
161+
]
162+
train_dataset = KIDDataset(
163+
KID_DATASET_PATH,
164+
filenames=train_filenames,
165+
transform=T,
166+
)
167+
val_filenames = split[split.split == "val"].filename.values
168+
val_filenames = [
169+
filename
170+
for filename in val_filenames
171+
if (KID_DATASET_PATH / "depth" / filename).exists()
172+
]
173+
val_dataset = KIDDataset(
174+
KID_DATASET_PATH,
175+
filenames=val_filenames,
176+
transform=T_val,
177+
)
178+
elif dataset_id == "kvasircapsule":
179+
split = pd.read_csv(KVASIR_CAPSULE_DATASET_PATH / "split_depth.csv")
180+
train_filenames = split[split.split != "val"].filename.values
181+
train_filenames = [
182+
filename
183+
for filename in train_filenames
184+
if (KVASIR_CAPSULE_DATASET_PATH / "depth" / filename).exists()
185+
]
186+
train_dataset = KvasirCapsuleDataset(
187+
KVASIR_CAPSULE_DATASET_PATH,
188+
filenames=train_filenames,
189+
transform=T,
190+
)
191+
val_filenames = split[split.split == "val"].filename.values
192+
val_filenames = [
193+
filename
194+
for filename in val_filenames
195+
if (KVASIR_CAPSULE_DATASET_PATH / "depth" / filename).exists()
196+
]
197+
val_dataset = KvasirCapsuleDataset(
198+
KVASIR_CAPSULE_DATASET_PATH,
199+
filenames=val_filenames,
200+
transform=T_val,
201+
)
202+
elif dataset_id == "endoslam":
203+
endoslam_path = Path("~/EndoSLAM/data").expanduser()
204+
filenames = [
205+
f.name
206+
for i, f in enumerate(sorted((endoslam_path / "Frames").glob("*.png")))
207+
if i % 100 == 0
208+
]
209+
train_filenames = filenames[: int(len(filenames) * 0.8)]
210+
val_filenames = filenames[len(train_filenames) :]
211+
train_dataset = EndoSLAMDataset(
212+
endoslam_path,
213+
train_filenames,
214+
transform=T,
215+
)
216+
val_dataset = EndoSLAMDataset(
217+
endoslam_path,
218+
val_filenames,
219+
transform=T,
220+
)
221+
109222
loader_train = torch.utils.data.DataLoader(
110223
train_dataset, shuffle=True, batch_size=batch_size, drop_last=True
111224
)
112225
loader_val = torch.utils.data.DataLoader(
113226
val_dataset, shuffle=True, batch_size=batch_size, drop_last=True
114227
)
228+
nca.vignette = train_dataset.vignette
115229

116230
trainer = BasicNCATrainer(
117231
nca,
118232
WEIGHTS_PATH / "depth_KID2_normal_small_bowel.pth",
119-
max_epochs=500,
120-
pad_noise=False,
121-
steps_range=(64, 96),
122-
steps_validation=80,
233+
max_epochs=3500,
234+
steps_range=(96, 110),
235+
steps_validation=100,
123236
)
124237
trainer.train_basic_nca(
125238
loader_train,
126239
loader_val,
127240
summary_writer=writer,
241+
save_every=1,
128242
)
129243
writer.close()
130244

131245

132246
@click.command()
133247
@click.option("--batch-size", "-b", default=8, type=int)
134-
@click.option("--hidden-channels", "-H", default=18, type=int)
248+
@click.option("--hidden-channels", "-H", default=20, type=int)
135249
def main(batch_size, hidden_channels):
136250
train_depth_KID(batch_size=batch_size, hidden_channels=hidden_channels)
137251

0 commit comments

Comments
 (0)