Skip to content

Commit 88070df

Browse files
anna-grimanna-grim
andauthored
refactor: optimized split corr runtime (#657)
* refactor: optimized split corr runtime * remove debug line --------- Co-authored-by: anna-grim <anna.grim@alleninstitute.org>
1 parent da0a3e6 commit 88070df

File tree

5 files changed

+41
-21
lines changed

5 files changed

+41
-21
lines changed

src/neuron_proofreader/merge_proofreading/merge_inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def predict(self, x_nodes):
103103
numpy.ndarray
104104
Predicted merge site likelihoods.
105105
"""
106-
with torch.no_grad():
106+
with torch.inference_mode():
107107
x_nodes = x_nodes.to(self.device)
108108
y_nodes = sigmoid(self.model(x_nodes))
109109
return np.squeeze(ml_util.to_cpu(y_nodes, to_numpy=True), axis=1)

src/neuron_proofreader/proposal_graph.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -645,12 +645,11 @@ def edge_attr(self, i, key="xyz", ignore=False):
645645
return attrs
646646

647647
def edge_length(self, edge):
648-
length = 0
649-
for i in range(1, len(self.edges[edge]["xyz"])):
650-
length += geometry.dist(
651-
self.edges[edge]["xyz"][i], self.edges[edge]["xyz"][i - 1]
652-
)
653-
return length
648+
xyz = self.edges[edge]["xyz"]
649+
if len(xyz) < 2:
650+
return 0.0
651+
else:
652+
return np.linalg.norm(xyz[1:] - xyz[:-1], axis=1).sum()
654653

655654
def find_fragments_near_xyz(self, query_xyz, max_dist):
656655
hits = dict()

src/neuron_proofreader/split_proofreading/split_feature_extraction.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
"""
1111

1212
from concurrent.futures import ThreadPoolExecutor, as_completed
13+
from skimage.transform import resize
14+
from time import time
1315
from torch_geometric.data import HeteroData
1416

1517
import numpy as np
@@ -264,6 +266,7 @@ def __call__(self, subgraph, features):
264266
for thread in as_completed(pending.keys()):
265267
proposal = pending.pop(thread)
266268
extractor = thread.result()
269+
267270
profiles[proposal] = extractor.get_intensity_profile()
268271
patches[proposal] = extractor.get_input_patch()
269272

@@ -413,7 +416,7 @@ def get_input_patch(self):
413416
raw image data and channel 1 contains segmentation data.
414417
"""
415418
img = img_util.resize(self.img, self.patch_shape)
416-
mask = img_util.resize(self.mask, self.patch_shape, True)
419+
mask = resize_segmentation(self.mask, self.patch_shape)
417420
return np.stack([img, mask], axis=0)
418421

419422
def get_intensity_profile(self):
@@ -954,3 +957,29 @@ def get_feature_dict():
954957
proposals.
955958
"""
956959
return {"branch": 2, "proposal": 70}
960+
961+
962+
def resize_segmentation(mask, new_shape):
963+
"""
964+
Resizes a segmentation mask to the given new shape.
965+
966+
Parameters
967+
----------
968+
mask : numpy.ndarray
969+
Segmentation mask to be resized.
970+
new_shape : Tuple[int]
971+
New shape of segmentation mask.
972+
973+
Returns
974+
-------
975+
mask : numpy.ndarray
976+
Resized segmentation mask.
977+
"""
978+
mask = resize(
979+
mask,
980+
new_shape,
981+
order=0,
982+
preserve_range=True,
983+
anti_aliasing=False,
984+
).astype(mask.dtype)
985+
return mask

src/neuron_proofreader/split_proofreading/split_inference.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
"""
3030

3131
from time import time
32-
from torch.nn.functional import sigmoid
3332
from tqdm import tqdm
3433

3534
import networkx as nx
@@ -327,10 +326,11 @@ def predict(self, data):
327326
Dictionary that maps proposal IDs to model predictions.
328327
"""
329328
# Generate predictions
330-
with torch.no_grad():
329+
with torch.inference_mode():
331330
device = self.config.ml.device
332331
x = data.get_inputs().to(device)
333-
hat_y = sigmoid(self.model(x))
332+
with torch.cuda.amp.autocast(enabled=True):
333+
hat_y = torch.sigmoid(self.model(x))
334334

335335
# Reformat predictions
336336
idx_to_id = data.idxs_proposals.idx_to_id

src/neuron_proofreader/utils/img_util.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -730,7 +730,7 @@ def remove_small_segments(segmentation, min_size):
730730
return segmentation
731731

732732

733-
def resize(img, new_shape, is_segmentation=False):
733+
def resize(img, new_shape):
734734
"""
735735
Resize a 3D image to the specified new shape using linear interpolation.
736736
@@ -740,22 +740,14 @@ def resize(img, new_shape, is_segmentation=False):
740740
Input 3D image array with shape (depth, height, width).
741741
new_shape : Tuple[int]
742742
Desired output shape as (new_depth, new_height, new_width).
743-
is_segmentation : bool, optional
744-
Indication of whether the image represents a segmentation mask.
745743
746744
Returns
747745
-------
748746
numpy.ndarray
749747
Resized 3D image with shape equal to "new_shape".
750748
"""
751-
# Set parameters
752-
order = 0 if is_segmentation else 3
753-
multiplier = 4 if is_segmentation else 1
754749
zoom_factors = np.array(new_shape) / np.array(img.shape)
755-
756-
# Resize image
757-
img = zoom(multiplier * img, zoom_factors, order=order)
758-
return img / multiplier
750+
return zoom(img, zoom_factors, order=1, prefilter=False)
759751

760752

761753
def to_physical(voxel, anisotropy, offset=(0, 0, 0)):

0 commit comments

Comments
 (0)