Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/deep_neurographs/fragment_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@
"""

from collections import defaultdict
from tqdm import tqdm

import networkx as nx
import numpy as np
from tqdm import tqdm

from deep_neurographs import geometry
from deep_neurographs.utils import util

QUERY_DIST = 15

Expand Down Expand Up @@ -46,15 +47,14 @@ def remove_curvy(fragments_graph, max_length, ratio=0.5):

"""
deleted_ids = set()
components = get_line_components(fragments_graph)
for nodes in tqdm(components, desc="Filter Curvy Fragments"):
for nodes in get_line_components(fragments_graph):
i, j = tuple(nodes)
length = fragments_graph.edges[i, j]["length"]
endpoint_dist = fragments_graph.dist(i, j)
if endpoint_dist / length < ratio and length < max_length:
deleted_ids.add(fragments_graph.edges[i, j]["swc_id"])
delete_fragment(fragments_graph, i, j)
return len(deleted_ids)
return util.reformat_number(len(deleted_ids))


# --- Doubles Removal ---
Expand Down Expand Up @@ -96,7 +96,7 @@ def remove_doubles(fragments_graph, max_length, node_spacing):
if check_doubles_criteria(hits, n_points):
delete_fragment(fragments_graph, i, j)
deleted_ids.add(swc_id)
return len(deleted_ids)
return util.reformat_number(len(deleted_ids))


def compute_projections(fragments_graph, kdtree, edge):
Expand Down
85 changes: 47 additions & 38 deletions src/deep_neurographs/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,6 @@
from deep_neurographs.utils.gnn_util import toCPU
from deep_neurographs.utils.graph_util import GraphLoader

BATCH_SIZE = 2000
CONFIDENCE_THRESHOLD = 0.7


class InferencePipeline:
"""
Expand Down Expand Up @@ -132,9 +129,9 @@ def __init__(
self.model_path,
self.ml_config.model_type,
self.graph_config.search_radius,
accept_threshold=self.ml_config.threshold,
anisotropy=self.ml_config.anisotropy,
batch_size=self.ml_config.batch_size,
confidence_threshold=self.ml_config.threshold,
device=device,
multiscale=self.ml_config.multiscale,
labels_path=labels_path,
Expand Down Expand Up @@ -178,21 +175,27 @@ def run(self, fragments_pointer):
# Finish
self.report("Final Graph...")
self.report_graph()

t, unit = util.time_writer(time() - t0)
self.report(f"Total Runtime: {round(t, 4)} {unit}\n")

def run_schedule(self, fragments_pointer, radius_schedule):
t0 = time()
# Initializations
self.log_experiment()
self.write_metadata()
t0 = time()

# Main
self.build_graph(fragments_pointer)
for round_id, radius in enumerate(radius_schedule):
self.report(f"--- Round {round_id + 1}: Radius = {radius} ---")
round_id += 1
self.report(f"--- Round {round_id}: Radius = {radius} ---")
self.generate_proposals(radius)
self.run_inference()
self.save_results(round_id=round_id)

# Finish
self.report("Final Graph...")
self.report_graph()
t, unit = util.time_writer(time() - t0)
self.report(f"Total Runtime: {round(t, 4)} {unit}\n")

Expand All @@ -212,7 +215,7 @@ def build_graph(self, fragments_pointer):
None

"""
self.report("(1) Building FragmentGraph")
self.report("Step 1: Building FragmentGraph")
t0 = time()

# Initialize Graph
Expand All @@ -233,31 +236,27 @@ def build_graph(self, fragments_pointer):
self.graph.save_labels(labels_path)
self.report(f"# SWCs Saved: {n_saved}")

# Report runtime
# Report results
t, unit = util.time_writer(time() - t0)
self.report(f"Module Runtime: {round(t, 4)} {unit}")

# Report graph overview
self.report("\nInitial Graph...")
self.report_graph()

def filter_fragments(self):
# Filter curvy fragments
# Curvy fragments
n_curvy = fragment_filtering.remove_curvy(self.graph, 200)
n_curvy = util.reformat_number(n_curvy)

# Filter doubles
# Double fragments
if self.graph_config.remove_doubles_bool:
n_doubles = fragment_filtering.remove_doubles(
self.graph, 200, self.graph_config.node_spacing
)
n_doubles = util.reformat_number(n_doubles)
self.report(f"# Double Fragments Deleted: {n_doubles}")
self.report(f"# Curvy Fragments Deleted: {n_curvy}")

