-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdataset.py
More file actions
65 lines (52 loc) · 1.87 KB
/
dataset.py
File metadata and controls
65 lines (52 loc) · 1.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import os
import numpy as np
import torch
from torch.utils.data import Dataset
from load_onedata import load_data
class BasicDataset(Dataset):
def __init__(self, filePath, crop_shape=[256, 256, 8], pad_shape=[256, 256, 16]):
self.crop_shape = crop_shape
self.pad_shape = pad_shape
self.ids = open(filePath).readlines()
def __len__(self):
return len(self.ids)
def __getitem__(self, i):
idx = self.ids[i]
img_file = idx.split('\n')[0]
mask_file = img_file.replace(".nii.gz","_gt.nii.gz")
img, mask = load_data(
img_file,mask_file,
crop_shape=self.crop_shape,
pad_shape=self.pad_shape
)
img = np.transpose(img,(2,0,1))
img = np.expand_dims(img, axis=0)
mask = mask.astype(np.int64)
mask = np.transpose(mask,(2,0,1)) # D,H,W # no expand_dims
img_tensor = torch.from_numpy(img).float()
mask_tensor = torch.from_numpy(mask).long()
return {
'image': torch.from_numpy(img).float(),
'mask': torch.from_numpy(mask).float()
}
def main():
# img='C:\Users\f1995\Desktop\MnM2\dataset\205\205_SA_ED.nii.gz'
# mask='C:\Users\f1995\Desktop\MnM2\dataset\205\205_SA_ED_gt.nii.gz'
filepath='/Users/f111f/Desktop/train_ex/trainfile.txt'
datadata=BasicDataset(filepath, crop_shape=[256, 256, 8], pad_shape=[256, 256, 16])
sample = datadata[0]
image = sample['image']
mask = sample['mask']
print(f"img: {image.shape}")
print(f"mask: {mask.shape}")
print(f"img_type: {image.dtype}")
print(f"mask_type: {mask.dtype}")
if __name__ == "__main__":
main()
# #output:
# (256, 256, 16)
# torch.Size([256, 256, 16])
# img: torch.Size([16, 256, 256])
# mask: torch.Size([16, 256, 256])
# img_type: torch.float32
# mask_type: torch.float32