1212 BasicNCATrainer ,
1313 WEIGHTS_PATH ,
1414 get_compute_device ,
15+ fix_random_seed ,
1516)
1617
1718import click
1819
19- import torch
20+ import cv2
21+
2022import numpy as np
23+ import pandas as pd
2124
2225import albumentations as A # type: ignore[import-untyped]
2326from albumentations .pytorch import ToTensorV2 # type: ignore[import-untyped]
2427
28+ import torch
2529from torch .utils .tensorboard import SummaryWriter
2630from torch .utils .data import Dataset
2731from PIL import Image
2832
29- from config import KID_DATASET_PATH
33+ from config import KID_DATASET_PATH , KVASIR_CAPSULE_DATASET_PATH
3034
3135
3236TASK_PATH = Path (__file__ ).parent
@@ -45,8 +49,62 @@ def __len__(self):
4549
4650 def __getitem__ (self , index ) -> Any :
4751 filename = self .image_filenames [index ]
48- image_filename = KID_DATASET_PATH / "all" / filename
49- mask_filename = KID_DATASET_PATH / "depth" / filename
52+ image_filename = self .path / "all" / filename
53+ mask_filename = self .path / "depth" / filename
54+ image = Image .open (image_filename ).convert ("RGB" )
55+ mask = Image .open (mask_filename ).convert ("L" )
56+ image_arr = np .asarray (image , dtype = np .float32 ) / 255.0
57+ mask_arr = np .asarray (mask , dtype = np .float32 ) / 255.0
58+ image_arr [self .vignette == 0 ] = 0
59+ mask_arr [self .vignette == 0 ] = 0
60+ sample = {"image" : image_arr , "mask" : mask_arr }
61+ if self .transform is not None :
62+ sample = self .transform (** sample )
63+ return sample ["image" ], sample ["mask" ]
64+
65+
66+ class KvasirCapsuleDataset (Dataset ):
67+ def __init__ (self , path : Path | PosixPath , filenames , transform = None ) -> None :
68+ super ().__init__ ()
69+ self .path = path
70+ self .image_filenames = filenames
71+ self .transform = transform
72+ self .vignette = cv2 .imread (str (path / "vignette_kvasir_capsule.png" ))[..., 0 ]
73+
74+ def __len__ (self ):
75+ return len (self .image_filenames )
76+
77+ def __getitem__ (self , index ):
78+ filename = self .image_filenames [index ]
79+ image_filename = self .path / "images" / "Any" / filename
80+ mask_filename = self .path / "depth" / filename
81+ image = Image .open (image_filename ).convert ("RGB" )
82+ mask = Image .open (mask_filename ).convert ("L" )
83+ image_arr = np .asarray (image , dtype = np .float32 ) / 255.0
84+ mask_arr = np .asarray (mask , dtype = np .float32 ) / 255.0
85+ image_arr [self .vignette == 0 ] = 0
86+ mask_arr [self .vignette == 0 ] = 0
87+ sample = {"image" : image_arr , "mask" : mask_arr }
88+ if self .transform is not None :
89+ sample = self .transform (** sample )
90+ return sample ["image" ], sample ["mask" ]
91+
92+
93+ class EndoSLAMDataset (Dataset ):
94+ def __init__ (self , path : Path | PosixPath , filenames , transform = None ) -> None :
95+ super ().__init__ ()
96+ self .path = path
97+ self .image_filenames = filenames
98+ self .transform = transform
99+ self .vignette = np .asarray (Image .open (path / "vignette_unity.png" ))[..., 0 ]
100+
101+ def __len__ (self ):
102+ return len (self .image_filenames )
103+
104+ def __getitem__ (self , index ) -> Any :
105+ filename = self .image_filenames [index ]
106+ image_filename = self .path / "Frames" / filename
107+ mask_filename = self .path / "Pixelwise Depths" / ("aov_" + filename )
50108 image = Image .open (image_filename ).convert ("RGB" )
51109 mask = Image .open (mask_filename ).convert ("L" )
52110 image_arr = np .asarray (image , dtype = np .float32 ) / 255.0
@@ -61,7 +119,7 @@ def __getitem__(self, index) -> Any:
61119
62120def train_depth_KID (batch_size : int , hidden_channels : int ):
63121 writer = SummaryWriter ()
64-
122+ fix_random_seed ()
65123 device = get_compute_device ("cuda:0" )
66124
67125 nca = DepthNCAModel (
@@ -72,66 +130,122 @@ def train_depth_KID(batch_size: int, hidden_channels: int):
72130 lambda_activity = 0.00 ,
73131 )
74132
133+ INPUT_SIZE = 64
134+
75135 T = A .Compose (
76136 [
77- A .CenterCrop (320 , 320 ),
78- A .Resize (80 , 80 ),
137+ A .CenterCrop (300 , 300 ),
138+ A .Resize (INPUT_SIZE , INPUT_SIZE ),
79139 A .RandomRotate90 (),
80140 ToTensorV2 (),
81141 ]
82142 )
83- import pandas as pd
84-
85- split = pd .read_csv (TASK_PATH / "split_normal_small_bowel.csv" )
86- train_filenames = split [split .split != "val" ].filename .values
87- train_filenames = [
88- filename
89- for filename in train_filenames
90- if (KID_DATASET_PATH / "depth" / filename ).exists ()
91- ]
92- train_dataset = KIDDataset (
93- KID_DATASET_PATH ,
94- filenames = train_filenames ,
95- transform = T ,
96- )
97- val_filenames = split [split .split == "val" ].filename .values
98- val_filenames = [
99- filename
100- for filename in val_filenames
101- if (KID_DATASET_PATH / "depth" / filename ).exists ()
102- ]
103- val_dataset = KIDDataset (
104- KID_DATASET_PATH ,
105- filenames = val_filenames ,
106- transform = T ,
143+ T_val = A .Compose (
144+ [
145+ A .CenterCrop (300 , 300 ),
146+ A .Resize (INPUT_SIZE , INPUT_SIZE ),
147+ A .RandomRotate90 (),
148+ ToTensorV2 (),
149+ ]
107150 )
108151
152+ dataset_id = "kvasircapsule"
153+
154+ if dataset_id == "kid" :
155+ split = pd .read_csv (KID_DATASET_PATH / "split_depth.csv" )
156+ train_filenames = split [split .split != "val" ].filename .values
157+ train_filenames = [
158+ filename
159+ for filename in train_filenames
160+ if (KID_DATASET_PATH / "depth" / filename ).exists ()
161+ ]
162+ train_dataset = KIDDataset (
163+ KID_DATASET_PATH ,
164+ filenames = train_filenames ,
165+ transform = T ,
166+ )
167+ val_filenames = split [split .split == "val" ].filename .values
168+ val_filenames = [
169+ filename
170+ for filename in val_filenames
171+ if (KID_DATASET_PATH / "depth" / filename ).exists ()
172+ ]
173+ val_dataset = KIDDataset (
174+ KID_DATASET_PATH ,
175+ filenames = val_filenames ,
176+ transform = T_val ,
177+ )
178+ elif dataset_id == "kvasircapsule" :
179+ split = pd .read_csv (KVASIR_CAPSULE_DATASET_PATH / "split_depth.csv" )
180+ train_filenames = split [split .split != "val" ].filename .values
181+ train_filenames = [
182+ filename
183+ for filename in train_filenames
184+ if (KVASIR_CAPSULE_DATASET_PATH / "depth" / filename ).exists ()
185+ ]
186+ train_dataset = KvasirCapsuleDataset (
187+ KVASIR_CAPSULE_DATASET_PATH ,
188+ filenames = train_filenames ,
189+ transform = T ,
190+ )
191+ val_filenames = split [split .split == "val" ].filename .values
192+ val_filenames = [
193+ filename
194+ for filename in val_filenames
195+ if (KVASIR_CAPSULE_DATASET_PATH / "depth" / filename ).exists ()
196+ ]
197+ val_dataset = KvasirCapsuleDataset (
198+ KVASIR_CAPSULE_DATASET_PATH ,
199+ filenames = val_filenames ,
200+ transform = T_val ,
201+ )
202+ elif dataset_id == "endoslam" :
203+ endoslam_path = Path ("~/EndoSLAM/data" ).expanduser ()
204+ filenames = [
205+ f .name
206+ for i , f in enumerate (sorted ((endoslam_path / "Frames" ).glob ("*.png" )))
207+ if i % 100 == 0
208+ ]
209+ train_filenames = filenames [: int (len (filenames ) * 0.8 )]
210+ val_filenames = filenames [len (train_filenames ) :]
211+ train_dataset = EndoSLAMDataset (
212+ endoslam_path ,
213+ train_filenames ,
214+ transform = T ,
215+ )
216+ val_dataset = EndoSLAMDataset (
217+ endoslam_path ,
218+ val_filenames ,
219+ transform = T ,
220+ )
221+
109222 loader_train = torch .utils .data .DataLoader (
110223 train_dataset , shuffle = True , batch_size = batch_size , drop_last = True
111224 )
112225 loader_val = torch .utils .data .DataLoader (
113226 val_dataset , shuffle = True , batch_size = batch_size , drop_last = True
114227 )
228+ nca .vignette = train_dataset .vignette
115229
116230 trainer = BasicNCATrainer (
117231 nca ,
118232 WEIGHTS_PATH / "depth_KID2_normal_small_bowel.pth" ,
119- max_epochs = 500 ,
120- pad_noise = False ,
121- steps_range = (64 , 96 ),
122- steps_validation = 80 ,
233+ max_epochs = 3500 ,
234+ steps_range = (96 , 110 ),
235+ steps_validation = 100 ,
123236 )
124237 trainer .train_basic_nca (
125238 loader_train ,
126239 loader_val ,
127240 summary_writer = writer ,
241+ save_every = 1 ,
128242 )
129243 writer .close ()
130244
131245
132246@click .command ()
133247@click .option ("--batch-size" , "-b" , default = 8 , type = int )
134- @click .option ("--hidden-channels" , "-H" , default = 18 , type = int )
248+ @click .option ("--hidden-channels" , "-H" , default = 20 , type = int )
135249def main (batch_size , hidden_channels ):
136250 train_depth_KID (batch_size = batch_size , hidden_channels = hidden_channels )
137251
0 commit comments