@@ -30,7 +30,6 @@ def weak_augmentations(p: float = 0.75) -> callable:
3030 ])
3131 return torch_em .transform .raw .get_raw_transform (normalizer = norm , augmentation1 = aug )
3232
33- < << << << HEAD
3433class DropChannel :
3534 def __init__ (self , channel : int ):
3635 self .channel = channel
@@ -59,12 +58,6 @@ def __call__(self, data):
5958 output [self .transform_channel ] = self .base_transform (data [self .transform_channel ])
6059 return output
6160
62- == == == =
63- def drop_mask_channel (x ):
64- x = x [:1 ]
65- return x
66-
67- > >> >> >> 9 c252ed35b26397634947e7bec01ccb222751af6
6861class ComposedTransform :
6962 def __init__ (self , * funcs ):
7063 self .funcs = funcs
@@ -74,7 +67,6 @@ def __call__(self, x):
7467 x = f (x )
7568 return x
7669
77- < << << << HEAD
7870class ChannelWiseAugmentations :
7971 def __init__ (self , base_augmentations : callable , transform_channel : int = 0 ):
8072 self .base_augmentations = base_augmentations
@@ -91,40 +83,29 @@ def __call__(self, data):
9183 output [self .transform_channel ] = self .base_augmentations (data [self .transform_channel ])
9284 return output
9385
94- class ChannelSplitterSampler :
95- == == == =
9686class ChannelSplitterSampler :
97- > >> >> >> 9 c252ed35b26397634947e7bec01ccb222751af6
9887 def __init__ (self , sampler ):
9988 self .sampler = sampler
10089
10190 def __call__ (self , x ):
10291 raw , mask = x [0 ], x [1 ]
10392 return self .sampler (raw , mask )
104- < << << << HEAD
10593
10694def get_stacked_path (inputs : List [np .ndarray ]):
10795 stacked = np .stack (inputs , axis = 0 )
10896 tmp_path = f"/tmp/stacked_{ uuid .uuid4 ().hex } .h5"
10997 with h5py .File (tmp_path , "w" ) as f :
11098 f .create_dataset ("raw" , data = stacked , compression = "gzip" )
11199 return tmp_path
112- == == == =
113- >> >> >> > 9 c252ed35b26397634947e7bec01ccb222751af6
114100
115101def get_unsupervised_loader (
116102 data_paths : Tuple [str ],
117103 raw_key : str ,
118104 patch_shape : Tuple [int , int , int ],
119105 batch_size : int ,
120- << << << < HEAD
121106 n_samples : Optional [int ] = None ,
122107 sample_mask_paths : Optional [Tuple [str ]] = None ,
123108 background_mask_paths : Tuple [str ] = None ,
124- == == == =
125- n_samples : Optional [int ],
126- sample_mask_paths : Optional [Tuple [str ]] = None ,
127- >> >> >> > 9 c252ed35b26397634947e7bec01ccb222751af6
128109 sampler : Optional [callable ] = None ,
129110 exclude_top_and_bottom : bool = False ,
130111) -> torch .utils .data .DataLoader :
@@ -142,10 +123,7 @@ def get_unsupervised_loader(
142123 exclude_top_and_bottom: Whether to exluce the five top and bottom slices to
143124 avoid artifacts at the border of tomograms.
144125 sample_mask_paths: The filepaths to the corresponding sample masks for each tomogram.
145- <<<<<<< HEAD
146126 background_mask_paths: TODO add description
147- =======
148- >>>>>>> 9c252ed35b26397634947e7bec01ccb222751af6
149127 sampler: Accept or reject patches based on a condition.
150128
151129 Returns:
@@ -157,40 +135,11 @@ def get_unsupervised_loader(
157135 roi = np .s_ [5 :- 5 , :, :]
158136 else :
159137 roi = None
160- << << << < HEAD
161138
162139 # initialize class instances
163140 base_transform = torch_em .transform .get_raw_transform ()
164141 channelwise_raw_transform = ChannelWiseRawTransform (base_transform )
165142 drop_channel = DropChannel (channel = 1 )
166- == == == =
167- # stack tomograms and masks and write to temp files to use as input to RawDataset()
168- if sample_mask_paths is not None :
169- assert len (data_paths ) == len (sample_mask_paths ), \
170- f"Expected equal number of data_paths and and sample_masks_paths, got { len (data_paths )} data paths and { len (sample_mask_paths )} mask paths."
171-
172- stacked_paths = []
173- for i , (data_path , mask_path ) in enumerate (zip (data_paths , sample_mask_paths )):
174- raw = read_mrc (data_path )[0 ]
175- mask = read_mrc (mask_path )[0 ]
176- stacked = np .stack ([raw , mask ], axis = 0 )
177-
178- tmp_path = f"/tmp/stacked{ i } _{ uuid .uuid4 ().hex } .h5"
179- with h5py .File (tmp_path , "w" ) as f :
180- f .create_dataset ("raw" , data = stacked , compression = "gzip" )
181- stacked_paths .append (tmp_path )
182-
183- # update variables for RawDataset()
184- data_paths = tuple (stacked_paths )
185- base_transform = torch_em .transform .get_raw_transform ()
186- raw_transform = ComposedTransform (base_transform , drop_mask_channel )
187- sampler = ChannelSplitterSampler (sampler )
188- with_channels = True
189- else :
190- raw_transform = torch_em .transform .get_raw_transform ()
191- with_channels = False
192- sampler = None
193- >> >> >> > 9 c252ed35b26397634947e7bec01ccb222751af6
194143
195144 # get configurations
196145 has_sample_mask = sample_mask_paths is not None
@@ -256,11 +205,6 @@ def get_unsupervised_loader(
256205 else :
257206 n_samples_per_ds = int (n_samples / len (data_paths ))
258207
259- < << << << HEAD
260- == == == =
261- augmentations = (weak_augmentations (), weak_augmentations ())
262-
263- >> >> >> > 9 c252ed35b26397634947e7bec01ccb222751af6
264208 datasets = [
265209 torch_em .data .RawDataset (path , raw_key , patch_shape , raw_transform , transform , roi = roi ,
266210 n_samples = n_samples_per_ds , sampler = sampler , ndim = ndim , with_channels = with_channels , augmentations = augmentations )
0 commit comments