2929from deep_neurographs .utils .gnn_util import toCPU
3030from deep_neurographs .utils .graph_util import GraphLoader
3131
32- BATCH_SIZE = 2000
33- CONFIDENCE_THRESHOLD = 0.7
34-
3532
3633class InferencePipeline :
3734 """
@@ -132,9 +129,9 @@ def __init__(
132129 self .model_path ,
133130 self .ml_config .model_type ,
134131 self .graph_config .search_radius ,
132+ accept_threshold = self .ml_config .threshold ,
135133 anisotropy = self .ml_config .anisotropy ,
136134 batch_size = self .ml_config .batch_size ,
137- confidence_threshold = self .ml_config .threshold ,
138135 device = device ,
139136 multiscale = self .ml_config .multiscale ,
140137 labels_path = labels_path ,
@@ -178,21 +175,27 @@ def run(self, fragments_pointer):
178175 # Finish
179176 self .report ("Final Graph..." )
180177 self .report_graph ()
181-
182178 t , unit = util .time_writer (time () - t0 )
183179 self .report (f"Total Runtime: { round (t , 4 )} { unit } \n " )
184180
185181 def run_schedule (self , fragments_pointer , radius_schedule ):
186- t0 = time ()
182+ # Initializations
187183 self .log_experiment ()
184+ self .write_metadata ()
185+ t0 = time ()
186+
187+ # Main
188188 self .build_graph (fragments_pointer )
189189 for round_id , radius in enumerate (radius_schedule ):
190- self .report (f"--- Round { round_id + 1 } : Radius = { radius } ---" )
191190 round_id += 1
191+ self .report (f"--- Round { round_id } : Radius = { radius } ---" )
192192 self .generate_proposals (radius )
193193 self .run_inference ()
194194 self .save_results (round_id = round_id )
195195
196+ # Finish
197+ self .report ("Final Graph..." )
198+ self .report_graph ()
196199 t , unit = util .time_writer (time () - t0 )
197200 self .report (f"Total Runtime: { round (t , 4 )} { unit } \n " )
198201
@@ -212,7 +215,7 @@ def build_graph(self, fragments_pointer):
212215 None
213216
214217 """
215- self .report ("(1) Building FragmentGraph" )
218+ self .report ("Step 1: Building FragmentGraph" )
216219 t0 = time ()
217220
218221 # Initialize Graph
@@ -233,31 +236,27 @@ def build_graph(self, fragments_pointer):
233236 self .graph .save_labels (labels_path )
234237 self .report (f"# SWCs Saved: { n_saved } " )
235238
236- # Report runtime
239+ # Report results
237240 t , unit = util .time_writer (time () - t0 )
238241 self .report (f"Module Runtime: { round (t , 4 )} { unit } " )
239-
240- # Report graph overview
241242 self .report ("\n Initial Graph..." )
242243 self .report_graph ()
243244
244245 def filter_fragments (self ):
245- # Filter curvy fragments
246+ # Curvy fragments
246247 n_curvy = fragment_filtering .remove_curvy (self .graph , 200 )
247- n_curvy = util .reformat_number (n_curvy )
248248
249- # Filter doubles
249+ # Double fragments
250250 if self .graph_config .remove_doubles_bool :
251251 n_doubles = fragment_filtering .remove_doubles (
252252 self .graph , 200 , self .graph_config .node_spacing
253253 )
254- n_doubles = util .reformat_number (n_doubles )
255254 self .report (f"# Double Fragments Deleted: { n_doubles } " )
256255 self .report (f"# Curvy Fragments Deleted: { n_curvy } " )
257256
258257 def generate_proposals (self , radius = None ):
259258 """
260- Generates proposals for the fragment graph based on the specified
259+ Generates proposals for the fragments graph based on the specified
261260 configuration.
262261
263262 Parameters
@@ -270,7 +269,7 @@ def generate_proposals(self, radius=None):
270269
271270 """
272271 # Initializations
273- self .report ("(2) Generate Proposals" )
272+ self .report ("Step 2: Generate Proposals" )
274273 if radius is None :
275274 radius = self .graph_config .search_radius
276275
@@ -307,17 +306,21 @@ def run_inference(self):
307306 None
308307
309308 """
310- self .report ("(3) Run Inference" )
309+ # Initializations
310+ self .report ("Step 3: Run Inference" )
311+ proposals = self .graph .list_proposals ()
312+ n_proposals = max (len (proposals ), 1 )
313+
314+ # Main
311315 t0 = time ()
312- n_proposals = max (self .graph .n_proposals (), 1 )
313- self .graph , accepts = self .inference_engine .run (
314- self .graph , self .graph .list_proposals ()
315- )
316+ self .graph , accepts = self .inference_engine .run (self .graph , proposals )
316317 self .accepted_proposals .extend (accepts )
317- self .report (f"# Accepted: { util .reformat_number (len (accepts ))} " )
318- self .report (f"% Accepted: { round (len (accepts ) / n_proposals , 4 )} " )
319318
319+ # Report results
320320 t , unit = util .time_writer (time () - t0 )
321+ n_accepts = len (self .accepted_proposals )
322+ self .report (f"# Accepted: { util .reformat_number (n_accepts )} " )
323+ self .report (f"% Accepted: { round (n_accepts / n_proposals , 4 )} " )
321324 self .report (f"Module Runtime: { round (t , 4 )} { unit } \n " )
322325
323326 def save_results (self , round_id = None ):
@@ -334,15 +337,15 @@ def save_results(self, round_id=None):
334337 None
335338
336339 """
337- # Save result locally
340+ # Save result on local machine
338341 suffix = f"-{ round_id } " if round_id else ""
339342 filename = f"corrected-processed-swcs{ suffix } .zip"
340343 path = os .path .join (self .output_dir , filename )
341344 self .graph .to_zipped_swcs (path )
342345 self .save_connections (round_id = round_id )
343346 self .write_metadata ()
344347
345- # Save result on s3
348+ # Save result on s3 (if applicable)
346349 filename = f"corrected-processed-swcs-s3.zip"
347350 path = os .path .join (self .output_dir , filename )
348351 self .graph .to_zipped_swcs (path , min_size = 50 )
@@ -373,7 +376,8 @@ def save_to_s3(self):
373376 # --- io ---
374377 def save_connections (self , round_id = None ):
375378 """
376- Saves predicted connections between connected components in a txt file.
379+ Writes the accepted proposals from the graph to a text file. Each line
380+ contains the two swc ids as comma separated values.
377381
378382 Parameters
379383 ----------
@@ -414,7 +418,7 @@ def write_metadata(self):
414418 "long_range_bool" : self .graph_config .long_range_bool ,
415419 "proposals_per_leaf" : self .graph_config .proposals_per_leaf ,
416420 "search_radius" : f"{ self .graph_config .search_radius } um" ,
417- "confidence_threshold " : self .ml_config .threshold ,
421+ "accept_threshold " : self .ml_config .threshold ,
418422 "node_spacing" : self .graph_config .node_spacing ,
419423 "remove_doubles" : self .graph_config .remove_doubles_bool ,
420424 }
@@ -475,9 +479,9 @@ def __init__(
475479 model_path ,
476480 model_type ,
477481 radius ,
482+ accept_threshold = 0.7 ,
478483 anisotropy = [1.0 , 1.0 , 1.0 ],
479- batch_size = BATCH_SIZE ,
480- confidence_threshold = CONFIDENCE_THRESHOLD ,
484+ batch_size = 2000 ,
481485 device = None ,
482486 multiscale = 1 ,
483487 labels_path = None ,
@@ -490,22 +494,27 @@ def __init__(
490494 Parameters
491495 ----------
492496 img_path : str
493- Path to image stored in a GCS bucket .
497+ Path to image.
494498 model_path : str
495- Path to machine learning model parameters .
499+ Path to machine learning model weights .
496500 model_type : str
497501 Type of machine learning model used to perform inference.
498502 radius : float
499503 Search radius used to generate proposals.
504+ accept_threshold : float, optional
505+ Threshold for accepting proposals, where proposals with predicted
506+ likelihood above this threshold are accepted. The default is 0.7.
507+ anisotropy : List[float], optional
508+ ...
500509 batch_size : int, optional
501- Number of proposals to generate features and classify per batch.
502- The default is the global varaible "BATCH_SIZE".
503- confidence_threshold : float, optional
504- Threshold on acceptance probability for proposals. The default is
505- the global variable "CONFIDENCE_THRESHOLD".
510+ Number of proposals to classify in each batch.The default is 2000.
506511 multiscale : int, optional
507512 Level in the image pyramid that voxel coordinates must index into.
508513 The default is 1.
514+ labels_path : str or None, optional
515+ ...
516+ is_multimodal : bool, optional
517+ ...
509518
510519 Returns
511520 -------
@@ -517,7 +526,7 @@ def __init__(
517526 self .device = "cpu" if device is None else device
518527 self .is_gnn = True if "Graph" in model_type else False
519528 self .radius = radius
520- self .threshold = confidence_threshold
529+ self .threshold = accept_threshold
521530
522531 # Features
523532 self .feature_generator = FeatureGenerator (
0 commit comments