Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions spine/io/parse/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@

import numpy as np

from spine.math.cluster import DBSCAN
from spine.math.distance import METRICS
from spine.math.cluster import dbscan

from spine.data import Meta

Expand Down Expand Up @@ -190,8 +191,8 @@ def __init__(self, dtype, particle_event=None, add_particle_info=False,

# Intialize DBSCAN if the clusters are to be broken up
self.break_clusters = break_clusters
if break_clusters:
self.dbscan = DBSCAN(break_eps, min_samples=1, metric=break_metric)
self.break_eps = break_eps
self.break_metric_id = METRICS[break_metric]

# Intialize the sparse and particle parsers
self.sparse_parser = Sparse3DParser(dtype, sparse_event='dummy')
Expand Down Expand Up @@ -337,7 +338,9 @@ def process(self, cluster_event, particle_event=None,

# If requested, break cluster into detached pieces
if self.break_clusters:
frag_labels = self.dbscan.fit_predict(voxels)
frag_labels = dbscan(
voxels, eps=self.break_eps, min_samples=1,
metric_id=self.break_metric_id)
features[1] = id_offset + frag_labels
id_offset += max(frag_labels) + 1

Expand Down
5 changes: 4 additions & 1 deletion spine/post/reco/direction.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,4 +91,7 @@ def process(self, data):

# Assign directions to the appropriate particles
for i, part_id in enumerate(part_ids):
setattr(data[k][part_id], attrs[i], dirs[i])
if attrs[i].startswith('start'):
setattr(data[k][part_id], attrs[i], dirs[i])
else:
setattr(data[k][part_id], attrs[i], -dirs[i])
2 changes: 1 addition & 1 deletion spine/utils/gnn/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,7 @@ def _get_cluster_closest_primary_label(data: nb.float64[:,:],
# Minimize the point-cluster distances
dists = np.empty(len(group_index), dtype=data.dtype)
for i, c in enumerate(group_index):
dists[i] = np.min(sm.cdist(start_point, voxels[clusts[c]]))
dists[i] = np.min(sm.distance.cdist(start_point, voxels[clusts[c]]))

# Label the closest cluster as the only primary cluster
labels[group_index] = 0
Expand Down
Loading