99
1010# Process labels stored in json napari style.
1111# I don't actually think that we need the epsilon here, but will leave it for now.
12- def process_labels (label_path , shape , sigma , eps ):
13- labels = np .zeros (shape , dtype = "float32" )
12+ def process_labels (label_path , shape , sigma , eps , bb = None ):
1413 points = pd .read_csv (label_path )
14+
15+ if bb :
16+ (z_min , z_max ), (y_min , y_max ), (x_min , x_max ) = [(s .start , s .stop ) for s in bb ]
17+ restricted_shape = (z_max - z_min , y_max - y_min , x_max - x_min )
18+ labels = np .zeros (restricted_shape , dtype = "float32" )
19+ shape = restricted_shape
20+ else :
21+ labels = np .zeros (shape , dtype = "float32" )
22+
1523 assert len (points .columns ) == len (shape )
24+ z_coords , y_coords , x_coords = points ["axis-0" ], points ["axis-1" ], points ["axis-2" ]
25+ if bb is not None :
26+ z_coords -= z_min
27+ y_coords -= y_min
28+ x_coords -= x_min
29+ mask = np .logical_and .reduce ([
30+ np .logical_and (z_coords >= 0 , z_coords < (z_max - z_min )),
31+ np .logical_and (y_coords >= 0 , y_coords < (y_max - y_min )),
32+ np .logical_and (x_coords >= 0 , x_coords < (x_max - x_min )),
33+ ])
34+ z_coords , y_coords , x_coords = z_coords [mask ], y_coords [mask ], x_coords [mask ]
35+
1636 coords = tuple (
17- np .clip (np .round (points [ax ].values ).astype ("int" ), 0 , shape [i ] - 1 )
18- for i , ax in enumerate (points .columns )
37+ np .clip (np .round (coord ).astype ("int" ), 0 , coord_max - 1 ) for coord , coord_max in zip (
38+ (z_coords , y_coords , x_coords ), shape
39+ )
1940 )
41+
2042 labels [coords ] = 1
2143 labels = gaussian (labels , sigma )
2244 # TODO better normalization?
23- labels /= labels .max ()
45+ labels /= (labels .max () + 1e-7 )
46+ labels *= 4
2447 return labels
2548
2649
2750class DetectionDataset (torch .utils .data .Dataset ):
2851 max_sampling_attempts = 500
2952
53+ @staticmethod
54+ def compute_len (shape , patch_shape ):
55+ if patch_shape is None :
56+ return 1
57+ else :
58+ n_samples = int (np .prod ([float (sh / csh ) for sh , csh in zip (shape , patch_shape )]))
59+ return n_samples
60+
3061 def __init__ (
3162 self ,
32- raw_image_paths ,
33- label_paths ,
63+ raw_path ,
64+ label_path ,
3465 patch_shape ,
66+ raw_key ,
3567 raw_transform = None ,
3668 label_transform = None ,
3769 transform = None ,
@@ -43,10 +75,9 @@ def __init__(
4375 sigma = None ,
4476 ** kwargs ,
4577 ):
46- self .raw_images = raw_image_paths
47- # TODO make this a parameter
48- self .raw_key = "raw"
49- self .label_images = label_paths
78+ self .raw_path = raw_path
79+ self .label_path = label_path
80+ self .raw_key = raw_key
5081 self ._ndim = 3
5182
5283 assert len (patch_shape ) == self ._ndim
@@ -63,12 +94,13 @@ def __init__(
6394 self .eps = eps
6495 self .sigma = sigma
6596
97+ with zarr .open (self .raw_path , "r" ) as f :
98+ self .shape = f [self .raw_key ].shape
99+
66100 if n_samples is None :
67- self ._len = len (self .raw_images )
68- self .sample_random_index = False
101+ self ._len = self .compute_len (self .shape , self .patch_shape ) if n_samples is None else n_samples
69102 else :
70103 self ._len = n_samples
71- self .sample_random_index = True
72104
73105 def __len__ (self ):
74106 return self ._len
@@ -89,21 +121,19 @@ def _sample_bounding_box(self, shape):
89121 return tuple (slice (start , start + psh ) for start , psh in zip (bb_start , self .patch_shape ))
90122
91123 def _get_sample (self , index ):
92- if self .sample_random_index :
93- index = np .random .randint (0 , len (self .raw_images ))
94- raw , label = self .raw_images [index ], self .label_images [index ]
124+ raw , label_path = self .raw_path , self .label_path
95125
96126 raw = zarr .open (raw )[self .raw_key ]
97- # Note: this is quite inefficient, because we process the full crop rather than
98- # just the requested bounding box.
99- label = process_labels (label , raw .shape , self .sigma , self .eps )
127+ shape = raw .shape
128+
129+ bb = self ._sample_bounding_box (shape )
130+ label = process_labels (label_path , shape , self .sigma , self .eps , bb = bb )
100131
101132 have_raw_channels = raw .ndim == 4 # 3D with channels
102133 have_label_channels = label .ndim == 4
103134 if have_label_channels :
104135 raise NotImplementedError ("Multi-channel labels are not supported." )
105136
106- shape = raw .shape
107137 prefix_box = tuple ()
108138 if have_raw_channels :
109139 if shape [- 1 ] < 16 :
@@ -112,19 +142,19 @@ def _get_sample(self, index):
112142 shape = shape [1 :]
113143 prefix_box = (slice (None ), )
114144
115- bb = self ._sample_bounding_box (shape )
116145 raw_patch = np .array (raw [prefix_box + bb ])
117- label_patch = np .array (label [ bb ] )
146+ label_patch = np .array (label )
118147
119148 if self .sampler is not None :
120- sample_id = 0
121- while not self .sampler (raw_patch , label_patch ):
122- bb = self ._sample_bounding_box (shape )
123- raw_patch = np .array (raw [prefix_box + bb ])
124- label_patch = np .array (label [bb ])
125- sample_id += 1
126- if sample_id > self .max_sampling_attempts :
127- raise RuntimeError (f"Could not sample a valid batch in { self .max_sampling_attempts } attempts" )
149+ assert False , "Sampler not implemented"
150+ # sample_id = 0
151+ # while not self.sampler(raw_patch, label_patch):
152+ # bb = self._sample_bounding_box(shape)
153+ # raw_patch = np.array(raw[prefix_box + bb])
154+ # label_patch = np.array(label[bb])
155+ # sample_id += 1
156+ # if sample_id > self.max_sampling_attempts:
157+ # raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts")
128158
129159 if have_raw_channels and len (prefix_box ) == 0 :
130160 raw_patch = raw_patch .transpose ((3 , 0 , 1 , 2 )) # Channels, Depth, Height, Width
0 commit comments