77from torch_em .util import ensure_tensor_with_channels
88
99
10- # Process labels stored in json napari style.
11- # I don't actually think that we need the epsilon here, but will leave it for now.
12- def process_labels (label_path , shape , sigma , eps , bb = None ):
13- points = pd .read_csv (label_path )
10+ class MinPointSampler :
11+ """A sampler to reject samples with a low fraction of foreground pixels in the labels.
12+
13+ Args:
14+ min_fraction: The minimal fraction of foreground pixels for accepting a sample.
15+ background_id: The id of the background label.
16+ p_reject: The probability for rejecting a sample that does not meet the criterion.
17+ """
18+ def __init__ (self , min_points : int , p_reject : float = 1.0 ):
19+ self .min_points = min_points
20+ self .p_reject = p_reject
21+
22+ def __call__ (self , x : np .ndarray , n_points : int ) -> bool :
23+ """Check the sample.
24+
25+ Args:
26+ x: The raw data.
27+ y: The label data.
28+
29+ Returns:
30+ Whether to accept this sample.
31+ """
32+
33+ if n_points > self .min_points :
34+ return True
35+ else :
36+ return np .random .rand () > self .p_reject
1437
15- if bb :
16- (z_min , z_max ), (y_min , y_max ), (x_min , x_max ) = [(s .start , s .stop ) for s in bb ]
17- restricted_shape = (z_max - z_min , y_max - y_min , x_max - x_min )
18- labels = np .zeros (restricted_shape , dtype = "float32" )
19- shape = restricted_shape
20- else :
21- labels = np .zeros (shape , dtype = "float32" )
2238
39+ def load_labels (label_path , shape , bb ):
40+ points = pd .read_csv (label_path )
2341 assert len (points .columns ) == len (shape )
24- z_coords , y_coords , x_coords = points ["axis-0" ], points ["axis-1" ], points ["axis-2" ]
42+ z_coords , y_coords , x_coords = points ["axis-0" ].values , points ["axis-1" ].values , points ["axis-2" ].values
43+
2544 if bb is not None :
45+ (z_min , z_max ), (y_min , y_max ), (x_min , x_max ) = [(s .start , s .stop ) for s in bb ]
2646 z_coords -= z_min
2747 y_coords -= y_min
2848 x_coords -= x_min
@@ -32,13 +52,31 @@ def process_labels(label_path, shape, sigma, eps, bb=None):
3252 np .logical_and (x_coords >= 0 , x_coords < (x_max - x_min )),
3353 ])
3454 z_coords , y_coords , x_coords = z_coords [mask ], y_coords [mask ], x_coords [mask ]
55+ restricted_shape = (z_max - z_min , y_max - y_min , x_max - x_min )
56+ shape = restricted_shape
3557
58+ n_points = len (z_coords )
3659 coords = tuple (
3760 np .clip (np .round (coord ).astype ("int" ), 0 , coord_max - 1 ) for coord , coord_max in zip (
3861 (z_coords , y_coords , x_coords ), shape
3962 )
4063 )
4164
65+ return coords , n_points
66+
67+
68+ # Process labels stored in json napari style.
69+ # I don't actually think that we need the epsilon here, but will leave it for now.
70+ def process_labels (coords , shape , sigma , eps , bb = None ):
71+
72+ if bb :
73+ (z_min , z_max ), (y_min , y_max ), (x_min , x_max ) = [(s .start , s .stop ) for s in bb ]
74+ restricted_shape = (z_max - z_min , y_max - y_min , x_max - x_min )
75+ labels = np .zeros (restricted_shape , dtype = "float32" )
76+ shape = restricted_shape
77+ else :
78+ labels = np .zeros (shape , dtype = "float32" )
79+
4280 labels [coords ] = 1
4381 labels = gaussian (labels , sigma )
4482 # TODO better normalization?
@@ -124,16 +162,10 @@ def _get_sample(self, index):
124162 raw , label_path = self .raw_path , self .label_path
125163
126164 raw = zarr .open (raw )[self .raw_key ]
165+ have_raw_channels = raw .ndim == 4 # 3D with channels
127166 shape = raw .shape
128167
129168 bb = self ._sample_bounding_box (shape )
130- label = process_labels (label_path , shape , self .sigma , self .eps , bb = bb )
131-
132- have_raw_channels = raw .ndim == 4 # 3D with channels
133- have_label_channels = label .ndim == 4
134- if have_label_channels :
135- raise NotImplementedError ("Multi-channel labels are not supported." )
136-
137169 prefix_box = tuple ()
138170 if have_raw_channels :
139171 if shape [- 1 ] < 16 :
@@ -143,18 +175,25 @@ def _get_sample(self, index):
143175 prefix_box = (slice (None ), )
144176
145177 raw_patch = np .array (raw [prefix_box + bb ])
146- label_patch = np .array (label )
147178
179+ coords , n_points = load_labels (label_path , shape , bb )
148180 if self .sampler is not None :
149- assert False , "Sampler not implemented"
150- # sample_id = 0
151- # while not self.sampler(raw_patch, label_patch):
152- # bb = self._sample_bounding_box(shape)
153- # raw_patch = np.array(raw[prefix_box + bb])
154- # label_patch = np.array(label[bb])
155- # sample_id += 1
156- # if sample_id > self.max_sampling_attempts:
157- # raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts")
181+ sample_id = 0
182+ while not self .sampler (raw_patch , n_points ):
183+ bb = self ._sample_bounding_box (shape )
184+ raw_patch = np .array (raw [prefix_box + bb ])
185+ coords , n_points = load_labels (label_path , shape , bb )
186+ sample_id += 1
187+ if sample_id > self .max_sampling_attempts :
188+ raise RuntimeError (f"Could not sample a valid batch in { self .max_sampling_attempts } attempts" )
189+
190+ label = process_labels (coords , shape , self .sigma , self .eps , bb = bb )
191+
192+ have_label_channels = label .ndim == 4
193+ if have_label_channels :
194+ raise NotImplementedError ("Multi-channel labels are not supported." )
195+
196+ label_patch = np .array (label )
158197
159198 if have_raw_channels and len (prefix_box ) == 0 :
160199 raw_patch = raw_patch .transpose ((3 , 0 , 1 , 2 )) # Channels, Depth, Height, Width
0 commit comments