Skip to content

Commit 6f861c4

Browse files
anna-grimanna-grim
andauthored
Optimize graph loader (#291)
* refactor: optimized graph loader * refactor: improve inference clarity --------- Co-authored-by: anna-grim <anna.grim@alleninstitute.org>
1 parent 0aa0608 commit 6f861c4

File tree

2 files changed

+52
-43
lines changed

2 files changed

+52
-43
lines changed

src/deep_neurographs/fragment_filtering.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,13 @@
1010
"""
1111

1212
from collections import defaultdict
13+
from tqdm import tqdm
1314

1415
import networkx as nx
1516
import numpy as np
16-
from tqdm import tqdm
1717

1818
from deep_neurographs import geometry
19+
from deep_neurographs.utils import util
1920

2021
QUERY_DIST = 15
2122

@@ -46,15 +47,14 @@ def remove_curvy(fragments_graph, max_length, ratio=0.5):
4647
4748
"""
4849
deleted_ids = set()
49-
components = get_line_components(fragments_graph)
50-
for nodes in tqdm(components, desc="Filter Curvy Fragments"):
50+
for nodes in get_line_components(fragments_graph):
5151
i, j = tuple(nodes)
5252
length = fragments_graph.edges[i, j]["length"]
5353
endpoint_dist = fragments_graph.dist(i, j)
5454
if endpoint_dist / length < ratio and length < max_length:
5555
deleted_ids.add(fragments_graph.edges[i, j]["swc_id"])
5656
delete_fragment(fragments_graph, i, j)
57-
return len(deleted_ids)
57+
return util.reformat_number(len(deleted_ids))
5858

5959

6060
# --- Doubles Removal ---
@@ -96,7 +96,7 @@ def remove_doubles(fragments_graph, max_length, node_spacing):
9696
if check_doubles_criteria(hits, n_points):
9797
delete_fragment(fragments_graph, i, j)
9898
deleted_ids.add(swc_id)
99-
return len(deleted_ids)
99+
return util.reformat_number(len(deleted_ids))
100100

101101

102102
def compute_projections(fragments_graph, kdtree, edge):

src/deep_neurographs/inference.py

Lines changed: 47 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,6 @@
2929
from deep_neurographs.utils.gnn_util import toCPU
3030
from deep_neurographs.utils.graph_util import GraphLoader
3131

32-
BATCH_SIZE = 2000
33-
CONFIDENCE_THRESHOLD = 0.7
34-
3532

3633
class InferencePipeline:
3734
"""
@@ -132,9 +129,9 @@ def __init__(
132129
self.model_path,
133130
self.ml_config.model_type,
134131
self.graph_config.search_radius,
132+
accept_threshold=self.ml_config.threshold,
135133
anisotropy=self.ml_config.anisotropy,
136134
batch_size=self.ml_config.batch_size,
137-
confidence_threshold=self.ml_config.threshold,
138135
device=device,
139136
multiscale=self.ml_config.multiscale,
140137
labels_path=labels_path,
@@ -178,21 +175,27 @@ def run(self, fragments_pointer):
178175
# Finish
179176
self.report("Final Graph...")
180177
self.report_graph()
181-
182178
t, unit = util.time_writer(time() - t0)
183179
self.report(f"Total Runtime: {round(t, 4)} {unit}\n")
184180

185181
def run_schedule(self, fragments_pointer, radius_schedule):
186-
t0 = time()
182+
# Initializations
187183
self.log_experiment()
184+
self.write_metadata()
185+
t0 = time()
186+
187+
# Main
188188
self.build_graph(fragments_pointer)
189189
for round_id, radius in enumerate(radius_schedule):
190-
self.report(f"--- Round {round_id + 1}: Radius = {radius} ---")
191190
round_id += 1
191+
self.report(f"--- Round {round_id}: Radius = {radius} ---")
192192
self.generate_proposals(radius)
193193
self.run_inference()
194194
self.save_results(round_id=round_id)
195195

196+
# Finish
197+
self.report("Final Graph...")
198+
self.report_graph()
196199
t, unit = util.time_writer(time() - t0)
197200
self.report(f"Total Runtime: {round(t, 4)} {unit}\n")
198201

@@ -212,7 +215,7 @@ def build_graph(self, fragments_pointer):
212215
None
213216
214217
"""
215-
self.report("(1) Building FragmentGraph")
218+
self.report("Step 1: Building FragmentGraph")
216219
t0 = time()
217220

218221
# Initialize Graph
@@ -233,31 +236,27 @@ def build_graph(self, fragments_pointer):
233236
self.graph.save_labels(labels_path)
234237
self.report(f"# SWCs Saved: {n_saved}")
235238

236-
# Report runtime
239+
# Report results
237240
t, unit = util.time_writer(time() - t0)
238241
self.report(f"Module Runtime: {round(t, 4)} {unit}")
239-
240-
# Report graph overview
241242
self.report("\nInitial Graph...")
242243
self.report_graph()
243244

244245
def filter_fragments(self):
245-
# Filter curvy fragments
246+
# Curvy fragments
246247
n_curvy = fragment_filtering.remove_curvy(self.graph, 200)
247-
n_curvy = util.reformat_number(n_curvy)
248248

249-
# Filter doubles
249+
# Double fragments
250250
if self.graph_config.remove_doubles_bool:
251251
n_doubles = fragment_filtering.remove_doubles(
252252
self.graph, 200, self.graph_config.node_spacing
253253
)
254-
n_doubles = util.reformat_number(n_doubles)
255254
self.report(f"# Double Fragments Deleted: {n_doubles}")
256255
self.report(f"# Curvy Fragments Deleted: {n_curvy}")
257256

258257
def generate_proposals(self, radius=None):
259258
"""
260-
Generates proposals for the fragment graph based on the specified
259+
Generates proposals for the fragments graph based on the specified
261260
configuration.
262261
263262
Parameters
@@ -270,7 +269,7 @@ def generate_proposals(self, radius=None):
270269
271270
"""
272271
# Initializations
273-
self.report("(2) Generate Proposals")
272+
self.report("Step 2: Generate Proposals")
274273
if radius is None:
275274
radius = self.graph_config.search_radius
276275

@@ -307,17 +306,21 @@ def run_inference(self):
307306
None
308307
309308
"""
310-
self.report("(3) Run Inference")
309+
# Initializations
310+
self.report("Step 3: Run Inference")
311+
proposals = self.graph.list_proposals()
312+
n_proposals = max(len(proposals), 1)
313+
314+
# Main
311315
t0 = time()
312-
n_proposals = max(self.graph.n_proposals(), 1)
313-
self.graph, accepts = self.inference_engine.run(
314-
self.graph, self.graph.list_proposals()
315-
)
316+
self.graph, accepts = self.inference_engine.run(self.graph, proposals)
316317
self.accepted_proposals.extend(accepts)
317-
self.report(f"# Accepted: {util.reformat_number(len(accepts))}")
318-
self.report(f"% Accepted: {round(len(accepts) / n_proposals, 4)}")
319318

319+
# Report results
320320
t, unit = util.time_writer(time() - t0)
321+
n_accepts = len(self.accepted_proposals)
322+
self.report(f"# Accepted: {util.reformat_number(n_accepts)}")
323+
self.report(f"% Accepted: {round(n_accepts / n_proposals, 4)}")
321324
self.report(f"Module Runtime: {round(t, 4)} {unit}\n")
322325

323326
def save_results(self, round_id=None):
@@ -334,15 +337,15 @@ def save_results(self, round_id=None):
334337
None
335338
336339
"""
337-
# Save result locally
340+
# Save result on local machine
338341
suffix = f"-{round_id}" if round_id else ""
339342
filename = f"corrected-processed-swcs{suffix}.zip"
340343
path = os.path.join(self.output_dir, filename)
341344
self.graph.to_zipped_swcs(path)
342345
self.save_connections(round_id=round_id)
343346
self.write_metadata()
344347

345-
# Save result on s3
348+
# Save result on s3 (if applicable)
346349
filename = f"corrected-processed-swcs-s3.zip"
347350
path = os.path.join(self.output_dir, filename)
348351
self.graph.to_zipped_swcs(path, min_size=50)
@@ -373,7 +376,8 @@ def save_to_s3(self):
373376
# --- io ---
374377
def save_connections(self, round_id=None):
375378
"""
376-
Saves predicted connections between connected components in a txt file.
379+
Writes the accepted proposals from the graph to a text file. Each line
380+
contains the two swc ids as comma separated values.
377381
378382
Parameters
379383
----------
@@ -414,7 +418,7 @@ def write_metadata(self):
414418
"long_range_bool": self.graph_config.long_range_bool,
415419
"proposals_per_leaf": self.graph_config.proposals_per_leaf,
416420
"search_radius": f"{self.graph_config.search_radius}um",
417-
"confidence_threshold": self.ml_config.threshold,
421+
"accept_threshold": self.ml_config.threshold,
418422
"node_spacing": self.graph_config.node_spacing,
419423
"remove_doubles": self.graph_config.remove_doubles_bool,
420424
}
@@ -475,9 +479,9 @@ def __init__(
475479
model_path,
476480
model_type,
477481
radius,
482+
accept_threshold=0.7,
478483
anisotropy=[1.0, 1.0, 1.0],
479-
batch_size=BATCH_SIZE,
480-
confidence_threshold=CONFIDENCE_THRESHOLD,
484+
batch_size=2000,
481485
device=None,
482486
multiscale=1,
483487
labels_path=None,
@@ -490,22 +494,27 @@ def __init__(
490494
Parameters
491495
----------
492496
img_path : str
493-
Path to image stored in a GCS bucket.
497+
Path to image.
494498
model_path : str
495-
Path to machine learning model parameters.
499+
Path to machine learning model weights.
496500
model_type : str
497501
Type of machine learning model used to perform inference.
498502
radius : float
499503
Search radius used to generate proposals.
504+
accept_threshold : float, optional
505+
Threshold for accepting proposals, where proposals with predicted
506+
likelihood above this threshold are accepted. The default is 0.7.
507+
anisotropy : List[float], optional
508+
...
500509
batch_size : int, optional
501-
Number of proposals to generate features and classify per batch.
502-
The default is the global varaible "BATCH_SIZE".
503-
confidence_threshold : float, optional
504-
Threshold on acceptance probability for proposals. The default is
505-
the global variable "CONFIDENCE_THRESHOLD".
510+
Number of proposals to classify in each batch.The default is 2000.
506511
multiscale : int, optional
507512
Level in the image pyramid that voxel coordinates must index into.
508513
The default is 1.
514+
labels_path : str or None, optional
515+
...
516+
is_multimodal : bool, optional
517+
...
509518
510519
Returns
511520
-------
@@ -517,7 +526,7 @@ def __init__(
517526
self.device = "cpu" if device is None else device
518527
self.is_gnn = True if "Graph" in model_type else False
519528
self.radius = radius
520-
self.threshold = confidence_threshold
529+
self.threshold = accept_threshold
521530

522531
# Features
523532
self.feature_generator = FeatureGenerator(

0 commit comments

Comments
 (0)