forked from millicentaumaomondi/medicalsegmentation
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathbrain_dataset.py
More file actions
56 lines (43 loc) · 1.84 KB
/
brain_dataset.py
File metadata and controls
56 lines (43 loc) · 1.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from PIL import Image
from torch.utils.data import Dataset
import os
import torch
class ProdBrainDataset(Dataset):
def __init__(self, root_dir, img_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.img_dir = img_dir
self.img_files = [f for f in os.listdir(self.img_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]
def __len__(self):
return len(self.img_files)
def __getitem__(self, idx):
img_name = self.img_files[idx]
img_path = os.path.join(self.img_dir, img_name)
img = Image.open(img_path).convert('RGB')
img_gray = img.convert('L')
if self.transform:
img_gray = self.transform(img_gray)
return img_gray
def get_original_size(self, idx):
img_name = self.img_files[idx]
img_path = os.path.join(self.img_dir, img_name)
img = Image.open(img_path)
return img.size
class BrainDataset(ProdBrainDataset):
def __init__(self, root_dir, img_dir, mask_dir=None, transform=None):
super().__init__(root_dir, img_dir, transform)
self.mask_dir = mask_dir
if mask_dir:
self.mask_files = [f for f in os.listdir(self.mask_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]
def __getitem__(self, idx):
img_gray = super().__getitem__(idx)
if self.mask_dir and len(self.mask_files) > idx:
mask_name = self.mask_files[idx]
mask_path = os.path.join(self.mask_dir, mask_name)
if not os.path.exists(mask_path):
raise FileNotFoundError(f"Mask file '{mask_path}' does not exist.")
mask = Image.open(mask_path).convert('L')
if self.transform:
mask = self.transform(mask)
return img_gray, mask
return img_gray, torch.zeros_like(img_gray)