def generate_proposals(self, radius=None):
"""
Generates proposals for the fragment graph based on the specified
Generates proposals for the fragments graph based on the specified
configuration.

Parameters
Expand All @@ -270,7 +269,7 @@ def generate_proposals(self, radius=None):

"""
# Initializations
self.report("(2) Generate Proposals")
self.report("Step 2: Generate Proposals")
if radius is None:
radius = self.graph_config.search_radius

Expand Down Expand Up @@ -307,17 +306,21 @@ def run_inference(self):
None

"""
self.report("(3) Run Inference")
# Initializations
self.report("Step 3: Run Inference")
proposals = self.graph.list_proposals()
n_proposals = max(len(proposals), 1)

# Main
t0 = time()
n_proposals = max(self.graph.n_proposals(), 1)
self.graph, accepts = self.inference_engine.run(
self.graph, self.graph.list_proposals()
)
self.graph, accepts = self.inference_engine.run(self.graph, proposals)
self.accepted_proposals.extend(accepts)
self.report(f"# Accepted: {util.reformat_number(len(accepts))}")
self.report(f"% Accepted: {round(len(accepts) / n_proposals, 4)}")

# Report results
t, unit = util.time_writer(time() - t0)
n_accepts = len(self.accepted_proposals)
self.report(f"# Accepted: {util.reformat_number(n_accepts)}")
self.report(f"% Accepted: {round(n_accepts / n_proposals, 4)}")
self.report(f"Module Runtime: {round(t, 4)} {unit}\n")

def save_results(self, round_id=None):
Expand All @@ -334,15 +337,15 @@ def save_results(self, round_id=None):
None

"""
# Save result locally
# Save result on local machine
suffix = f"-{round_id}" if round_id else ""
filename = f"corrected-processed-swcs{suffix}.zip"
path = os.path.join(self.output_dir, filename)
self.graph.to_zipped_swcs(path)
self.save_connections(round_id=round_id)
self.write_metadata()

# Save result on s3
# Save result on s3 (if applicable)
filename = f"corrected-processed-swcs-s3.zip"
path = os.path.join(self.output_dir, filename)
self.graph.to_zipped_swcs(path, min_size=50)
Expand Down Expand Up @@ -373,7 +376,8 @@ def save_to_s3(self):
# --- io ---
def save_connections(self, round_id=None):
"""
Saves predicted connections between connected components in a txt file.
Writes the accepted proposals from the graph to a text file. Each line
contains the two swc ids as comma separated values.

Parameters
----------
Expand Down Expand Up @@ -414,7 +418,7 @@ def write_metadata(self):
"long_range_bool": self.graph_config.long_range_bool,
"proposals_per_leaf": self.graph_config.proposals_per_leaf,
"search_radius": f"{self.graph_config.search_radius}um",
"confidence_threshold": self.ml_config.threshold,
"accept_threshold": self.ml_config.threshold,
"node_spacing": self.graph_config.node_spacing,
"remove_doubles": self.graph_config.remove_doubles_bool,
}
Expand Down Expand Up @@ -475,9 +479,9 @@ def __init__(
model_path,
model_type,
radius,
accept_threshold=0.7,
anisotropy=[1.0, 1.0, 1.0],
batch_size=BATCH_SIZE,
confidence_threshold=CONFIDENCE_THRESHOLD,
batch_size=2000,
device=None,
multiscale=1,
labels_path=None,
Expand All @@ -490,22 +494,27 @@ def __init__(
Parameters
----------
img_path : str
Path to image stored in a GCS bucket.
Path to image.
model_path : str
Path to machine learning model parameters.
Path to machine learning model weights.
model_type : str
Type of machine learning model used to perform inference.
radius : float
Search radius used to generate proposals.
accept_threshold : float, optional
Threshold for accepting proposals, where proposals with predicted
likelihood above this threshold are accepted. The default is 0.7.
anisotropy : List[float], optional
...
batch_size : int, optional
Number of proposals to generate features and classify per batch.
The default is the global varaible "BATCH_SIZE".
confidence_threshold : float, optional
Threshold on acceptance probability for proposals. The default is
the global variable "CONFIDENCE_THRESHOLD".
Number of proposals to classify in each batch.The default is 2000.
multiscale : int, optional
Level in the image pyramid that voxel coordinates must index into.
The default is 1.
labels_path : str or None, optional
...
is_multimodal : bool, optional
...

Returns
-------
Expand All @@ -517,7 +526,7 @@ def __init__(
self.device = "cpu" if device is None else device
self.is_gnn = True if "Graph" in model_type else False
self.radius = radius
self.threshold = confidence_threshold
self.threshold = accept_threshold

# Features
self.feature_generator = FeatureGenerator(
Expand Down
Loading
Loading