1515from torch .utils .data import Dataset
1616from tqdm import tqdm
1717
18+ import fastremap
1819import numpy as np
1920import random
2021import tensorstore as ts
@@ -178,44 +179,68 @@ def sample_voxel(self, brain_id):
178179 def sample_foreground_voxel (self , brain_id ):
179180 if self .skeletons [brain_id ] is not None and np .random .random () > 0.5 :
180181 return self .sample_skeleton_voxel (brain_id )
181- #elif self.segmentations[brain_id] is not None:
182- # return self.sample_segmentation_voxel(brain_id)
183- else :
184- return self .sample_bright_voxel (brain_id )
182+ else : # self.segmentations[brain_id] is not None:
183+ return self .sample_segmentation_voxel (brain_id )
184+
185+ #else:
186+ # return self.sample_bright_voxel(brain_id)
185187
186188 def sample_skeleton_voxel (self , brain_id ):
187189 idx = random .randint (0 , len (self .foreground [brain_id ]) - 1 )
188190 shift = np .random .randint (0 , 16 , size = 3 )
189191 return tuple (self .foreground [brain_id ][idx ] + shift )
190192
191193 def sample_segmentation_voxel (self , brain_id ):
194+ while self .segmentations [brain_id ] is None :
195+ brain_id = self .sample_brain ()
196+
192197 cnt = 0
193- while cnt < 32 :
198+ best_voxel , max_volume = None , 0
199+ while max_volume < 4000 :
194200 # Read random image patch
195201 voxel = self .sample_interior_voxel (brain_id )
196202 labels_patch = self .read_precomputed_patch (brain_id , voxel )
197203
198204 # Check if labels patch has large enough object
199- # --> call fastremap
200- # --> find largest object
201- return voxel
205+ vals , cnts = fastremap .unique (labels_patch , return_counts = True )
206+ if len (cnts ) > 1 :
207+ volume = np .max (cnts [1 :])
208+ if volume > max_volume :
209+ best_voxel = voxel
210+ max_volume = volume
211+
212+ # Check number of tries
213+ cnt += 1
214+ if cnt > 32 :
215+ break
216+
217+ print ("Brain_ID:" , brain_id )
218+ print ("Largest Volume:" , max_volume )
219+ print ("Voxel:" , best_voxel )
220+ print ("# Attempts:" , cnt )
221+ return best_voxel
202222
203223 def sample_bright_voxel (self , brain_id ):
204224 cnt = 0
205225 brightest_voxel , max_brightness = None , 0
206- while cnt < 32 :
226+ while max_brightness < self . min_brightness :
207227 # Read random image patch
208228 voxel = self .sample_interior_voxel (brain_id )
209229 img_patch = self .read_patch (brain_id , voxel )
210230
211231 # Check if image patch is bright enough
212232 brightness = np .max (img_patch )
213- if brightness >= self .min_brightness :
214- return voxel
215- elif brightness > max_brightness :
233+ if brightness > max_brightness :
216234 brightest_voxel = voxel
217235 max_brightness = brightness
236+
237+ # Check number of tries
218238 cnt += 1
239+ if cnt > 32 :
240+ break
241+ print ("Brain_ID:" , brain_id )
242+ print ("Max Brightness:" , max_brightness )
243+ print ("# Attempts:" , cnt )
219244 return brightest_voxel
220245
221246 def sample_interior_voxel (self , brain_id ):
@@ -238,21 +263,57 @@ def __len__(self):
238263 return self .n_examples_per_epoch
239264
240265 def read_patch (self , brain_id , center ):
266+ """
267+ Reads an image patch from a Zarr array.
268+
269+ Parameters
270+ ----------
271+ brain_id : str
272+ Unique identifier of the sampled brain.
273+ center : Tuple[int]
274+ Center of image patch to be read.
275+
276+ Returns
277+ -------
278+ numpy.ndarray
279+ Image patch.
280+ """
241281 s = img_util .get_slices (center , self .patch_shape )
242282 return self .imgs [brain_id ][(0 , 0 , * s )]
243283
244284 def read_precomputed_patch (self , brain_id , center ):
245285 """
246- Reads an image patch from a precomputed array.
286+ Reads an image patch from a Precomputed array.
247287
248288 Parameters
249289 ----------
250- ...
290+ brain_id : str
291+ Unique identifier of the sampled brain.
292+ center : Tuple[int]
293+ Center of image patch to be read.
294+
295+ Returns
296+ -------
297+ numpy.ndarray
298+ Image patch.
251299 """
252300 s = img_util .get_slices (center , self .patch_shape )
253- return self .segmentations [brain_id ][( 0 , 0 , * s ) ].read ().result ()
301+ return self .segmentations [brain_id ][s ].read ().result ()
254302
255303 def to_voxels (self , xyz_arr ):
304+ """
305+ Converts 3D points from physical to voxel coordinates.
306+
307+ Parameters
308+ ----------
309+ xyz_arr : numpy.ndarray
310+ Array with shape (n, 3) that contains 3D points.
311+
312+ Returns
313+ -------
314+ numpy.ndarray
315+ 3D Points converted to voxel coordinates.
316+ """
256317 for i in range (3 ):
257318 xyz_arr [:, i ] = xyz_arr [:, i ] / self .anisotropy [i ]
258319 return np .flip (xyz_arr , axis = 1 ).astype (int )
@@ -442,7 +503,7 @@ def init_datasets(
442503 n_train_examples_per_epoch = 100 ,
443504 n_validate_examples = 0 ,
444505 segmentation_prefixes_path = None ,
445- sigma_bm4d = 30 ,
506+ sigma_bm4d = 16 ,
446507 swc_pointers = None
447508):
448509 # Initializations
0 commit comments