1717
1818import numpy as np
1919import random
20+ import tensorstore as ts
2021import torch
2122
2223from aind_exaspim_image_compression .utils import img_util , util
@@ -37,10 +38,12 @@ def __init__(
3738 self ,
3839 patch_shape ,
3940 anisotropy = (0.748 , 0.748 , 1.0 ),
40- boundary_buffer = 4000 ,
41+ boundary_buffer = 5000 ,
4142 foreground_sampling_rate = 0.2 ,
42- n_examples_per_epoch = 200 ,
43- sigma_bm4d = 50 ,
43+ min_brightness = 200 ,
44+ n_examples_per_epoch = 300 ,
45+ normalization_percentiles = [0.5 , 99.9 ],
46+ sigma_bm4d = 30 ,
4447 ):
4548 # Call parent class
4649 super (TrainDataset , self ).__init__ ()
@@ -49,51 +52,110 @@ def __init__(
4952 self .anisotropy = anisotropy
5053 self .boundary_buffer = boundary_buffer
5154 self .foreground_sampling_rate = foreground_sampling_rate
55+ self .min_brightness = min_brightness
5256 self .n_examples_per_epoch = n_examples_per_epoch
57+ self .normalization_percentiles = normalization_percentiles
5358 self .patch_shape = patch_shape
5459 self .sigma_bm4d = sigma_bm4d
5560 self .swc_reader = Reader ()
5661
5762 # Data structures
58- self .foreground = dict ()
63+ self .segmentations = dict ()
64+ self .skeletons = dict ()
5965 self .imgs = dict ()
6066
6167 # --- Ingest data ---
62- def ingest_brain (self , brain_id , img_path , swc_pointer ):
63- self .foreground [brain_id ] = self .load_swcs (swc_pointer )
68+ def ingest_brain (self , brain_id , img_path , segmentation_path , swc_pointer ):
69+ """
70+ Loads a brain image, label mask, and skeletons, then stores each in
71+ internal dictionaries.
72+
73+ Parameters
74+ ----------
75+ brain_id : hashable
76+ Unique identifier for the brain corresponding to the image.
77+ img_path : str or Path
78+ Path to whole-brain image to be read.
79+ segmentation_path : str
80+ Path to segmentation.
81+ swc_path : str
82+ Path to SWC files.
83+ """
84+ self .segmentations [brain_id ] = self .load_segmentation (segmentation_path )
6485 self .imgs [brain_id ] = img_util .read (img_path )
86+ self .skeletons [brain_id ] = self .load_swcs (swc_pointer )
87+
88+ def load_segmentation (self , segmentation_path ):
89+ """
90+ Reads a segmentation mask generated by Google Applied Sciences (GAS).
91+
92+ Parameters
93+ ----------
94+ segmentation_path : str
95+ Path to segmentation.
96+
97+ Returns
98+ -------
99+ ...
100+ """
101+ if segmentation_path :
102+ # Load image
103+ label_mask = ts .open (
104+ {
105+ "driver" : "neuroglancer_precomputed" ,
106+ "kvstore" : {
107+ "driver" : "gcs" ,
108+ "bucket" : "allen-nd-goog" ,
109+ "path" : segmentation_path ,
110+ },
111+ "context" : {
112+ "cache_pool" : {"total_bytes_limit" : 1000000000 },
113+ "cache_pool#remote" : {"total_bytes_limit" : 1000000000 },
114+ "data_copy_concurrency" : {"limit" : 8 },
115+ },
116+ "recheck_cached_data" : "open" ,
117+ }
118+ ).result ()
119+
120+ # Permute axes to be consistent with raw image.
121+ label_mask = label_mask [ts .d ["channel" ][0 ]]
122+ label_mask = label_mask [ts .d [0 ].transpose [2 ]]
123+ label_mask = label_mask [ts .d [0 ].transpose [1 ]]
124+ return label_mask
125+ else :
126+ return None
65127
66128 def load_swcs (self , swc_pointer ):
67129 if swc_pointer :
68130 # Initializations
69131 swc_dicts = self .swc_reader .read (swc_pointer )
70132 n_points = np .sum ([len (d ["xyz" ]) for d in swc_dicts ])
71133
72- # Extract foreground voxels
134+ # Extract skeleton voxels
73135 if n_points > 0 :
74136 start = 0
75- foreground = np .zeros ((n_points , 3 ), dtype = np .int32 )
137+ skeletons = np .zeros ((n_points , 3 ), dtype = np .int32 )
76138 for swc_dict in swc_dicts :
77139 end = start + len (swc_dict ["xyz" ])
78- foreground [start :end ] = self .to_voxels (swc_dict ["xyz" ])
140+ skeletons [start :end ] = self .to_voxels (swc_dict ["xyz" ])
79141 start = end
80- return foreground
81- return set ()
142+ return skeletons
143+ return None
82144
83145 # --- Core Routines ---
84146 def __getitem__ (self , dummy_input ):
85147 # Sample image patch
86148 brain_id = self .sample_brain ()
87149 voxel = self .sample_voxel (brain_id )
88- noise = self .get_patch (brain_id , voxel )
89- mn , mx = np .percentile (noise , 5 ), np . percentile ( noise , 99.9 )
150+ noise = self .read_patch (brain_id , voxel )
151+ mn , mx = np .percentile (noise , self . normalization_percentiles )
90152
91153 # Denoise image patch
92154 denoised = bm4d (noise , self .sigma_bm4d )
93155
94156 # Normalize image patches
95- noise = (noise - mn ) / max (mx , 1 )
96- denoised = (denoised - mn ) / max (mx , 1 )
157+ noise = (noise - mn ) / max (mx - mn , 1 )
158+ denoised = (denoised - mn ) / max (mx - mn , 1 )
97159 return noise , denoised , (mn , mx )
98160
99161 def sample_brain (self ):
@@ -114,12 +176,47 @@ def sample_voxel(self, brain_id):
114176 return self .sample_interior_voxel (brain_id )
115177
116178 def sample_foreground_voxel (self , brain_id ):
117- if len ( self .foreground [brain_id ]) > 0 :
118- idx = random . randint ( 0 , len ( self .foreground [ brain_id ]) - 1 )
119- shift = np . random . randint ( 0 , 16 , size = 3 )
120- return tuple ( self .foreground [ brain_id ][ idx ] + shift )
179+ if self .skeletons [brain_id ] is not None and np . random . random ( ) > 0.5 :
180+ return self .sample_skeleton_voxel ( brain_id )
181+ #elif self.segmentations[brain_id] is not None:
182+ # return self.sample_segmentation_voxel( brain_id)
121183 else :
122- return self .sample_interior_voxel (brain_id )
184+ return self .sample_bright_voxel (brain_id )
185+
186+ def sample_skeleton_voxel (self , brain_id ):
187+ idx = random .randint (0 , len (self .foreground [brain_id ]) - 1 )
188+ shift = np .random .randint (0 , 16 , size = 3 )
189+ return tuple (self .foreground [brain_id ][idx ] + shift )
190+
191+ def sample_segmentation_voxel (self , brain_id ):
192+ cnt = 0
193+ while cnt < 32 :
194+ # Read random image patch
195+ voxel = self .sample_interior_voxel (brain_id )
196+ labels_patch = self .read_precomputed_patch (brain_id , voxel )
197+
198+ # Check if labels patch has large enough object
199+ # --> call fastremap
200+ # --> find largest object
201+ return voxel
202+
203+ def sample_bright_voxel (self , brain_id ):
204+ cnt = 0
205+ brightest_voxel , max_brightness = None , 0
206+ while cnt < 32 :
207+ # Read random image patch
208+ voxel = self .sample_interior_voxel (brain_id )
209+ img_patch = self .read_patch (brain_id , voxel )
210+
211+ # Check if image patch is bright enough
212+ brightness = np .max (img_patch )
213+ if brightness >= self .min_brightness :
214+ return voxel
215+ elif brightness > max_brightness :
216+ brightest_voxel = voxel
217+ max_brightness = brightness
218+ cnt += 1
219+ return brightest_voxel
123220
124221 def sample_interior_voxel (self , brain_id ):
125222 voxel = list ()
@@ -140,36 +237,54 @@ def __len__(self):
140237 """
141238 return self .n_examples_per_epoch
142239
143- def get_patch (self , brain_id , voxel ):
144- s , e = img_util .get_start_end (voxel , self .patch_shape )
145- return self .imgs [brain_id ][0 , 0 , s [0 ]: e [0 ], s [1 ]: e [1 ], s [2 ]: e [2 ]]
240+ def read_patch (self , brain_id , center ):
241+ s = img_util .get_slices (center , self .patch_shape )
242+ return self .imgs [brain_id ][(0 , 0 , * s )]
243+
244+ def read_precomputed_patch (self , brain_id , center ):
245+ """
246+ Reads an image patch from a precomputed array.
247+
248+ Parameters
249+ ----------
250+ ...
251+ """
252+ s = img_util .get_slices (center , self .patch_shape )
253+ return self .segmentations [brain_id ][(0 , 0 , * s )].read ().result ()
146254
147255 def to_voxels (self , xyz_arr ):
148256 for i in range (3 ):
149257 xyz_arr [:, i ] = xyz_arr [:, i ] / self .anisotropy [i ]
150258 return np .flip (xyz_arr , axis = 1 ).astype (int )
151259
152- def update_foreground_sampling_rate (self , foreground_sampling_rate ):
153- self .foreground_sampling_rate = foreground_sampling_rate
154-
155260
156261class ValidateDataset (Dataset ):
157262
158- def __init__ (self , patch_shape , sigma_bm4d = 50 ):
263+ def __init__ (
264+ self ,
265+ patch_shape ,
266+ normalization_percentiles = [0.5 , 99.9 ],
267+ sigma_bm4d = 30 ,
268+ ):
159269 """
160270 Instantiates a ValidateDataset object.
161271
162272 Parameters
163273 ----------
164274 patch_shape : Tuple[int]
165275 Shape of image patches to be extracted.
166- sigma_bm4d : float
167- Smoothing parameter used in the BM4D denoising algorithm.
276+ normalization_percentiles : List[float], optional
277+ Upper and lower percentiles used to normalize the input image.
278+ Default is [0.5, 99.5].
279+ sigma_bm4d : float, optional
280+ Smoothing parameter used in the BM4D denoising algorithm. Default
281+ is 30.
168282 """
169283 # Call parent class
170284 super (ValidateDataset , self ).__init__ ()
171285
172286 # Instance attributes
287+ self .normalization_percentiles = normalization_percentiles
173288 self .patch_shape = patch_shape
174289 self .sigma_bm4d = sigma_bm4d
175290
@@ -217,13 +332,13 @@ def ingest_example(self, brain_id, voxel):
217332 Voxel coordinates of the patch center in the brain volume.
218333 """
219334 # Get image patches
220- noise = self .get_patch (brain_id , voxel )
221- mn , mx = np .percentile (noise , 5 ), np . percentile ( noise , 99.9 )
335+ noise = self .read_patch (brain_id , voxel )
336+ mn , mx = np .percentile (noise , self . normalization_percentiles )
222337 denoised = bm4d (noise , self .sigma_bm4d )
223338
224339 # Normalize image patches
225- noise = (noise - mn ) / max (mx , 1 )
226- denoised = (denoised - mn ) / max (mx , 1 )
340+ noise = (noise - mn ) / max (mx - mn , 1 )
341+ denoised = (denoised - mn ) / max (mx - mn , 1 )
227342
228343 # Store results
229344 self .example_ids .append ((brain_id , voxel ))
@@ -251,9 +366,9 @@ def __getitem__(self, idx):
251366 """
252367 return self .noise [idx ], self .denoised [idx ], self .mn_mxs [idx ]
253368
254- def get_patch (self , brain_id , voxel ):
255- s , e = img_util .get_start_end ( voxel , self .patch_shape )
256- return self .imgs [brain_id ][0 , 0 , s [ 0 ]: e [ 0 ], s [ 1 ]: e [ 1 ], s [ 2 ]: e [ 2 ] ]
369+ def read_patch (self , brain_id , center ):
370+ slices = img_util .get_slices ( center , self .patch_shape )
371+ return self .imgs [brain_id ][( 0 , 0 , * slices ) ]
257372
258373
259374# --- Custom Dataloader ---
@@ -326,8 +441,9 @@ def init_datasets(
326441 foreground_sampling_rate = 0.5 ,
327442 n_train_examples_per_epoch = 100 ,
328443 n_validate_examples = 0 ,
329- sigma_bm4d = 50 ,
330- swc_dict = None
444+ segmentation_prefixes_path = None ,
445+ sigma_bm4d = 30 ,
446+ swc_pointers = None
331447):
332448 # Initializations
333449 train_dataset = TrainDataset (
@@ -338,19 +454,35 @@ def init_datasets(
338454 )
339455 val_dataset = ValidateDataset (patch_shape )
340456
457+ # Read segmentation path lookup (if applicable)
458+ if segmentation_prefixes_path :
459+ segmentation_paths = util .read_json (segmentation_prefixes_path )
460+ else :
461+ segmentation_paths = dict ()
462+
341463 # Load data
342464 for brain_id in tqdm (brain_ids , desc = "Load Data" ):
343- # Set paths
465+ # Set image path
344466 img_path = get_img_prefix (brain_id , img_paths_json ) + str (0 )
345- if swc_dict :
346- swc_pointer = deepcopy (swc_dict )
467+
468+ # Set segmentation path
469+ if brain_id in segmentation_paths :
470+ segmentation_path = segmentation_paths [brain_id ]
471+ else :
472+ segmentation_path = None
473+
474+ # Set SWC pointer
475+ if swc_pointers :
476+ swc_pointer = deepcopy (swc_pointers )
347477 swc_pointer ["path" ] += f"/{ brain_id } /world"
348478 else :
349479 swc_pointer = None
350480
351481 # Ingest data
352- train_dataset .ingest_brain (brain_id , img_path , swc_pointer )
353482 val_dataset .ingest_brain (brain_id , img_path )
483+ train_dataset .ingest_brain (
484+ brain_id , img_path , segmentation_path , swc_pointer
485+ )
354486
355487 # Generate validation examples
356488 for _ in range (n_validate_examples ):
0 commit comments