1+ # -*- coding: utf-8 -*-
2+
3+ from logging import root
4+ import os
5+ from re import S
6+ import torch
7+ import random
8+ import h5py
9+ import pandas as pd
10+ from scipy import ndimage
11+ from torch .utils .data import Dataset
12+ from torch .utils .data .sampler import Sampler
13+
14+ class H5DataSets (Dataset ):
15+ """
16+ Dataset for loading images stored in h5 format. It generates
17+ 4D tensors with dimention order [C, D, H, W] for 3D images, and
18+ 3D tensors with dimention order [C, H, W] for 2D images
19+ """
20+ def __init__ (self , root_dir , sample_list_name , transform = None ):
21+ self .root_dir = root_dir
22+ self .transform = transform
23+ with open (sample_list_name , 'r' ) as f :
24+ lines = f .readlines ()
25+ self .sample_list = [item .replace ('\n ' , '' ) for item in lines ]
26+
27+ def __len__ (self ):
28+ return len (self .sample_list )
29+
30+ def __getitem__ (self , idx ):
31+ sample_name = self .sample_list [idx ]
32+ h5f = h5py .File (self .root_dir + '/' + sample_name , 'r' )
33+ image = h5f ['image' ][:]
34+ label = h5f ['label' ][:]
35+ sample = {'image' : image , 'label' : label }
36+ if self .transform :
37+ sample = self .transform (sample )
38+ # sample["idx"] = idx
39+ return sample
40+
41+ class TwoStreamBatchSampler (Sampler ):
42+ """Iterate two sets of indices
43+
44+ An 'epoch' is one iteration through the primary indices.
45+ During the epoch, the secondary indices are iterated through
46+ as many times as needed.
47+ """
48+
49+ def __init__ (self , primary_indices , secondary_indices , batch_size , secondary_batch_size ):
50+ self .primary_indices = primary_indices
51+ self .secondary_indices = secondary_indices
52+ self .secondary_batch_size = secondary_batch_size
53+ self .primary_batch_size = batch_size - secondary_batch_size
54+
55+ assert len (self .primary_indices ) >= self .primary_batch_size > 0
56+ assert len (self .secondary_indices ) >= self .secondary_batch_size > 0
57+
58+ def __iter__ (self ):
59+ primary_iter = iterate_once (self .primary_indices )
60+ secondary_iter = iterate_eternally (self .secondary_indices )
61+ return (
62+ primary_batch + secondary_batch
63+ for (primary_batch , secondary_batch )
64+ in zip (grouper (primary_iter , self .primary_batch_size ),
65+ grouper (secondary_iter , self .secondary_batch_size ))
66+ )
67+
68+ def __len__ (self ):
69+ return len (self .primary_indices ) // self .primary_batch_size
70+
71+
72+ def iterate_once (iterable ):
73+ return np .random .permutation (iterable )
74+
75+
76+ def iterate_eternally (indices ):
77+ def infinite_shuffles ():
78+ while True :
79+ yield np .random .permutation (indices )
80+ return itertools .chain .from_iterable (infinite_shuffles ())
81+
82+
83+ def grouper (iterable , n ):
84+ "Collect data into fixed-length chunks or blocks"
85+ # grouper('ABCDEFG', 3) --> ABC DEF"
86+ args = [iter (iterable )] * n
87+ return zip (* args )
88+
89+
90+ if __name__ == "__main__" :
91+ root_dir = "/home/guotai/disk2t/projects/semi_supervise/SSL4MIS/data/ACDC/data/slices"
92+ file_name = "/home/guotai/disk2t/projects/semi_supervise/slices.txt"
93+ dataset = H5DataSets (root_dir , file_name )
94+ train_loader = torch .utils .data .DataLoader (dataset ,
95+ batch_size = 4 , shuffle = True , num_workers = 1 )
96+ for sample in train_loader :
97+ image = sample ['image' ]
98+ label = sample ['label' ]
99+ print (image .shape , label .shape )
100+ print (image .min (), image .max (), label .max ())
0 commit comments