11from typing import Optional , Tuple
22
33import numpy as np
4+ import uuid
5+ import h5py
46import torch
57import torch_em
68import torch_em .self_training as self_training
79from torchvision import transforms
810
11+ from synapse_net .file_utils import read_mrc
912from .supervised_training import get_2d_model , get_3d_model , get_supervised_loader , _determine_ndim
1013
1114
@@ -28,14 +31,36 @@ def weak_augmentations(p: float = 0.75) -> callable:
2831 ])
2932 return torch_em .transform .raw .get_raw_transform (normalizer = norm , augmentation1 = aug )
3033
31-
34+ def drop_mask_channel (x ):
35+ x = x [:1 ]
36+ return x
37+
38+ class ComposedTransform :
39+ def __init__ (self , * funcs ):
40+ self .funcs = funcs
41+
42+ def __call__ (self , x ):
43+ for f in self .funcs :
44+ x = f (x )
45+ return x
46+
47+ class ChannelSplitterSampler :
48+ def __init__ (self , sampler ):
49+ self .sampler = sampler
50+
51+ def __call__ (self , x ):
52+ raw , mask = x [0 ], x [1 ]
53+ return self .sampler (raw , mask )
54+
3255def get_unsupervised_loader (
3356 data_paths : Tuple [str ],
3457 raw_key : str ,
3558 patch_shape : Tuple [int , int , int ],
3659 batch_size : int ,
3760 n_samples : Optional [int ],
38- exclude_top_and_bottom : bool = False ,
61+ boundary_mask_paths : Optional [Tuple [str ]] = None ,
62+ sampler : Optional [callable ] = None ,
63+ exclude_top_and_bottom : bool = False , # TODO this seems unneccesary if we have a boundary mask - remove?
3964) -> torch .utils .data .DataLoader :
4065 """Get a dataloader for unsupervised segmentation training.
4166
@@ -50,19 +75,46 @@ def get_unsupervised_loader(
5075 based on the patch_shape and size of the volumes used for training.
5176 exclude_top_and_bottom: Whether to exluce the five top and bottom slices to
5277 avoid artifacts at the border of tomograms.
78+ boundary_mask_paths: The filepaths to the corresponding boundary masks for each tomogram.
79+ sampler: Accept or reject patches based on a condition.
5380
5481 Returns:
5582 The PyTorch dataloader.
5683 """
57-
5884 # We exclude the top and bottom slices where the tomogram reconstruction is bad.
85+ # TODO this seems unneccesary if we have a boundary mask - remove?
5986 if exclude_top_and_bottom :
6087 roi = np .s_ [5 :- 5 , :, :]
6188 else :
6289 roi = None
90+ # stack tomograms and masks and write to temp files to use as input to RawDataset()
91+ if boundary_mask_paths is not None :
92+ assert len (data_paths ) == len (boundary_mask_paths ), \
93+ f"Expected equal number of data_paths and and boundary_masks_paths, got { len (data_paths )} data paths and { len (boundary_mask_paths )} mask paths."
94+
95+ stacked_paths = []
96+ for i , (data_path , mask_path ) in enumerate (zip (data_paths , boundary_mask_paths )):
97+ raw = read_mrc (data_path )[0 ]
98+ mask = read_mrc (mask_path )[0 ]
99+ stacked = np .stack ([raw , mask ], axis = 0 )
100+
101+ tmp_path = f"/tmp/stacked{ i } _{ uuid .uuid4 ().hex } .h5"
102+ with h5py .File (tmp_path , "w" ) as f :
103+ f .create_dataset ("raw" , data = stacked , compression = "gzip" )
104+ stacked_paths .append (tmp_path )
105+
106+ # update variables for RawDataset()
107+ data_paths = tuple (stacked_paths )
108+ base_transform = torch_em .transform .get_raw_transform ()
109+ raw_transform = ComposedTransform (base_transform , drop_mask_channel )
110+ sampler = ChannelSplitterSampler (sampler )
111+ with_channels = True
112+ else :
113+ raw_transform = torch_em .transform .get_raw_transform ()
114+ with_channels = False
115+ sampler = None
63116
64117 _ , ndim = _determine_ndim (patch_shape )
65- raw_transform = torch_em .transform .get_raw_transform ()
66118 transform = torch_em .transform .get_augmentations (ndim = ndim )
67119
68120 if n_samples is None :
@@ -71,18 +123,19 @@ def get_unsupervised_loader(
71123 n_samples_per_ds = int (n_samples / len (data_paths ))
72124
73125 augmentations = (weak_augmentations (), weak_augmentations ())
126+
74127 datasets = [
75- torch_em .data .RawDataset (path , raw_key , patch_shape , raw_transform , transform ,
76- augmentations = augmentations , roi = roi , ndim = ndim , n_samples = n_samples_per_ds )
128+ torch_em .data .RawDataset (path , raw_key , patch_shape , raw_transform , transform , roi = roi ,
129+ n_samples = n_samples_per_ds , sampler = sampler , ndim = ndim , with_channels = with_channels , augmentations = augmentations )
77130 for path in data_paths
78131 ]
79132 ds = torch .utils .data .ConcatDataset (datasets )
80133
81134 num_workers = 4 * batch_size
82- loader = torch_em .segmentation .get_data_loader (ds , batch_size = batch_size , num_workers = num_workers , shuffle = True )
135+ loader = torch_em .segmentation .get_data_loader (ds , batch_size = batch_size ,
136+ num_workers = num_workers , shuffle = True )
83137 return loader
84138
85-
86139# TODO: use different paths for supervised and unsupervised training
87140# (We are currently not using this functionality directly, so this is not a high priority)
88141def semisupervised_training (
0 commit comments