diff --git a/README.md b/README.md index 9708a4cb8..cadc36f91 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ [![Build Status](https://app.travis-ci.com/francois-drielsma/lartpc_mlreco3d.svg?token=WB4oxAv87vEXhuxUGH7e&branch=develop&status=passed)](https://app.travis-ci.com/github/francois-drielsma/lartpc_mlreco3d/logscans?serverType=git) [![Documentation Status](https://readthedocs.org/projects/lartpc-mlreco3d/badge/?version=latest)](https://lartpc-mlreco3d.readthedocs.io/en/latest/?badge=latest) -The Scalable Particle Imaging with Neural Embeddings (SPINE) package leverages state-of-the-art Machine Learning (ML) algorithms -- in particular Deep Neural Networks (DNNs) -- to reconstruct particle imagaging detector data. This package was primarily developed for Liquid Argon Time-Projection Chamber (LArTPC) data and relies on Convolutional Neural Networks (CNNs) for pixel-level feature extraction and Graph Neural Networks (GNNs) for superstructure formation. The schematic below breaks down the full end-to-end reconstruction flow. +The Scalable Particle Imaging with Neural Embeddings (SPINE) package leverages state-of-the-art Machine Learning (ML) algorithms -- in particular Deep Neural Networks (DNNs) -- to reconstruct particle imaging detector data. This package was primarily developed for Liquid Argon Time-Projection Chamber (LArTPC) data and relies on Convolutional Neural Networks (CNNs) for pixel-level feature extraction and Graph Neural Networks (GNNs) for superstructure formation. The schematic below breaks down the full end-to-end reconstruction flow. ![Full chain](https://github.com/DeepLearnPhysics/spine/blob/develop/docs/source/_static/img/spine-chain-alpha.png) diff --git a/spine/ana/diag/track.py b/spine/ana/diag/track.py index 0bc721be0..8addae30f 100644 --- a/spine/ana/diag/track.py +++ b/spine/ana/diag/track.py @@ -1,13 +1,12 @@ """Module to evaluate diagnostic metrics on tracks.""" import numpy as np -from scipy.spatial.distance import cdist -from spine.ana.base import AnaBase +from spine.math.distance import cdist from spine.utils.globals import TRACK_SHP -from spine.utils.numba_local import principal_components +from spine.ana.base import AnaBase __all__ = ['TrackCompletenessAna'] @@ -142,10 +141,11 @@ def cluster_track_chunks(points, start_point, end_point, pixel_size): """ # Project and cluster on the projected axis direction = (end_point-start_point)/np.linalg.norm(end_point-start_point) + scale = pixel_size*np.max(direction) projs = np.dot(points - start_point, direction) perm = np.argsort(projs) seps = projs[perm][1:] - projs[perm][:-1] - breaks = np.where(seps > pixel_size*1.1)[0] + 1 + breaks = np.where(seps > scale*1.1)[0] + 1 cluster_labels = np.empty(len(projs), dtype=int) for i, index in enumerate(np.split(np.arange(len(projs)), breaks)): cluster_labels[perm[index]] = i diff --git a/spine/ana/metric/segment.py b/spine/ana/metric/segment.py index 218d797e5..ac6d20476 100644 --- a/spine/ana/metric/segment.py +++ b/spine/ana/metric/segment.py @@ -111,7 +111,7 @@ def process(self, data): if self.ghost: # If there are ghost, must combine the predictions full_seg_pred = np.full_like(seg_label, GHOST_SHP, dtype=np.int32) - deghost_mask = data['ghost'][:, 0] > data['ghost'][:, 1] + deghost_mask = np.argmax(data['ghost'], axis=1) == 0 full_seg_pred[deghost_mask] = seg_pred seg_pred = full_seg_pred diff --git a/spine/build/base.py b/spine/build/base.py index 16f2cbfe3..5015f88a3 100644 --- a/spine/build/base.py +++ b/spine/build/base.py @@ -38,13 +38,13 @@ class BuilderBase(ABC): # Necessary/optional data products to load a reconstructed object _load_reco_keys = ( - ('points', True), ('depositions', True), ('sources', False) + ('points', False), ('depositions', False), ('sources', False) ) # Necessary/optional data products to load a truth object _load_truth_keys = ( - ('points_label', True), ('points', False), ('points_g4', False), - ('depositions_label', True), ('depositions', False), + ('points_label', False), ('points', False), ('points_g4', False), + ('depositions_label', False), ('depositions', False), ('depositions_q_label', False), ('depositions_g4', False), ('sources_label', False), ('sources', False) ) diff --git a/spine/build/fragment.py b/spine/build/fragment.py index 93b2c1adc..2dca59747 100644 --- a/spine/build/fragment.py +++ b/spine/build/fragment.py @@ -152,7 +152,7 @@ def build_truth(self, data): """ return self._build_truth(**data) - def _build_truth(self, label_tensor, points_label, depositions_label, + def _build_truth(self, label_tensor, points_label, depositions_label, depositions_q_label=None, label_adapt_tensor=None, points=None, depositions=None, label_g4_tensor=None, points_g4=None, depositions_g4=None, sources_label=None, @@ -287,16 +287,17 @@ def load_reco(self, data): """ return self._load_reco(**data) - def _load_reco(self, reco_fragments, points, depositions, sources=None): + def _load_reco(self, reco_fragments, points=None, depositions=None, + sources=None): """Load :class:`RecoFragment` objects from their stored versions. Parameters ---------- reco_fragments : List[RecoFragment] (F) List of partial reconstructed fragments - points : np.ndarray + points : np.ndarray, optional (N, 3) Set of deposition coordinates in the image - depositions : np.ndarray + depositions : np.ndarray, optional (N) Set of deposition values sources : np.ndarray, optional (N, 2) Tensor which contains the module/tpc information @@ -313,10 +314,11 @@ def _load_reco(self, reco_fragments, points, depositions, sources=None): "The ordering of the stored fragments is wrong.") # Update the fragment with its long-form attributes - fragment.points = points[fragment.index] - fragment.depositions = depositions[fragment.index] - if sources is not None: - fragment.sources = sources[fragment.index] + if points is not None: + fragment.points = points[fragment.index] + fragment.depositions = depositions[fragment.index] + if sources is not None: + fragment.sources = sources[fragment.index] return reco_fragments @@ -335,20 +337,20 @@ def load_truth(self, data): """ return self._load_truth(**data) - def _load_truth(self, truth_fragments, points_label, depositions_label, - depositions_q_label=None, points=None, depositions=None, - points_g4=None, depositions_g4=None, sources_label=None, - sources=None): + def _load_truth(self, truth_fragments, points_label=None, + depositions_label=None, depositions_q_label=None, + points=None, depositions=None, points_g4=None, + depositions_g4=None, sources_label=None, sources=None): """Load :class:`TruthFragment` objects from their stored versions. Parameters ---------- truth_fragments : List[TruthFragment] (F) List of partial truth fragments - points_label : np.ndarray + points_label : np.ndarray, optional (N', 3) Set of deposition coordinates in the label image (identical for pixel TPCs, different if deghosting is involved) - depositions_label : np.ndarray + depositions_label : np.ndarray, optional (N') Set of true deposition values in MeV depositions_q_label : np.ndarray, optional (N') Set of true deposition values in ADC, if relevant @@ -377,12 +379,13 @@ def _load_truth(self, truth_fragments, points_label, depositions_label, "The ordering of the stored fragments is wrong.") # Update the fragment with its long-form attributes - fragment.points = points_label[fragment.index] - fragment.depositions = depositions_label[fragment.index] - if depositions_q_label is not None: - fragment.depositions_q = depositions_q_label[fragment.index] - if sources_label is not None: - fragment.sources = sources_label[fragment.index] + if points_label is not None: + fragment.points = points_label[fragment.index] + fragment.depositions = depositions_label[fragment.index] + if depositions_q_label is not None: + fragment.depositions_q = depositions_q_label[fragment.index] + if sources_label is not None: + fragment.sources = sources_label[fragment.index] if points is not None: fragment.points_adapt = points[fragment.index_adapt] diff --git a/spine/build/manager.py b/spine/build/manager.py index 55205f719..21c89e158 100644 --- a/spine/build/manager.py +++ b/spine/build/manager.py @@ -44,7 +44,7 @@ class BuildManager: ) def __init__(self, fragments, particles, interactions, - mode='both', units='cm', sources=None): + mode='both', units='cm', sources=None, lite=False): """Initializes the build manager. Parameters @@ -60,6 +60,9 @@ def __init__(self, fragments, particles, interactions, sources : Dict[str, str], optional Dictionary which maps the necessary data products onto a name in the input/output dictionary of the reconstruction chain. + lite : bool, default False + If `True`, the objects being loaded are lite and do not map + to long-form attributes. Simply load the matches. """ # Check on the mode, store it assert mode in self._run_modes, ( @@ -100,6 +103,9 @@ def __init__(self, fragments, particles, interactions, assert len(self.builders), ( "Do not call the builder unless it does anything.") + # Store whether to load the long-form attributes or not + self.lite = lite + def __call__(self, data): """Build the representations for one entry. @@ -115,7 +121,7 @@ def __call__(self, data): # If this is the first time the builders are called, build # the objects shared between fragments/particles/interactions load = True - if 'points' not in data and 'points_label' not in data: + if not self.lite and 'points' not in data and 'points_label' not in data: load = False if np.isscalar(data['index']): sources = self.build_sources(data) diff --git a/spine/build/particle.py b/spine/build/particle.py index ca7bb9464..0db66825f 100644 --- a/spine/build/particle.py +++ b/spine/build/particle.py @@ -311,16 +311,17 @@ def load_reco(self, data): """ return self._load_reco(**data) - def _load_reco(self, reco_particles, points, depositions, sources=None): + def _load_reco(self, reco_particles, points=None, depositions=None, + sources=None): """Construct :class:`RecoParticle` objects from their stored versions. Parameters ---------- reco_particles : List[RecoParticle] (P) List of partial reconstructed particles - points : np.ndarray + points : np.ndarray, optional (N, 3) Set of deposition coordinates in the image - depositions : np.ndarray + depositions : np.ndarray, optional (N) Set of deposition values sources : np.ndarray, optional (N, 2) Tensor which contains the module/tpc information @@ -337,10 +338,11 @@ def _load_reco(self, reco_particles, points, depositions, sources=None): "The ordering of the stored particles is wrong.") # Update the particle with its long-form attributes - particle.points = points[particle.index] - particle.depositions = depositions[particle.index] - if sources is not None: - particle.sources = sources[particle.index] + if points is not None: + particle.points = points[particle.index] + particle.depositions = depositions[particle.index] + if sources is not None: + particle.sources = sources[particle.index] return reco_particles @@ -354,20 +356,20 @@ def load_truth(self, data): """ return self._load_truth(**data) - def _load_truth(self, truth_particles, points_label, depositions_label, - depositions_q_label=None, points=None, depositions=None, - points_g4=None, depositions_g4=None, sources_label=None, - sources=None): + def _load_truth(self, truth_particles, points_label=None, + depositions_label=None, depositions_q_label=None, + points=None, depositions=None, points_g4=None, + depositions_g4=None, sources_label=None, sources=None): """Construct :class:`TruthParticle` objects from their stored versions. Parameters ---------- truth_particles : List[TruthParticle] (P) List of partial truth particles - points_label : np.ndarray + points_label : np.ndarray, optional (N', 3) Set of deposition coordinates in the label image (identical for pixel TPCs, different if deghosting is involved) - depositions_label : np.ndarray + depositions_label : np.ndarray, optional (N') Set of true deposition values in MeV depositions_q_label : np.ndarray, optional (N') Set of true deposition values in ADC, if relevant @@ -396,12 +398,13 @@ def _load_truth(self, truth_particles, points_label, depositions_label, "The ordering of the stored particles is wrong.") # Update the particle with its long-form attributes - particle.points = points_label[particle.index] - particle.depositions = depositions_label[particle.index] - if depositions_q_label is not None: - particle.depositions_q = depositions_q_label[particle.index] - if sources_label is not None: - particle.sources = sources_label[particle.index] + if points_label is not None: + particle.points = points_label[particle.index] + particle.depositions = depositions_label[particle.index] + if depositions_q_label is not None: + particle.depositions_q = depositions_q_label[particle.index] + if sources_label is not None: + particle.sources = sources_label[particle.index] if points is not None: particle.points_adapt = points[particle.index_adapt] diff --git a/spine/driver.py b/spine/driver.py index 2ff474b80..cc420bf01 100644 --- a/spine/driver.py +++ b/spine/driver.py @@ -23,9 +23,10 @@ from .io import loader_factory, reader_factory, writer_factory from .io.write import CSVWriter +from .math import seed as numba_seed + from .utils.logger import logger from .utils.cuda import set_visible_devices -from .utils.numba_local import seed as numba_seed from .utils.unwrap import Unwrapper from .utils.stopwatch import StopwatchManager @@ -367,7 +368,7 @@ def initialize_io(self, loader=None, reader=None, writer=None): # Fetch the list of previously run post-processors # TODO: this only works with two runs in a row, not 3 and above self.post_list = None - if self.reader.cfg is not None: + if self.reader.cfg is not None and 'post' in self.reader.cfg: self.post_list = tuple(self.reader.cfg['post']) # Fetch an appropriate common prefix for all input files diff --git a/spine/io/parse/cluster.py b/spine/io/parse/cluster.py index 047d5ee44..e699b16be 100644 --- a/spine/io/parse/cluster.py +++ b/spine/io/parse/cluster.py @@ -12,13 +12,14 @@ import numpy as np +from spine.math.cluster import DBSCAN + from spine.data import Meta from spine.utils.globals import DELTA_SHP, SHAPE_PREC from spine.utils.particles import process_particle_event from spine.utils.ppn import image_coordinates from spine.utils.conditional import larcv -from spine.utils.numba_local import dbscan from .base import ParserBase from .sparse import ( @@ -185,11 +186,13 @@ def __init__(self, dtype, particle_event=None, add_particle_info=False, self.type_include_mpr = type_include_mpr self.type_include_secondary = type_include_secondary self.primary_include_mpr = primary_include_mpr - self.break_clusters = break_clusters - self.break_eps = break_eps - self.break_metric = break_metric self.shape_precedence = shape_precedence + # Intialize DBSCAN if the clusters are to be broken up + self.break_clusters = break_clusters + if break_clusters: + self.dbscan = DBSCAN(break_eps, min_samples=1, metric=break_metric) + # Intialize the sparse and particle parsers self.sparse_parser = Sparse3DParser(dtype, sparse_event='dummy') @@ -334,8 +337,7 @@ def process(self, cluster_event, particle_event=None, # If requested, break cluster into detached pieces if self.break_clusters: - frag_labels = dbscan( - voxels, self.break_eps, self.break_metric) + frag_labels = self.dbscan.fit_predict(voxels) features[1] = id_offset + frag_labels id_offset += max(frag_labels) + 1 diff --git a/spine/io/parse/sparse.py b/spine/io/parse/sparse.py index 30f67329f..8adbe65b9 100644 --- a/spine/io/parse/sparse.py +++ b/spine/io/parse/sparse.py @@ -150,7 +150,7 @@ class Sparse3DParser(ParserBase): def __init__(self, dtype, sparse_event=None, sparse_event_list=None, num_features=None, hit_keys=None, nhits_idx=None, - feature_only=False): + feature_only=False, lexsort=False): """Initialize the parser. Parameters @@ -175,6 +175,9 @@ def __init__(self, dtype, sparse_event=None, sparse_event_list=None, (doublet vs triplet) should be inserted. feature_only : bool, default False If `True`, only return the feature vector without the coordinates + lexsort : bool, default False + When merging points from multiple sources (num_features is not + `None`), this allows to lexicographically sort coordinates """ # Initialize the parent class super().__init__( @@ -193,6 +196,10 @@ def __init__(self, dtype, sparse_event=None, sparse_event_list=None, raise ValueError("The argument nhits_idx needs to be specified if " "you want to compute the nhits feature.") + self.lexsort = lexsort + if self.num_features is None and lexsort: + raise ValueError + # Get the number of features in the output tensor assert (sparse_event is not None) ^ (sparse_event_list is not None), ( "Must provide either `sparse_event` or `sparse_event_list`.") @@ -265,7 +272,7 @@ def process(self, sparse_event=None, sparse_event_list=None): if num_points is None: num_points = sparse_event.as_vector().size() - if not self.feature_only: + if not self.feature_only or self.lexsort: np_voxels = np.empty((num_points, 3), dtype=self.itype) larcv.fill_3d_voxels(sparse_event, np_voxels) else: @@ -293,15 +300,27 @@ def process(self, sparse_event=None, sparse_event_list=None): np_features.insert(self.nhits_idx, nhits) # Append to the global list of voxel/features - if not self.feature_only: + if not self.feature_only or self.lexsort: all_voxels.append(np_voxels) all_features.append(np.hstack(np_features)) + # Stack coordinates/features + all_features = np.vstack(all_features) + if not self.feature_only or self.lexsort: + all_voxels = np.vstack(all_voxels) + + # Lexicographically sort coordinates/features, if requested + if self.lexsort: + perm = np.lexsort(all_voxels.T) + all_features = all_features[perm] + if not self.feature_only: + all_voxels = all_voxels[perm] + + # Return if self.feature_only: - return np.vstack(all_features) + return all_features else: - return (np.vstack(all_voxels), np.vstack(all_features), - Meta.from_larcv(meta)) + return all_voxels, all_features, Meta.from_larcv(meta) class Sparse3DAggregateParser(Sparse3DParser): diff --git a/spine/math/__init__.py b/spine/math/__init__.py new file mode 100644 index 000000000..781097fe4 --- /dev/null +++ b/spine/math/__init__.py @@ -0,0 +1,20 @@ +"""Module with fast, Numba-accelerated, compiles math routines. + +This includes multiple submodules: +- `base.py` includes basic functions, as found in numpy or scipy.special +- `linalg.py` includes linear algebra routines, as found in numpy.linalg +- `distance.py` includes distance functions, as found in scipy.distance +- `graph.py` includes graph routines, as found in scipy.csgraph +- `cluster.py` includes cluster functions, as found in skleran.cluster +""" + +# Expose all base functions directly +from .base import * + +# Expose submodules +from . import cluster +from . import decomposition +from . import distance +from . import graph +from . import linalg +from . import neighbors diff --git a/spine/math/base.py b/spine/math/base.py new file mode 100644 index 000000000..86e029dd4 --- /dev/null +++ b/spine/math/base.py @@ -0,0 +1,321 @@ +"""Numba JIT compiled implementation of basic functions. + +Most of these functions are implemented here because vanilla numba does not +support optional arguments, such as `axis` for most functions or +`return_counts` for the `unique` function. +""" + +import numpy as np +import numba as nb + +__all__ = ['seed', 'unique', 'mean', 'mode', 'argmax', 'argmin', 'amax', 'amin', + 'all', 'softmax', 'log_loss'] + + +@nb.njit(cache=True) +def seed(seed: nb.int64) -> None: + """Sets the numpy random seed for all Numba jitted functions. + + Note that setting the seed using `np.random.seed` outside a Numba jitted + function does *not* set the seed of Numba functions. + + Parameters + ---------- + seed : int + Random number generator seed + """ + np.random.seed(seed) + + +@nb.njit(cache=True) +def unique(x: nb.int64[:]) -> (nb.int64[:], nb.int64[:]): + """Numba implementation of `np.unique(x, return_counts=True)`. + + Parameters + ---------- + x : np.ndarray + (N) array of values + + Returns + ------- + np.ndarray + (U) array of unique values + np.ndarray + (U) array of counts of each unique value in the original array + """ + # Nothing to do if the input is empty + uniques = np.empty(len(x), dtype=x.dtype) + counts = np.empty(len(x), dtype=np.int64) + if len(x) == 0: + return uniques, counts + + # Build the list of unique values and counts + x = np.sort(x.flatten()) + uniques[0] = x[0] + idx = 1 + for i in range(len(x) - 1): + if x[i] != x[i+1]: + uniques[idx] = x[i+1] + counts[idx-1] = i + 1 + idx += 1 + + counts[idx-1] = len(x) + + # Narrow vectors down + uniques = uniques[:idx] + counts = counts[:idx] + + # Adjust counts + counts[1:] = counts[1:] - counts[:-1] + + return uniques, counts + + +@nb.njit(cache=True) +def mean(x: nb.float32[:,:], + axis: nb.int32) -> nb.float32[:]: + """Numba implementation of `np.mean(x, axis)`. + + Parameters + ---------- + x : np.ndarray + (N,M) array of values + axis : int + Array axis ID + + Returns + ------- + np.ndarray + (N) or (M) array of `mean` values + """ + assert axis == 0 or axis == 1 + mean = np.empty(x.shape[1-axis], dtype=x.dtype) + if axis == 0: + for i in range(len(mean)): + mean[i] = np.mean(x[:,i]) + else: + for i in range(len(mean)): + mean[i] = np.mean(x[i]) + + return mean + + +@nb.njit(cache=True) +def mode(x: nb.int64[:]) -> nb.int64: + """Numba implementation of `scipy.stats.mode(x)`. + + Parameters + ---------- + x : np.ndarray + (N) array of values + + Returns + ------- + int + Most-propable value in the array + """ + values, counts = unique(x) + + return values[np.argmax(counts)] + + +@nb.njit(cache=True) +def argmin(x: nb.float32[:,:], + axis: nb.int32) -> nb.int32[:]: + """Numba implementation of `np.argmin(x, axis)`. + + Parameters + ---------- + x : np.ndarray + (N,M) array of values + axis : int + Array axis ID + + Returns + ------- + np.ndarray + (N) or (M) array of `argmin` values + """ + assert axis == 0 or axis == 1 + argmin = np.empty(x.shape[1-axis], dtype=np.int32) + if axis == 0: + for i in range(len(argmin)): + argmin[i] = np.argmin(x[:,i]) + else: + for i in range(len(argmin)): + argmin[i] = np.argmin(x[i]) + + return argmin + + +@nb.njit(cache=True) +def argmax(x: nb.float32[:,:], + axis: nb.int32) -> nb.int32[:]: + """Numba implementation of `np.argmax(x, axis)`. + + Parameters + ---------- + x : np.ndarray + (N,M) array of values + axis : int + Array axis ID + + Returns + ------- + np.ndarray + (N) or (M) array of `argmax` values + """ + assert axis == 0 or axis == 1 + argmax = np.empty(x.shape[1-axis], dtype=np.int32) + if axis == 0: + for i in range(len(argmax)): + argmax[i] = np.argmax(x[:,i]) + + else: + for i in range(len(argmax)): + argmax[i] = np.argmax(x[i]) + + return argmax + + +@nb.njit(cache=True) +def amin(x: nb.float32[:,:], + axis: nb.int32) -> nb.float32[:]: + """Numba implementation of `np.amin(x, axis)`. + + Parameters + ---------- + x : np.ndarray + (N,M) array of values + axis : int + Array axis ID + + Returns + ------- + np.ndarray + (N) or (M) array of `min` values + """ + assert axis == 0 or axis == 1 + xmin = np.empty(x.shape[1-axis], dtype=np.int32) + if axis == 0: + for i in range(len(xmin)): + xmin[i] = np.min(x[:, i]) + + else: + for i in range(len(xmin)): + xmin[i] = np.min(x[i]) + + return xmin + + +@nb.njit(cache=True) +def amax(x: nb.float32[:,:], + axis: nb.int32) -> nb.float32[:]: + """Numba implementation of `np.amax(x, axis)`. + + Parameters + ---------- + x : np.ndarray + (N,M) array of values + axis : int + Array axis ID + + Returns + ------- + np.ndarray + (N) or (M) array of `max` values + """ + assert axis == 0 or axis == 1 + xmax = np.empty(x.shape[1-axis], dtype=np.int32) + if axis == 0: + for i in range(len(xmax)): + xmax[i] = np.max(x[:, i]) + + else: + for i in range(len(xmax)): + xmax[i] = np.max(x[i]) + + return xmax + + +@nb.njit(cache=True) +def all(x: nb.float32[:,:], + axis: nb.int32) -> nb.boolean[:]: + """Numba implementation of `np.all(x, axis)`. + + Parameters + ---------- + x : np.ndarray + (N, M) Array of values + axis : int + Array axis ID + + Returns + ------- + np.ndarray + (N) or (M) array of `all` outputs + """ + assert axis == 0 or axis == 1 + all = np.empty(x.shape[1-axis], dtype=np.bool_) + if axis == 0: + for i in range(len(all)): + all[i] = np.all(x[:,i]) + + else: + for i in range(len(all)): + all[i] = np.all(x[i]) + + return all + + +@nb.njit(cache=True) +def softmax(x: nb.float32[:,:], + axis: nb.int32) -> nb.float32[:,:]: + """ + Numba implementation of `scipy.special.softmax(x, axis)`. + + Parameters + ---------- + x : np.ndarray + (N,M) array of values + axis : int + Array axis ID + + Returns + ------- + np.ndarray + (N,M) Array of softmax scores + """ + assert axis == 0 or axis == 1 + if axis == 0: + xmax = amax(x, axis=0) + logsumexp = np.log(np.sum(np.exp(x-xmax), axis=0)) + xmax + return np.exp(x - logsumexp) + else: + xmax = amax(x, axis=1).reshape(-1,1) + logsumexp = np.log(np.sum(np.exp(x-xmax), axis=1)).reshape(-1,1) + xmax + return np.exp(x - logsumexp) + + +@nb.njit(cache=True) +def log_loss(label: nb.boolean[:], + pred: nb.float32[:]) -> nb.float32: + """Numba implementation of cross-entropy loss. + + Parameters + ---------- + label : np.ndarray + (N) array of boolean labels (0 or 1) + pred : np.ndarray + (N) array of float scores (between 0 and 1) + + Returns + ------- + float + Cross-entropy loss + """ + if len(label) > 0: + return -(np.sum(np.log(pred[label])) + + np.sum(np.log(1.-pred[~label])))/len(label) + else: + return 0. diff --git a/spine/math/cluster.py b/spine/math/cluster.py new file mode 100644 index 000000000..fdda68244 --- /dev/null +++ b/spine/math/cluster.py @@ -0,0 +1,129 @@ +"""Numba JIT compiled implementation of clustering routines.""" + +import numba as nb +import numpy as np + +from .distance import METRICS, get_metric_id +from .graph import radius_graph, connected_components + + +DBSCAN_DTYPE = ( + ('eps', nb.float32), + ('min_samples', nb.int64), + ('metric_id', nb.int64), + ('p', nb.float32) +) + + +@nb.experimental.jitclass(DBSCAN_DTYPE) +class DBSCAN: + """Class-version of the Numba-accelerate :func:`dbscan` function. + + Attributes + ---------- + eps : float + Distance scale (determines neighborhood) + min_samples : int + Minimum number of neighbors (including oneself) to be considered + a core point + metric : str + Distance metric to be used to establish neighborhood + """ + + def __init__(self, + eps: nb.float32, + min_samples: nb.int64 = 1, + metric: nb.types.string = 'euclidean', + p: nb.int64 = 2.): + """Initialize the DBSCAN parameters. + + Parameters + ---------- + eps : float + Distance scale (determines neighborhood) + min_samples : int + Minimum number of neighbors (including oneself) to be considered + a core point + metric : str + Distance metric to be used to establish neighborhood + p : float, default 2. + p-norm factor for the Minkowski metric, if used + """ + # For Euclidean, save time by using squared Euclidean + if metric == 'euclidean': + metric = 'sqeuclidean' + eps = eps*eps + + # Store parameters + self.eps = eps + self.min_samples = min_samples + self.metric_id = get_metric_id(metric, p) + self.p = p + + def fit_predict(self, x): + """Runs DBSCAN on 3D points and returns the group assignments. + + + Notes + ----- + The traditional 'min_samples' is always set to 1 here. + + Parameters + ---------- + x : np.ndarray + (N, 3) array of point coordinates + eps : float + Distance below which two points are considered neighbors + min_samples : int + Minimum number of neighbors for a point to be a core point + metric : str, default 'euclidean' + Distance metric used to compute pdist + + Returns + ------- + np.ndarray + (N) Group assignments + """ + # Produce a radius graph + edge_index = radius_graph(x, self.eps, self.metric_id, self.p) + + # Build groups + return connected_components( + edge_index, len(x), self.min_samples, directed=False) + + +@nb.njit(cache=True) +def dbscan(x: nb.float32[:, :], + eps: nb.float32, + min_samples: nb.int64 = 1, + metric_id: nb.int64 = METRICS['euclidean'], + p: nb.float32 = 2.) -> nb.float32[:]: + """Runs DBSCAN on 3D points and returns the group assignments. + + Notes + ----- + The traditional 'min_samples' is always set to 1 here. + + Parameters + ---------- + x : np.ndarray + (N, 3) array of point coordinates + eps : float + Distance below which two points are considered neighbors + min_samples : int + Minimum number of neighbors for a point to be a core point + metric : str, default 'euclidean' + Distance metric used to compute pdist + p : float, default 2. + p-norm factor for the Minkowski metric, if used + + Returns + ------- + np.ndarray + (N) Group assignments + """ + # Produce a radius graph + edge_index = radius_graph(x, eps, metric_id, p) + + # Build groups + return connected_components(edge_index, len(x), min_samples, directed=False) diff --git a/spine/math/decomposition.py b/spine/math/decomposition.py new file mode 100644 index 000000000..e50dd8ca6 --- /dev/null +++ b/spine/math/decomposition.py @@ -0,0 +1,31 @@ +"""Numba JIT compiled implementation of decomposition routines.""" + +import numba as nb +import numpy as np + +__all__ = ['principal_components'] + + +@nb.njit(cache=True) +def principal_components(x: nb.float32[:,:]) -> nb.float32[:,:]: + """Computes the principal components of a point cloud by computing the + eigenvectors of the centered covariance matrix. + + Parameters + ---------- + x : np.ndarray + (N, d) Coordinates in d dimensions + + Returns + ------- + np.ndarray + (d, d) List of principal components (row-ordered) + """ + # Get covariance matrix + A = np.cov(x.T, ddof = len(x) - 1).astype(x.dtype) # Casting needed... + + # Get eigenvectors + _, v = np.linalg.eigh(A) + v = np.ascontiguousarray(np.fliplr(v).T) + + return v diff --git a/spine/math/distance.py b/spine/math/distance.py new file mode 100644 index 000000000..db67e7fe0 --- /dev/null +++ b/spine/math/distance.py @@ -0,0 +1,517 @@ +"""Numba JIT compiled implementation of distance computation routines. + +This module is entirely dedicated to 3D points, which is the core representation +of objects targetted by this software package. +""" + +import numpy as np +import numba as nb + +from .base import mean, argmin + +__all__ = ['cityblock', 'euclidean', 'sqeuclidean', 'minkowski', 'chebyshev', + 'pdist', 'cdist', 'farthest_pair', 'closest_pair'] + +# Available distance metrics (casting is important for numba optimization) +METRICS = { + 'minkowski': np.int64(0), + 'cityblock': np.int64(1), + 'euclidean': np.int64(2), + 'sqeuclidean': np.int64(3), + 'chebyshev': np.int64(4) +} + + +@nb.njit(cache=True) +def get_metric_id(metric: nb.types.string, + p: nb.float32) -> nb.int64: + """Checks on the metric name, returns an enumerated form of the metric. + + Parameters + ---------- + metric : str, default 'euclidean' + Distance metric + p : float + p-norm factor for the Minkowski metric, if used + + Returns + ------- + int + Enumerated form of the distance metric + """ + if metric == 'minkowski': + if p == 1.: + return np.int64(1) + elif p == 2.: + return np.int64(2) + else: + return np.int64(0) + elif metric == 'cityblock': + return np.int64(1) + elif metric == 'euclidean': + return np.int64(2) + elif metric == 'sqeuclidean': + return np.int64(3) + elif metric == 'chebyshev': + return np.int64(4) + else: + raise ValueError(f"Distance metric not recognized: {metric}") + + +@nb.njit(cache=True) +def cityblock(x: nb.float32[:], + y: nb.float32[:]) -> nb.float32: + """Compute the cityblock distance (L1) between to 3D points. + + Parameters + ---------- + x : np.ndarray + (3) Coorinates of the first point + y : np.ndarray + (3) Coorinates of the second point + + Returns + ------- + float + Cityblock distance + """ + return abs(y[0] - x[0]) + abs(y[1] - x[1]) + abs(y[2] - x[2]) + + +@nb.njit(cache=True) +def euclidean(x: nb.float32[:], + y: nb.float32[:]) -> nb.float32: + """Compute the Euclidean distance (L2) between two 3D points. + + Parameters + ---------- + x : np.ndarray + (3) Coorinates of the first point + y : np.ndarray + (3) Coorinates of the second point + + Returns + ------- + float + Euclidean distance + """ + return np.sqrt((y[0] - x[0])**2 + (y[1] - x[1])**2 + (y[2] - x[2])**2) + + +@nb.njit(cache=True) +def sqeuclidean(x: nb.float32[:], + y: nb.float32[:]) -> nb.float32: + """Compute the squared Euclidean distance (L2) between two 3D points. + + Parameters + ---------- + x : np.ndarray + (3) Coorinates of the first point + y : np.ndarray + (3) Coorinates of the second point + + Returns + ------- + float + Squared Euclidean distance + """ + return (y[0] - x[0])**2 + (y[1] - x[1])**2 + (y[2] - x[2])**2 + + +@nb.njit(cache=True) +def chebyshev(x: nb.float32[:], + y: nb.float32[:]) -> nb.float32: + """Compute the Chebyshev distance (Linf) between to 3D points. + + Parameters + ---------- + x : np.ndarray + (3) Coorinates of the first point + y : np.ndarray + (3) Coorinates of the second point + + Returns + ------- + float + Chebyshev distance + """ + return max(abs(y[0] - x[0]), abs(y[1] - x[1]), abs(y[2] - x[2])) + + +@nb.njit(cache=True) +def minkowski(x: nb.float32[:], + y: nb.float32[:], + p: nb.float32) -> nb.float32: + """Compute the Minkowski distance (Lp) between two 3D points. + + Parameters + ---------- + x : np.ndarray + (3) Coorinates of the first point + y : np.ndarray + (3) Coorinates of the second point + + Returns + ------- + float + Minkowski distance + """ + return pow(abs(y[0] - x[0])**p + abs(y[1] - x[1])**p + abs(y[2] - x[2])**p, 1./p) + + +@nb.njit(cache=True) +def pdist(x: nb.float32[:,:], + metric_id: nb.int64 = METRICS['euclidean'], + p: nb.float32 = 2.) -> nb.float32[:,:]: + """Numba implementation of + `scipy.spatial.distance.pdist(x, metric=metric, p=p)` in 3D. + + Parameters + ---------- + x : np.ndarray + (N, 3) array of point coordinates in the set + metric_id : int, default 2 (Euclidean) + Distance metric enumerator + p : float, default 2. + p-norm factor for the Minkowski metric, if used + + Returns + ------- + np.ndarray + (N, N) array of pair-wise Euclidean distances + """ + # Check on the input + assert x.shape[1] == 3, "Only supports 3D points for now." + + # Dispatch (faster this way than dipatching at each distance call) + if metric_id == np.int64(0): + return _pdist_minkowski(x, p) + elif metric_id == np.int64(1): + return _pdist_cityblock(x) + elif metric_id == np.int64(2): + return _pdist_euclidean(x) + elif metric_id == np.int64(3): + return _pdist_sqeuclidean(x) + elif metric_id == np.int64(4): + return _pdist_chebyshev(x) + else: + raise ValueError("Distance metric not recognized.") + +@nb.njit(cache=True) +def _pdist_cityblock(x: nb.float32[:,:]) -> nb.float32[:,:]: + res = np.empty((len(x), len(x)), dtype=x.dtype) + for i in range(len(x)): + res[i, i] = 0. + for j in range(i+1, len(x)): + res[i, j] = res[j, i] = cityblock(x[i], x[j]) + + return res + +@nb.njit(cache=True) +def _pdist_euclidean(x: nb.float32[:,:]) -> nb.float32[:,:]: + res = np.empty((len(x), len(x)), dtype=x.dtype) + for i in range(len(x)): + res[i, i] = 0. + for j in range(i+1, len(x)): + res[i, j] = res[j, i] = euclidean(x[i], x[j]) + + return res + +@nb.njit(cache=True) +def _pdist_sqeuclidean(x: nb.float32[:,:]) -> nb.float32[:,:]: + res = np.empty((len(x), len(x)), dtype=x.dtype) + for i in range(len(x)): + res[i, i] = 0. + for j in range(i+1, len(x)): + res[i, j] = res[j, i] = sqeuclidean(x[i], x[j]) + + return res + +@nb.njit(cache=True) +def _pdist_chebyshev(x: nb.float32[:,:]) -> nb.float32[:,:]: + res = np.empty((len(x), len(x)), dtype=x.dtype) + for i in range(len(x)): + res[i, i] = 0. + for j in range(i+1, len(x)): + res[i, j] = res[j, i] = chebyshev(x[i], x[j]) + + return res + +@nb.njit(cache=True) +def _pdist_minkowski(x: nb.float32[:,:], + p: nb.float32) -> nb.float32[:,:]: + res = np.empty((len(x), len(x)), dtype=x.dtype) + for i in range(len(x)): + res[i, i] = 0. + for j in range(i+1, len(x)): + res[i, j] = res[j, i] = minkowski(x[i], x[j], p) + + return res + + +@nb.njit(cache=True) +def cdist(x1: nb.float32[:,:], + x2: nb.float32[:,:], + metric_id: nb.int64 = METRICS['euclidean'], + p: nb.float32 = 2.) -> nb.float32[:,:]: + """Numba implementation of Euclidean + `scipy.spatial.distance.cdist(x, metric=p=2)` in 3D. + + Parameters + ---------- + x1 : np.ndarray + (N, 3) array of point coordinates in the first set + x2 : np.ndarray + (M, 3) array of point coordinates in the second set + metric_id : int, default 2 (Euclidean) + Distance metric enumerator + p : float, default 2. + p-norm factor for the Minkowski metric, if used + + Returns + ------- + np.ndarray + (N, M) array of pair-wise Euclidean distances + """ + # Check on the input + assert x1.shape[1] == 3 and x2.shape[1] == 3, ( + "Only supports 3D points for now.") + + # Dispatch (faster this way than dipatching at each distance call) + if metric_id == np.int64(0): + return _cdist_minkowski(x1, x2, p) + elif metric_id == np.int64(1): + return _cdist_cityblock(x1, x2) + elif metric_id == np.int64(2): + return _cdist_euclidean(x1, x2) + elif metric_id == np.int64(3): + return _cdist_sqeuclidean(x1, x2) + elif metric_id == np.int64(4): + return _cdist_chebyshev(x1, x2) + else: + raise ValueError("Distance metric not recognized.") + +@nb.njit(cache=True) +def _cdist_cityblock(x1: nb.float32[:,:], + x2: nb.float32[:,:]) -> nb.float32[:,:]: + res = np.empty((len(x1), len(x2)), dtype=x1.dtype) + for i1 in range(len(x1)): + for i2 in range(len(x2)): + res[i1, i2] = cityblock(x1[i1], x2[i2]) + + return res + +@nb.njit(cache=True) +def _cdist_euclidean(x1: nb.float32[:,:], + x2: nb.float32[:,:]) -> nb.float32[:,:]: + res = np.empty((len(x1), len(x2)), dtype=x1.dtype) + for i1 in range(len(x1)): + for i2 in range(len(x2)): + res[i1, i2] = euclidean(x1[i1], x2[i2]) + + return res + +@nb.njit(cache=True) +def _cdist_sqeuclidean(x1: nb.float32[:,:], + x2: nb.float32[:,:]) -> nb.float32[:,:]: + res = np.empty((len(x1), len(x2)), dtype=x1.dtype) + for i1 in range(len(x1)): + for i2 in range(len(x2)): + res[i1, i2] = sqeuclidean(x1[i1], x2[i2]) + + return res + +@nb.njit(cache=True) +def _cdist_chebyshev(x1: nb.float32[:,:], + x2: nb.float32[:,:]) -> nb.float32[:,:]: + res = np.empty((len(x1), len(x2)), dtype=x1.dtype) + for i1 in range(len(x1)): + for i2 in range(len(x2)): + res[i1, i2] = chebyshev(x1[i1], x2[i2]) + + return res + +@nb.njit(cache=True) +def _cdist_minkowski(x1: nb.float32[:,:], + x2: nb.float32[:,:], + p: nb.float32) -> nb.float32[:,:]: + res = np.empty((len(x1), len(x2)), dtype=x1.dtype) + for i1 in range(len(x1)): + for i2 in range(len(x2)): + res[i1, i2] = minkowski(x1[i1], x2[i2], p) + + return res + + +@nb.njit(cache=True) +def farthest_pair(x: nb.float32[:,:], + iterative: nb.boolean = False, + metric_id: nb.int64 = METRICS['euclidean'], + p: nb.float32 = 2.) -> (nb.int64, nb.int64, nb.float32): + """Algorithm which finds the two points which are farthest from each other + in a set, in the Euclidean sense. + + Two algorithms on offer: + - `brute`: compute pdist, use argmax (exact) + - `iterative`: Start with the first point in one set, find the farthest + point in the other, move to that point, repeat. This + algorithm is *not* exact, but a good and very quick proxy. + + Parameters + ---------- + x : np.ndarray + (N, 3) array of point coordinates + iterative : bool + If `True`, uses an iterative, fast approximation + metric_id : int, default 2 (Euclidean) + Distance metric enumerator + p : float + p-norm factor for the Minkowski metric, if used + + Returns + ------- + int + ID of the first point that makes up the pair + int + ID of the second point that makes up the pair + float + Distance between the two points + """ + # To save time, if Euclidean distance is used, use its square + euclidean = False + if metric_id == np.int64(2): + euclidean = True + metric_id = np.int64(3) + + # Dispatch + if not iterative: + # Find the distance between every pair of points + dist_mat = pdist(x, metric_id, p) + + # Select the pair with the farthest distance, fetch indexes + index = np.argmax(dist_mat) + i, j = index//len(x), index%len(x) + + # Record farthest distance + dist = dist_mat[i, j] + + else: + # Seed the search with the point farthest from the centroid + centroid = mean(x, 0) + start_idx = np.argmax(cdist(centroid[None, :], x, metric_id, p)) + + # Jump to the farthest point until convergence + idxs, subidx, dist, tempdist = [start_idx, start_idx], 0, 0., -1. + while dist > tempdist: + tempdist = dist + dists = cdist(x[idxs[subidx]][None, :], x, metric_id, p).flatten() + idxs[~subidx] = np.argmax(dists) + dist = dists[idxs[~subidx]] + subidx = ~subidx + + # Unroll index + i, j = idxs + + # If needed, take the square root of the distance + if euclidean: + dist = np.sqrt(dist) + + return i, j, dist + + +@nb.njit(cache=True) +def closest_pair(x1: nb.float32[:,:], + x2: nb.float32[:,:], + iterative: nb.boolean = False, + seed: nb.boolean = True, + metric_id: nb.int64 = METRICS['euclidean'], + p: nb.float32 = 2.) -> (nb.int64, nb.int64, nb.float32): + """Algorithm which finds the two points which are closest to each other + from two separate sets. + + Two algorithms on offer: + - `brute`: compute cdist, use argmin + - `iterative`: Start with one point in one set, find the closest + point in the other set, move to theat point, repeat. This + algorithm is *not* exact, but a good and very quick proxy. + + Parameters + ---------- + x1 : np.ndarray + (Nx3) array of point coordinates in the first set + x1 : np.ndarray + (Nx3) array of point coordinates in the second set + iterative : bool + If `True`, uses an iterative, fast approximation + seed : bool + Whether or not to use the two farthest points in one of the two sets + to seed the iterative algorithm + metric_id : int, default 2 (Euclidean) + Distance metric enumerator + p : float, default 2. + p-norm factor for the Minkowski metric, if used + + Returns + ------- + int + ID of the first point that makes up the pair + int + ID of the second point that makes up the pair + float + Distance between the two points + """ + # To save time, if Euclidean distance is used, use its square + euclidean = False + if metric_id == np.int64(2): + euclidean = True + metric_id = np.int64(3) + + # Find the two points in two sets of points that are closest to each other + if not iterative: + # Compute every pair-wise distances between the two sets + dist_mat = cdist(x1, x2, metric_id, p) + + # Select the closest pair of point, fetch indexes + index = np.argmin(dist_mat) + i, j = index//len(x2), index%len(x2) + + # Record closest distance + dist = dist_mat[i, j] + + else: + # Pick the point to start iterating from + xarr = [x1, x2] + idxs, set_id, dist, tempdist = [0, 0], 0, 1e9, 1e9+1. + if seed: + # Find the end points of the two sets + for i, x in enumerate(xarr): + seed_idxs = np.array(farthest_pair(xarr[i], True)[:2]) + seed_dists = cdist(xarr[i][seed_idxs], xarr[~i], metric_id, p) + seed_argmins = argmin(seed_dists, axis=1) + seed_mins = np.array([seed_dists[0][seed_argmins[0]], + seed_dists[1][seed_argmins[1]]]) + if np.min(seed_mins) < dist: + set_id = ~i + seed_choice = np.argmin(seed_mins) + idxs[int(~set_id)] = seed_idxs[seed_choice] + idxs[int(set_id)] = seed_argmins[seed_choice] + dist = seed_mins[seed_choice] + + # Find the closest point in the other set, repeat until convergence + while dist < tempdist: + tempdist = dist + dists = cdist( + xarr[set_id][idxs[set_id]][None, :], xarr[~set_id], + metric_id, p).flatten() + idxs[~set_id] = np.argmin(dists) + dist = dists[idxs[~set_id]] + subidx = ~set_id + + # Unroll index + i, j = idxs + + # If needed, take the square root of the distance + if euclidean: + dist = np.sqrt(dist) + + return i, j, dist diff --git a/spine/math/graph.py b/spine/math/graph.py new file mode 100644 index 000000000..39bc12f33 --- /dev/null +++ b/spine/math/graph.py @@ -0,0 +1,378 @@ +"""Numba JIT compiled implementation of graph routines. + +In particular, this module supports the CSR data structure and derived methods, +which tremendously speeds up graph construction and computation in Numba. +""" + +import numba as nb +import numpy as np + +from .distance import METRICS, cdist, minkowski, cityblock, sqeuclidean, chebyshev + + +CSR_DTYPE = ( + ('num_nodes', nb.int64), + ('neighbors', nb.int64[:]), + ('offsets', nb.int64[:]) +) + + +@nb.experimental.jitclass(CSR_DTYPE) +class CSRGraph: + """Numba-enabled compressed Sparse Row (CSR) representation of a sparse matrix. + + Attributes + ---------- + neighbors : np.ndarray + (E) List of node neighbors in a compressed array + offsets : np.ndarray + (N+1) Per-node slicing boundaries to query each node neighborhood + num_nodes : int + Number of nodes in the graph, N + """ + + def __init__(self, + neighbors: nb.int64[:], + offsets: nb.int64[:], + num_nodes: nb.int64): + """Construct the Compressed Sparse Row (CSR) representation of a + sparse matrix based on a list of nodes and edges. + + Parameters + ---------- + neighbors : np.ndarray + (E) List of node neighbors in a compressed array + offsets : np.ndarray + (N+1) Per-node slicing boundaries to query each node neighborhood + num_nodes : int + Number of nodes in the graph, N + """ + self.neighbors = neighbors + self.offsets = offsets + self.num_nodes = num_nodes + + def __getitem__(self, + node_id: nb.int64): + """Get the list of neighbors associated with a node. + + Parameters + ---------- + node_id : int + Node index i + + Returns + ------- + np.ndarray + List of neighbors associated with node i + """ + start, end = self.offsets[node_id], self.offsets[node_id + 1] + return self.neighbors[start:end] + + def num_neighbors(self, + node_id: nb.int64): + """Returns the number of neighbors of a node. + + Parameters + ---------- + node_id : int + Node index i + + Returns + ------- + int + Number of neighbors of node i + """ + start, end = self.offsets[node_id], self.offsets[node_id + 1] + return end - start + + +@nb.njit +def csr_graph(edge_index: nb.int64[:,:], + num_nodes: nb.int64, + directed: nb.boolean = True) -> CSR_DTYPE: + """Construct the Compressed Sparse Row (CSR) representation of a sparse + matrix based on a list of nodes and edges. + + Parameters + ---------- + edge_index : np.ndarray + (E, 2) List of active edge indices in the graph + num_nodes : int + Number of nodes in the graph, N + directed : bool + Whether the input graph is directed or not + """ + # Count the number of connections per node + counts = np.zeros(num_nodes, dtype=np.int64) + for s, t in edge_index: + counts[s] += 1 + if not directed: + counts[t] += 1 + + # Build the offsets array + offsets = np.empty(num_nodes + 1, dtype=np.int64) + offsets[0] = 0 + for i in range(num_nodes): + offsets[i + 1] = offsets[i] + counts[i] + + # Build the neighbors array + neighbors = np.empty(offsets[-1], dtype=np.int64) + fill = np.zeros(num_nodes, dtype=np.int64) + for s, t in edge_index: + idx = offsets[s] + fill[s] + neighbors[idx] = t + fill[s] += 1 + if not directed: + idx = offsets[t] + fill[t] + neighbors[idx] = s + fill[t] += 1 + + # Initialize the CSR graph + return CSRGraph(neighbors, offsets, num_nodes) + + +@nb.njit(cache=True) +def connected_components(edge_index: nb.int64[:,:], + num_nodes: nb.int64, + min_samples: nb.int64 = 1, + directed: nb.boolean = True) -> nb.int64[:]: + """Find connected components. + + Parameters + ---------- + edge_index : np.ndarray + (E, 2) List of active edge indices in the graph + num_nodes : int + Number of nodes in the graph, N + directed : bool, default True + Whether the input graph is directed or not + + Returns + ------- + np.ndarray + (N) Cluster label associated with each node + """ + # Initialize the CSR data structure + graph = csr_graph(edge_index, num_nodes, directed) + + # Initialize output + labels = np.arange(graph.num_nodes) + visited = np.zeros(graph.num_nodes, dtype=nb.boolean) + component = np.empty(graph.num_nodes, dtype=nb.int64) + comp_idx = np.empty(1, dtype=nb.int64) # Acts as pointer + + # Loop through all nodes and start DFS from unvisited nodes + label = 0 + min_neighbors = min_samples - 1 + for node in range(graph.num_nodes): + if not visited[node]: + if graph.num_neighbors(node) > min_neighbors: + # Perform DFS and collect all nodes in this connected component + comp_idx[0] = 0 + dfs(graph, visited, node, component, comp_idx) + + # Collect all nodes that belong to the same connected component + for i in range(comp_idx[0]): + labels[component[i]] = label + + else: + # Relabel solitary nodes to maintain ordering + labels[node] = label + + # Increment label + label += 1 + + return labels + + +@nb.njit(cache=True) +def dfs(graph: CSR_DTYPE, + visited: nb.boolean[:], + node: nb.int64, + component: nb.int64[:], + comp_idx: nb.int64[:]): + """Does a depth-first search and builds a connected component. + + Parameters + ---------- + graph : CSRGraph + CSR representation of a graph + visitied : np.ndarray + (N) Boolean array which specified weather a node has been visited or not. + node : int + Current node index + component : np.ndarray + (N) Current component (padded) + comp_idx : np.ndarray + Current component index (pointer) + """ + # Mark the node as visited, incremant pointer + visited[node] = True + component[comp_idx[0]] = node + comp_idx[0] += 1 + + # Traverse all the neighbors of the node + for neighbor in graph[node]: + if not visited[neighbor]: + dfs(graph, visited, neighbor, component, comp_idx) + + +@nb.njit(cache=True) +def radius_graph(x: nb.float32[:, :], + radius: nb.float32, + metric_id: nb.int64 = METRICS['euclidean'], + p: nb.float32 = 2.) -> nb.int64[:, :]: + """Builds an undirected radius graph. + + This function generates a list of edges in a graph which connects all nodes + which live within some radius R of each other. + + Parameters + ---------- + x : np.ndarray + (N, 3) array of node coordinates + radius : float + Radius within which to build connections in the graph + metric_id : int, default 2 (Euclidean) + Distance metric enumerator + p : float, default 2. + p-norm factor for the Minkowski metric, if used + + Returns + ------- + np.ndarray + (E, 2) array of edges in the radius graph + """ + # Determine the distance function to use. If the metric is Euclidean, it + # is cheaper to square the radius and use the squared Euclidean metric + if metric_id == np.int64(0): + return _radius_graph_minkowski(x, radius, p) + elif metric_id == np.int64(1): + return _radius_graph_cityblock(x, radius) + elif metric_id == np.int64(2): + radius = radius*radius + return _radius_graph_sqeuclidean(x, radius) + elif metric_id == np.int64(3): + return _radius_graph_sqeuclidean(x, radius) + elif metric_id == np.int64(4): + return _radius_graph_chebyshev(x, radius) + else: + raise ValueError("Distance metric not recognized.") + +@nb.njit(cache=True) +def _radius_graph_minkowski(x: nb.float32[:, :], + radius: nb.float32, + p: nb.float32) -> nb.float32[:, :]: + # Initialize a data structure to hold edges + num_nodes = len(x) + max_edges = num_nodes*(num_nodes - 1)//2 + edge_index = np.empty((max_edges, 2), dtype=np.int64) + + # Loop over pairs of nodes, ass edges if the distance fits the bill + edge_count = 0 + for i in range(num_nodes): + for j in range(i + 1, num_nodes): + if minkowski(x[i], x[j], p) <= radius: + edge_index[edge_count, 0], edge_index[edge_count, 1] = i, j + edge_count += 1 + + return edge_index[:edge_count] + +@nb.njit(cache=True) +def _radius_graph_cityblock(x: nb.float32[:, :], + radius: nb.float32) -> nb.float32[:, :]: + # Initialize a data structure to hold edges + num_nodes = len(x) + max_edges = num_nodes*(num_nodes - 1)//2 + edge_index = np.empty((max_edges, 2), dtype=np.int64) + + # Loop over pairs of nodes, ass edges if the distance fits the bill + edge_count = 0 + for i in range(num_nodes): + for j in range(i + 1, num_nodes): + if cityblock(x[i], x[j]) <= radius: + edge_index[edge_count, 0], edge_index[edge_count, 1] = i, j + edge_count += 1 + + return edge_index[:edge_count] + +@nb.njit(cache=True) +def _radius_graph_sqeuclidean(x: nb.float32[:, :], + radius: nb.float32) -> nb.float32[:, :]: + # Initialize a data structure to hold edges + num_nodes = len(x) + max_edges = num_nodes*(num_nodes - 1)//2 + edge_index = np.empty((max_edges, 2), dtype=np.int64) + + # Loop over pairs of nodes, ass edges if the distance fits the bill + edge_count = 0 + for i in range(num_nodes): + for j in range(i + 1, num_nodes): + if sqeuclidean(x[i], x[j]) <= radius: + edge_index[edge_count, 0], edge_index[edge_count, 1] = i, j + edge_count += 1 + + return edge_index[:edge_count] + + +@nb.njit(cache=True) +def _radius_graph_chebyshev(x: nb.float32[:, :], + radius: nb.float32) -> nb.float32[:, :]: + # Initialize a data structure to hold edges + num_nodes = len(x) + max_edges = num_nodes*(num_nodes - 1)//2 + edge_index = np.empty((max_edges, 2), dtype=np.int64) + + # Loop over pairs of nodes, ass edges if the distance fits the bill + edge_count = 0 + for i in range(num_nodes): + for j in range(i + 1, num_nodes): + if chebyshev(x[i], x[j]) <= radius: + edge_index[edge_count, 0], edge_index[edge_count, 1] = i, j + edge_count += 1 + + return edge_index[:edge_count] + + +@nb.njit(cache=True) +def union_find(edge_index: nb.int64[:,:], + count: nb.int64, + return_inverse: bool = True) -> nb.int64[:]: + """Numba implementation of the Union-Find algorithm. + + This function assigns a group to each node in a graph, provided + a set of edges connecting the nodes together. + + Parameters + ---------- + edge_index : np.ndarray + (E, 2) List of edges (sparse adjacency matrix) + count : int + Number of nodes in the graph, C + return_inverse : bool, default True + Make sure the group IDs range from 0 to N_groups-1 + + Returns + ------- + np.ndarray + (C) Group assignments for each of the nodes in the graph + Dict[int, np.ndarray] + Dictionary which maps groups to indexes + """ + labels = np.arange(count) + groups = {i: np.array([i]) for i in labels} + for e in edge_index: + li, lj = labels[e[0]], labels[e[1]] + if li != lj: + labels[groups[lj]] = li + groups[li] = np.concatenate((groups[li], groups[lj])) + del groups[lj] + + if return_inverse: + mask = np.zeros(count, dtype=np.bool_) + mask[labels] = True + mapping = np.empty(count, dtype=labels.dtype) + mapping[mask] = np.arange(np.sum(mask)) + labels = mapping[labels] + + return labels, groups diff --git a/spine/math/linalg.py b/spine/math/linalg.py new file mode 100644 index 000000000..0c246571a --- /dev/null +++ b/spine/math/linalg.py @@ -0,0 +1,100 @@ +"""Numba JIT compiled implementation of linear algebra routines.""" + +import numpy as np +import numba as nb + +__all__ = ['norm', 'submatrix'] + + +@nb.njit(cache=True) +def norm(x: nb.float32[:,:], + axis: nb.int32) -> nb.float32[:]: + """Numba implementation of `np.linalg.norm(x, axis)`. + + Parameters + ---------- + x : np.ndarray + (N,M) array of values + axis : int + Array axis ID + + Returns + ------- + np.ndarray + (N) or (M) array of `norm` values + """ + assert axis == 0 or axis == 1 + xnorm = np.empty(x.shape[1-axis], dtype=np.int32) + if axis == 0: + for i in range(len(xnorm)): + xnorm[i] = np.linalg.norm(x[:,i]) + else: + for i in range(len(xnorm)): + xnorm[i] = np.linalg.norm(x[i]) + + return xnorm + + +@nb.njit(cache=True) +def submatrix(x: nb.float32[:,:], + index1: nb.int32[:], + index2: nb.int32[:]) -> nb.float32[:,:]: + """Numba implementation of matrix subsampling. + + Parameters + ---------- + x : np.ndarray + (N,M) array of values + index1 : np.ndarray + (N') array of indices along axis 0 in the input matrix + index2 : np.ndarray + (M') array of indices along axis 1 in the input matrix + + Returns + ------- + np.ndarray + (N',M') array of values from the original matrix + """ + subx = np.empty((len(index1), len(index2)), dtype=x.dtype) + for i, i1 in enumerate(index1): + for j, i2 in enumerate(index2): + subx[i, j] = x[i1, i2] + + return subx + + +@nb.njit(cache=True) +def contingency_table(x: nb.int32[:], + y: nb.int32[:], + nx: nb.int32 = None, + ny: nb.int32 = None) -> nb.int64[:, :]: + """Build a contingency table for two sets of labels. + + Parameters + ---------- + x : np.ndarray + (N) Array of integrer values + y : np.ndarray + (M) Array of integrer values + nx : int, optional + Number of integer values allowed in `x`, N + ny : int, optional + Number of integer values allowd in `y`, M + + Returns + ------- + np.ndarray + (N, M) Contingency table + """ + # If not provided, assume that the max label is the max of the label array + if not nx: + nx = np.max(x) + 1 + if not ny: + ny = np.max(y) + 1 + + # Bin the table + table = np.zeros((nx, ny), dtype=np.int64) + for i, j in zip(x, y): + table[i, j] += 1 + + return table diff --git a/spine/math/neighbors.py b/spine/math/neighbors.py new file mode 100644 index 000000000..aca89543e --- /dev/null +++ b/spine/math/neighbors.py @@ -0,0 +1,246 @@ +"""Numba JIT compiled implementation of neighbor query routines. + +In particular, this module supports: +- Radius-based neighbor classification +- kNN-based neighbor classification +""" + +import numba as nb +import numpy as np + +from .base import mode +from .distance import METRICS, get_metric_id, cdist + +__all__ = ['RadiusNeighborsClassifier', 'KNeighborsClassifier'] + + +RNC_DTYPE = ( + ('radius', nb.float32), + ('metric_id', nb.int64), + ('p', nb.float32), + ('iterate', nb.boolean) +) + + +KNC_DTYPE = ( + ('k', nb.int64), + ('metric_id', nb.int64), + ('p', nb.float32) +) + + +@nb.experimental.jitclass(RNC_DTYPE) +class RadiusNeighborsClassifier: + """Class which assigns labels to points based on radial neighborhood + majority vote. + + More specifically, for each point that is to be labeled: + - Find all labeled points within some radius R; + - Label the point based on majority vote. + + If there are no labeled points in the neighborhood of a query point, a + label of -1 is assigned to the query point. + + Currently this is bruteforced with cdist, but in the future this is + intended to be used with a KDTree backend for quicker query. + + Attributes + ---------- + radius : float + Radius around which to check + metric_id : int + Distance metric enumerator + p : float + p-norm factor for the Minkowski metric, if used + iterate : bool + Whether to recurse the search until no new labels are assigned + """ + + def __init__(self, + radius: nb.float32, + metric: nb.types.string = 'euclidean', + p: nb.float32 = 2., + iterate: nb.boolean = True): + """Initialize the RadiusNeighborsClassifier parameters. + + Parameters + ---------- + radius : float + Radius around which to check + metric : str, default 'euclidean' + Distance metric + p : float, default 2. + p-norm factor for the Minkowski metric, if used + iterate : bool, default True + Whether to recurse the search until no new labels are assigned + """ + # For Euclidean, save time by using squared Euclidean + if metric == 'euclidean': + metric = 'sqeuclidean' + radius = radius*radius + + # Store parameters + self.radius = radius + self.metric_id = get_metric_id(metric, p) + self.p = p + self.iterate = iterate + + def fit_predict(self, + X: nb.float32[:,:], + y: nb.float32[:], + Xq: nb.float32[:,:]): + """Assign labels to a set of points given a set of reference points. + + Parameters + ---------- + X : np.ndarray + (N, 3) Set of reference points + y : np.ndarray + (N) Labels of reference points + Xq : nb.ndarray + (M, 3) Set of query points + + Returns + ------- + np.ndarray + (M) Labels assigned to the query points + np.ndarray + Index of points which have not been sucessfully assigned + """ + # Loop over query points until no new labels can be assigned + num_query = len(Xq) + labels = np.empty(num_query, dtype=np.int64) + orphan_index = np.arange(num_query, dtype=np.int64) + while num_query > 0: + # Start by computing the distance between the query and reference + dists = cdist(Xq, X, metric_id=self.metric_id, p=self.p) + + # Fetch the mask of reference points closer than some radius + mask = dists < self.radius + + # Loop over query points + assigned = np.zeros(num_query, dtype=nb.boolean) + for i in range(num_query): + # Find the set of points within the predefined radius + index = np.where(mask[i])[0] + + # Use the mode to define the label + if len(index): + labels[orphan_index[i]] = mode(y[index]) + assigned[i] = True + else: + labels[orphan_index[i]] = -1 + + # If the number of orphans is unchanged, break + orphan_update = np.where(~assigned)[0] + if len(orphan_update) == 0 or len(orphan_update) == num_query: + orphan_index = orphan_index[orphan_update] + break + + # If no recursion is required, abort loop + if not self.iterate: + orphan_index = orphan_index[orphan_update] + break + + # Update the reference and query points + label_update = np.where(assigned)[0] + X = Xq[label_update] + Xq = Xq[orphan_update] + y = labels[orphan_index[label_update]] + + # Update orphan list + orphan_index = orphan_index[orphan_update] + num_query = len(orphan_index) + + return labels, orphan_index + + +@nb.experimental.jitclass(KNC_DTYPE) +class KNeighborsClassifier: + """Class which assigns labels to points based on a nearest neighbor + majority vote. + + More specifically, for each point that is to be labeled: + - Find the k closest labeled points; + - Label the point based on majority vote. + + If there are no labeled points in the neighborhood of a query point, a + label of -1 is assigned to the query point. + + Currently this is bruteforced with cdist, but in the future this is + intended to be used with a KDTree backend for quicker query. + + Attributes + ---------- + k : int + Number of neighbors to query + metric_id : int + Distance metric enumerator + p : float + p-norm factor for the Minkowski metric, if used + """ + + def __init__(self, + k: nb.int64, + metric: nb.types.string = 'euclidean', + p: nb.float32 = 2.): + """Initialize the RadiusNeighborsClassifier parameters. + + Parameters + ---------- + k : int + Number of neighbors to query + metric : str, default 'euclidean' + Distance metric + p : float, default 2. + p-norm factor for the Minkowski metric, if used + """ + # For Euclidean, save time by using squared Euclidean + if metric == 'euclidean': + metric = 'sqeuclidean' + + # Store parameters + self.k = k + self.metric_id = get_metric_id(metric, p) + self.p = p + + def fit_predict(self, + X: nb.float32[:,:], + y: nb.float32[:], + Xq: nb.float32[:,:]): + """Assign labels to a set of points given a set of reference points. + + Parameters + ---------- + X : np.ndarray + (N, 3) Set of reference points + y : np.ndarray + (N) Labels of reference points + Xq : nb.ndarray + (M, 3) Set of query points + + Returns + ------- + np.ndarray + (M) Labels assigned to the query points + np.ndarray + Index of points which have not been sucessfully assigned + """ + # If there are no labeled points provided, nothing to do + if len(X) == 0: + return (np.full(len(Xq), -1, dtype=np.int64), + np.arange(len(Xq), dtype=np.int64)) + + # Start by computing the distance between the query and reference + dists = cdist(Xq, X, metric_id=self.metric_id, p=self.p) + + # Loop over query poins + labels = np.empty(len(Xq), dtype=np.int64) + for i in range(len(Xq)): + # Find the list k closest labels + index = np.argsort(dists[i])[:self.k] + + # Use the mode to define the label + labels[i] = mode(y[index]) + + return labels, np.empty(0, dtype=np.int64) diff --git a/spine/model/graph_spice.py b/spine/model/graph_spice.py index 9c0e02183..654aa4bc3 100644 --- a/spine/model/graph_spice.py +++ b/spine/model/graph_spice.py @@ -217,6 +217,7 @@ def forward(self, data, seg_label, clust_label=None): else: features = result['hypergraph_features'] + coords = TensorBatch(coords.data[:, coords.coord_cols], coords.counts) graph = self.constructor(coords, features, seg_label, clust_label) # If requested, convert edge predictions to node predictions diff --git a/spine/model/layer/cnn/ppn.py b/spine/model/layer/cnn/ppn.py index 72f9f0e5a..b0e5b83b4 100644 --- a/spine/model/layer/cnn/ppn.py +++ b/spine/model/layer/cnn/ppn.py @@ -11,7 +11,7 @@ from .blocks import ResNetBlock, SPP, ASPP from spine.data import TensorBatch -from spine.utils.torch_local import local_cdist +from spine.utils.torch_local import cdist_fast from spine.utils.logger import logger from spine.utils.globals import ( COORD_COLS, VALUE_COL, PART_COL, SHAPE_COL, TRACK_SHP, GHOST_SHP, @@ -537,7 +537,7 @@ def get_ppn_positives(coords: torch.Tensor, # Compute the pairwise distance between the particle voxels and its # label points - dist_mat = local_cdist(coords[index], points) + dist_mat = cdist_fast(coords[index], points) # Generate a positive mask for all particle voxels within some # distance of its label points @@ -632,7 +632,7 @@ def forward(self, ppn_label, ppn_points, ppn_masks, ppn_layers, ppn_coords, else: # Compute the pairwise distances between each label point # and all the voxels in the image. - dist_mat = local_cdist(points_entry, points_label[:, COORD_COLS]) + dist_mat = cdist_fast(points_entry, points_label[:, COORD_COLS]) min_return = torch.min(dist_mat, dim=1) closest = offset + min_return.indices diff --git a/spine/model/layer/gnn/encode/geometric.py b/spine/model/layer/gnn/encode/geometric.py index 7ec03a729..9c1d57f51 100644 --- a/spine/model/layer/gnn/encode/geometric.py +++ b/spine/model/layer/gnn/encode/geometric.py @@ -3,7 +3,7 @@ from spine.data import TensorBatch -from spine.utils.torch_local import local_cdist +from spine.utils.torch_local import cdist_fast from spine.utils.globals import COORD_COLS, VALUE_COL, SHAPE_COL from spine.utils.gnn.cluster import ( get_cluster_features_batch, get_cluster_points_label_batch, @@ -401,7 +401,7 @@ def get_base_features(self, data, clusts, edge_index, closest_index=None): # Find the closest set point in each cluster if closest_index is None: - d12 = local_cdist(x1, x2) + d12 = cdist_fast(x1, x2) imin = torch.argmin(d12) else: imin = closest_index[e[0], e[1]] diff --git a/spine/model/layer/gnn/graph/base.py b/spine/model/layer/gnn/graph/base.py index 265bac5c2..11b9442c1 100644 --- a/spine/model/layer/gnn/graph/base.py +++ b/spine/model/layer/gnn/graph/base.py @@ -17,6 +17,12 @@ class GraphBase: # Name of the graph constructor (as specified in the configuration) name = None + # List of recognized distance methods + _dist_methods = ('voxel', 'centroid') + + # List of recognized distance algorithms + _dist_algorithms = ('brute', 'iterative', 'recursive') + def __init__(self, directed=False, max_length=None, classes=None, max_count=None, dist_method='voxel', dist_algorithm='brute'): """Initializes attributes shared accross all graph constructors. @@ -37,14 +43,21 @@ def __init__(self, directed=False, max_length=None, classes=None, dist_method : str, default 'voxel' Method used to compute inter-node distance ('voxel' or 'centroid') dist_algorithm : str, default 'brute' - Algorithm used to comppute inter-node distance - ('brute' or 'recursive') + Algorithm used to comppute inter-node distance ('brute' or 'iterative') """ + # Check on enumarated strings + assert dist_method in self._dist_methods, ( + f"Distance computation method not recognized: {dist_method}. " + f"Must be one of {self._dist_methods}.") + assert dist_algorithm in self._dist_algorithms, ( + f"Distance computation algorithm not recognized: {dist_algorithm}. " + f"Must be one of {self._dist_algorithms}.") + # Store attributes self.directed = directed self.max_count = max_count - self.dist_method = dist_method - self.dist_algorithm = dist_algorithm + self.dist_centroid = dist_method == 'centroid' + self.dist_iterative = dist_algorithm != 'brute' # Convert `max_length` to a matrix, if provided as a `triu` self.max_length = max_length @@ -98,8 +111,8 @@ def __call__(self, data, clusts, classes=None, groups=None): if self.compute_dist: dist_mat, closest_index = inter_cluster_distance( data.tensor[:, COORD_COLS], clusts.index_list, - clusts.counts, method=self.dist_method, - algorithm=self.dist_algorithm, return_index=True) + clusts.counts, centroid=self.dist_centroid, + iterative=self.dist_iterative, return_index=True) # Generate the edge index edge_index, edge_counts = self.generate( diff --git a/spine/model/layer/gnn/graph/knn.py b/spine/model/layer/gnn/graph/knn.py index 25dcc5cc0..137c44785 100644 --- a/spine/model/layer/gnn/graph/knn.py +++ b/spine/model/layer/gnn/graph/knn.py @@ -3,7 +3,7 @@ import numpy as np import numba as nb -import spine.utils.numba_local as nbl +from spine.math.linalg import submatrix from .base import GraphBase @@ -75,7 +75,7 @@ def _generate(batch_ids: nb.int64[:], clust_ids = np.where(batch_ids == b)[0] if len(clust_ids) > 1: subk = min(k+1, len(clust_ids)) - submat = nbl.submatrix(dist_mat, clust_ids, clust_ids) + submat = submatrix(dist_mat, clust_ids, clust_ids) for i in range(len(submat)): idxs = np.argsort(submat[i])[1:subk] edges = np.empty((subk-1,2), dtype=np.int64) diff --git a/spine/model/layer/gnn/graph/mst.py b/spine/model/layer/gnn/graph/mst.py index 65d8479f7..72e84e8da 100644 --- a/spine/model/layer/gnn/graph/mst.py +++ b/spine/model/layer/gnn/graph/mst.py @@ -5,7 +5,7 @@ from scipy.sparse.csgraph import minimum_spanning_tree -import spine.utils.numba_local as nbl +from spine.math.linalg import submatrix from .base import GraphBase @@ -58,7 +58,7 @@ def _generate(batch_ids: nb.int64[:], for b in np.unique(batch_ids): clust_ids = np.where(batch_ids == b)[0] if len(clust_ids) > 1: - submat = np.triu(nbl.submatrix(dist_mat, clust_ids, clust_ids)) + submat = np.triu(submatrix(dist_mat, clust_ids, clust_ids)) # Suboptimal. Ideally want to reimplement in Numba, tall order. with nb.objmode(mst_mat = 'float32[:,:]'): mst_mat = minimum_spanning_tree(submat) diff --git a/spine/model/uresnet.py b/spine/model/uresnet.py index 96fb99993..3559d04f2 100644 --- a/spine/model/uresnet.py +++ b/spine/model/uresnet.py @@ -10,7 +10,7 @@ from spine.data import TensorBatch from spine.utils.globals import BATCH_COL, COORD_COLS, VALUE_COL, GHOST_SHP from spine.utils.logger import logger -from spine.utils.torch_local import local_cdist +from spine.utils.torch_local import cdist_fast from .layer.factories import loss_fn_factory @@ -391,7 +391,7 @@ def get_distance_weights(self, seg_label, point_label): continue # Compute the minimal distance to any point in this entry - dist_mat = local_cdist(voxels_b, points_b) + dist_mat = cdist_fast(voxels_b, points_b) dists_b = torch.min(dist_mat, dim=1).values # Record information in the batch-wise tensor diff --git a/spine/post/optical/barycenter.py b/spine/post/optical/barycenter.py index cab47a34d..ffb1a65d9 100644 --- a/spine/post/optical/barycenter.py +++ b/spine/post/optical/barycenter.py @@ -2,7 +2,7 @@ import numpy as np -from spine.utils.numba_local import cdist +from spine.math.distance import cdist class BarycenterFlashMatcher: diff --git a/spine/post/reco/cathode_cross.py b/spine/post/reco/cathode_cross.py index c375a482f..c0d3c2e31 100644 --- a/spine/post/reco/cathode_cross.py +++ b/spine/post/reco/cathode_cross.py @@ -1,13 +1,13 @@ """Cathode crossing identification + merging module.""" import numpy as np -from scipy.spatial.distance import cdist from spine.data import RecoInteraction, TruthInteraction +from spine.math.distance import cdist, farthest_pair + from spine.utils.globals import TRACK_SHP from spine.utils.geo import Geometry -from spine.utils.numba_local import farthest_pair from spine.utils.gnn.cluster import cluster_direction from spine.post.base import PostBase diff --git a/spine/post/reco/shower.py b/spine/post/reco/shower.py index 1e1852b0d..11fcb3671 100644 --- a/spine/post/reco/shower.py +++ b/spine/post/reco/shower.py @@ -1,10 +1,10 @@ """Shower reconstruction module.""" import numpy as np -from scipy.stats import pearsonr + +from spine.math.distance import cdist from spine.utils.globals import SHOWR_SHP, TRACK_SHP, PROT_PID, PION_PID -from spine.utils.numba_local import cdist from spine.utils.gnn.cluster import cluster_direction, cluster_dedx from spine.data import ObjectList, RecoParticle diff --git a/spine/utils/cluster/ccc.py b/spine/utils/cluster/ccc.py index 7ff4806ae..99bb844dd 100644 --- a/spine/utils/cluster/ccc.py +++ b/spine/utils/cluster/ccc.py @@ -1,11 +1,12 @@ """Connected component clustering module.""" import numpy as np -from scipy.sparse import coo_matrix, csgraph import torch from spine.data import TensorBatch +from spine.math.graph import connected_components + from .orphan import OrphanAssigner __all__ = ['ConnectedComponentClusterer'] @@ -131,7 +132,7 @@ def fit_predict_entry(self, node_coords, edge_index, edge_assn, if len(nindex): node_pred[nindex] = node_pred_s offset = int(node_pred.max()) + 1 - + return node_pred def fit_predict_one(self, node_coords, edge_index, edge_assn, offset, @@ -159,17 +160,11 @@ def fit_predict_one(self, node_coords, edge_index, edge_assn, offset, """ # Narrow down the list of edges to those turned on assert edge_index.shape[1] == 2, ( - "The edge index must be of shape (E, 2)") - edges = edge_index[edge_assn == 1] - - # Convert the set of edges to a coordinate-format sparse adjacency matrix - num_nodes = len(node_coords) - edge_assn = np.ones(len(edges), dtype=int) - adj = coo_matrix((edge_assn, tuple(edges.T)), (num_nodes, num_nodes)) + "The edge index must be of shape (E, 2).") + edges = edge_index[np.where(edge_assn)[0]] - # Find connected components, allow for unidirectional connections - _, node_pred = csgraph.connected_components(adj, connection='weak') - node_pred = node_pred.astype(np.int64) + # Find connected components + node_pred = connected_components(edges, len(node_coords)) # If min_size is set, downselect the points considered labeled min_size = min_size if min_size is not None else self.min_size diff --git a/spine/utils/cluster/label.py b/spine/utils/cluster/label.py index bfded18ac..45adb679e 100644 --- a/spine/utils/cluster/label.py +++ b/spine/utils/cluster/label.py @@ -2,12 +2,13 @@ import numpy as np import torch -from torch_cluster import knn -from scipy.spatial.distance import cdist from spine.data import TensorBatch -from spine.utils.gnn.cluster import form_clusters, break_clusters +from spine.math.distance import METRICS, get_metric_id, cdist + +from spine.utils.torch_local import cdist_fast +from spine.utils.gnn.cluster import break_clusters from spine.utils.globals import ( COORD_COLS, VALUE_COL, CLUST_COL, SHAPE_COL, SHOWR_SHP, TRACK_SHP, MICHL_SHP, DELTA_SHP, GHOST_SHP) @@ -19,23 +20,19 @@ class ClusterLabelAdapter: """Adapts the cluster labels to account for the predicted semantics. Points wrongly predicted get the cluster label of the closest touching - cluster, if there is one. Points that are predicted as ghosts get invalid - (-1) cluster labels everywhere. + compatible cluster, if there is one. Points that are predicted as ghosts + get invalid (-1) cluster labels everywhere. Instances that have been broken up by the deghosting or semantic segmentation process get assigned distinct cluster labels for each - effective fragment, provided they appearing in the `break_classes` list. + effective fragment, provided they appear in the `break_classes` list. Notes ----- This class supports both Numpy arrays and Torch tensors. - - It uses the GPU implementation from `torch_cluster.knn` to speed up the - label adaptation computation (instead of cdist). - """ - def __init__(self, break_eps=1.1, break_metric='chebyshev', + def __init__(self, break_eps=1.1, break_metric='chebyshev', break_p=2., break_classes=[SHOWR_SHP, TRACK_SHP, MICHL_SHP, DELTA_SHP]): """Initialize the adapter class. @@ -47,13 +44,16 @@ def __init__(self, break_eps=1.1, break_metric='chebyshev', Distance scale used in the break up procedure break_metric : str, default 'chebyshev' Distance metric used in the break up produce + p : float, default 2. + p-norm factor for the Minkowski metric, if used break_classes : List[int], default [SHOWR_SHP, TRACK_SHP, MICHL_SHP, DELTA_SHP] Classes to run DBSCAN on to break up """ # Store relevant parameters self.break_eps = break_eps - self.break_metric = break_metric + self.break_metric_id = get_metric_id(break_metric, break_p) + self.break_p = break_p self.break_classes = break_classes # Attributes used to fetch the correct functions @@ -198,11 +198,9 @@ def _process(self, clust_label, seg_label, seg_pred, ghost_pred=None): X_pred = coords[bad_index] tagged_voxels_count = 1 while tagged_voxels_count > 0 and len(X_pred) > 0: - # Find the nearest neighbor to each predicted point - closest_ids = self._compute_neighbor(X_pred, X_true) - - # Compute Chebyshev distance between predicted and closest true. - distances = self._compute_distances(X_pred, X_true[closest_ids]) + # Compute Chebyshev distance between predicted and closest true + distances = self._compute_distances(X_pred, X_true) + distances, closest_ids = self._min(distances, 1) # Label unlabeled voxels that touch a compatible true voxel select_mask = distances < 1.1 @@ -253,7 +251,7 @@ def _process(self, clust_label, seg_label, seg_pred, ghost_pred=None): # Now if an instance was broken up, assign it different cluster IDs new_label[:, CLUST_COL] = break_clusters( - new_label, clusts, self.break_eps, self.break_metric) + new_label, clusts, self.break_eps, self.break_metric_id, self.break_p) return new_label @@ -281,6 +279,12 @@ def _eye(self, x): else: return np.eye(x, dtype=bool) + def _min(self, x, axis): + if self.torch: + return torch.min(x, axis) + else: + return np.min(x, axis), np.argmin(x, axis) + def _unique(self, x): if self.torch: return torch.unique(x).long() @@ -293,14 +297,10 @@ def _to_long(self, x): else: return x.astype(int64) - def _compute_neighbor(self, x, y): - if self.torch: - return knn(y[:, COORD_COLS], x[:, COORD_COLS], 1)[1] - else: - return cdist(x[:, COORD_COLS], y[:, COORD_COLS]).argmin(axis=1) - def _compute_distances(self, x, y): if self.torch: - return torch.amax(torch.abs(y[:, COORD_COLS] - x[:, COORD_COLS]), dim=1) + return cdist_fast(x[:, COORD_COLS], y[:, COORD_COLS], + metric='chebyshev') else: - return np.amax(np.abs(x[:, COORD_COLS] - y[:, COORD_COLS]), axis=1) + return cdist(x[:, COORD_COLS], y[:, COORD_COLS], + metric_id=METRICS['chebyshev']) diff --git a/spine/utils/cluster/orphan.py b/spine/utils/cluster/orphan.py index c4fc4df3c..217e12cb6 100644 --- a/spine/utils/cluster/orphan.py +++ b/spine/utils/cluster/orphan.py @@ -1,8 +1,9 @@ """Defines class used to assign orphaned points to a sensible cluster.""" import numpy as np -from sklearn.neighbors import KNeighborsClassifier, RadiusNeighborsClassifier -from sklearn.cluster import DBSCAN + +from spine.math.cluster import DBSCAN +from spine.math.neighbors import RadiusNeighborsClassifier, KNeighborsClassifier __all__ = ['OrphanAssigner'] @@ -13,20 +14,18 @@ class OrphanAssigner: This class takes care of finding the best match cluster ID for points that have not found a suitable group in the upstream clustering. - This is a wrapper class for two `scikit-learn` classes: + This is a wrapper class for two classes: - :class:`KNeighborsClassifier` - :class:`RadiusNeighborsClassifier` """ - def __init__(self, mode, iterate=True, assign_all=True, **kwargs): + def __init__(self, mode, assign_all=True, **kwargs): """Initialize the orphan assigner. Parameters ---------- mode : str Orphan assignment mode, one of 'knn' or 'radius' - iterate : bool, default True - Iterate the process until no additional orphans can be assigned assign_all : bool, default True If `True`, force assign all orphans to a cluster. In the 'knn' mode, this is guaranteed, provided there is at least one labeled point. @@ -39,22 +38,21 @@ def __init__(self, mode, iterate=True, assign_all=True, **kwargs): if mode == 'knn': self.classifier = KNeighborsClassifier(**kwargs) elif mode == 'radius': - self.classifier = RadiusNeighborsClassifier( - outlier_label=-1, **kwargs) + self.classifier = RadiusNeighborsClassifier(**kwargs) else: raise ValueError( "The orphan assignment mode must be one of 'knn' or " - f"'radius', got '{mode}' instead.") + f"'radius'. Got '{mode}' instead.") - # Store the extra parameters - self.iterate = iterate + # Store the extra parameter self.assign_all = assign_all # If needed, initialize DBSCAN if mode == 'radius' and assign_all: + radius = kwargs.get('radius') + metric = kwargs.get('metric', 'euclidean') self.dbscan = DBSCAN( - eps=self.classifier.radius, min_samples=1, - metric=self.classifier.metric, p=self.classifier.p) + eps=radius, min_samples=1, metric=metric, p=self.classifier.p) def __call__(self, X, y): """Place-holder for a function which assigns labels to orphan points. @@ -71,44 +69,23 @@ def __call__(self, X, y): np.ndarray (M) Labels assigned to each of the orphans """ - # Create a mask for orphaned points, throw if there are only orphans - orphan_index = np.where(y == -1)[0] - num_orphans = len(orphan_index) - if (self.mode == 'knn' or not self.assign_all) and len(y) == num_orphans: - raise RuntimeError( - "Cannot assign orphans without any valid labels.") - - # Loop until all there is no more orphans to assign - y_updated = y.copy() - while num_orphans: - # Fit the classifier with the labeled points - valid_index = np.where(y_updated > -1)[0] - if not len(valid_index): - break + # Create a mask to identify labeled and orphaned points + orphan_mask = y == -1 + orphan_index = np.where(orphan_mask)[0] + valid_index = np.where(~orphan_mask)[0] - self.classifier.fit(X[valid_index], y_updated[valid_index]) + # Assign orphan points using the neighbor classifier + labels, orphan_update = self.classifier.fit_predict( + X[valid_index], y[valid_index], X[orphan_index]) - # Get the assignment for each of the orphaned points - update = self.classifier.predict(X[orphan_index]) - - # Update the labels accordingly - y_updated[orphan_index] = update - - # If iterating is not required, break (iterating on kNN does nothing) - if not self.iterate or self.mode == 'knn': - break - - # If the number of orphans has not changed, no point in proceeding - orphan_index = orphan_index[update < 0] - if len(orphan_index) == num_orphans: - break - - num_orphans = len(orphan_index) + y_updated = y.copy() + y_updated[orphan_index] = labels + orphan_index = orphan_index[orphan_update] # If required, assign stragglers using DBSCAN - if num_orphans and self.mode == 'radius' and self.assign_all: + if len(orphan_index) and self.mode == 'radius' and self.assign_all: # Get the assignment for each of the orphaned points - update = self.dbscan.fit(X[orphan_index]).labels_ + update = self.dbscan.fit_predict(X[orphan_index]) # Update the labels accordingly offset = np.max(y_updated) + 1 diff --git a/spine/utils/dbscan.py b/spine/utils/dbscan.py deleted file mode 100644 index 213fe8e09..000000000 --- a/spine/utils/dbscan.py +++ /dev/null @@ -1,40 +0,0 @@ -"""Simple wrapper for sklearn's DBSCAN to turn its label output into -a list of clusters in the form of a point index list.""" - -import numpy as np -from typing import List -from sklearn.cluster import DBSCAN - - -def dbscan_points(coordinates, eps=1.999, min_samples=1, metric='euclidean'): - """Runs DBSCAN on an input point cloud. - - Returns the clusters as a list of indexes. - - Parameters - ---------- - coordinates : np.ndarray - Set of point coordinates - eps : float, default 1.999 - Distance parameter of DBSCAN - min_samples : int, default 1 - Minimum number of points in a cluster to be valid - metric : str, default 'euclidean' - Metric used to compute distances - - Returns - ------- - List[np.ndarray] - List of cluster indexes - """ - # Initialize DBSCAN - dbscan = DBSCAN(eps=eps, min_samples=min_samples, metric=metric) - - # Build clusters - labels = dbscan.fit(coordinates).labels_ - clusters = [] - for c in np.unique(labels): - if c > -1: - clusters.append(np.where(labels == c)[0]) - - return clusters diff --git a/spine/utils/gnn/cluster.py b/spine/utils/gnn/cluster.py index 05f3bd54d..5b7098b13 100644 --- a/spine/utils/gnn/cluster.py +++ b/spine/utils/gnn/cluster.py @@ -9,13 +9,14 @@ import torch from typing import List +import spine.math as sm + from spine.data import TensorBatch, IndexBatch from spine.utils.decorators import numbafy from spine.utils.globals import ( BATCH_COL, COORD_COLS, VALUE_COL, CLUST_COL, PART_COL, GROUP_COL, MOM_COL, SHAPE_COL, COORD_START_COLS, COORD_END_COLS, COORD_TIME_COL) -import spine.utils.numba_local as nbl def form_clusters_batch(data, min_size=-1, column=CLUST_COL, shapes=None, @@ -298,51 +299,75 @@ def form_clusters(data, min_size=-1, column=CLUST_COL, shapes=None): ------- List[Union[np.ndarray, torch.Tensor]] (C) List of arrays of voxel indexes in each cluster - List[int] + np.ndarray (C) Number of pixels in the mask for each cluster """ - # Fetch the right functions depending on input type + # Dispatch to the right functions based on input type if isinstance(data, torch.Tensor): - zeros = lambda x: torch.zeros(x, dtype=torch.bool, device=data.device) - where, unique = torch.where, torch.unique + return _form_clusters_torch(data, min_size, column, shapes) else: - zeros = lambda x: np.zeros(x, dtype=bool) - where, unique = np.where, np.unique + return _form_clusters_np(data, min_size, column, shapes) + +def _form_clusters_torch(data, min_size, column, shapes): + # If requested, restrict data to a specific set of semantic classes + if shapes is not None: + shapes = torch.as_tensor(shapes, dtype=data.dtype, device=data.device) + shape_index = torch.where(torch.any(data[:, SHAPE_COL] == shapes[:, None], 0))[0] + data = data[shape_index] + + # Get the list of unique clusters in this entry, order indices + clust_ids = data[:, column] + uniques, counts = torch.unique(clust_ids, return_counts=True) + full_index = torch.argsort(clust_ids, stable=True) + if shapes is not None: + full_index = shape_index[full_index] + # Build valid index + valid_index = torch.where( + (counts >= min_size) & (uniques > -1))[0].detach().cpu().numpy() + + # Build index list, restrict to valid clusters + counts = counts.detach().cpu().numpy() + breaks = tuple(np.cumsum(counts)[:-1]) + clusts = torch.tensor_split(full_index, breaks) + + # Restrict to valid clusters + clusts = [clusts[i] for i in valid_index] + counts = counts[valid_index] + + return clusts, counts + +def _form_clusters_np(data, min_size, column, shapes): # If requested, restrict data to a specific set of semantic classes if shapes is not None: - mask = zeros(len(data)) - for s in shapes: - mask |= (data[:, SHAPE_COL] == s) - mask = where(mask)[0] - data = data[mask] + shapes = np.array(shapes, dtype=data.dtype) + shape_index = np.where(np.any(data[:, SHAPE_COL] == shapes[:, None], 0))[0] + data = data[shape_index] # Get the clusters in this entry clust_ids = data[:, column] - clusts, counts = [], [] - for c in unique(clust_ids): - # Skip if the cluster ID is invalid - if c < 0: - continue - clust = where(clust_ids == c)[0] + uniques, counts = np.unique(clust_ids, return_counts=True) + full_index = np.argsort(clust_ids, stable=True) + if shapes is not None: + full_index = shape_index[full_index] - # Skip if the cluster size is below threshold - if len(clust) < min_size: - continue + # Build valid index + valid_index = np.where((counts >= min_size) & (uniques > -1))[0] - # If a mask was applied, get the appropriate IDs - if shapes is not None: - clust = mask[clust] + # Build index list, restrict to valid clusters + breaks = tuple(np.cumsum(counts)[:-1]) + clusts = list(np.split(full_index, breaks)) - clusts.append(clust) - counts.append(len(clust)) + # Restrict to valid clusters + clusts = [clusts[i] for i in valid_index] + counts = counts[valid_index] return clusts, counts @numbafy(cast_args=['data'], list_args=['clusts'], keep_torch=True, ref_arg='data') -def break_clusters(data, clusts, eps, metric): +def break_clusters(data, clusts, eps, metric_id, p): """Runs DBSCAN on each invididual cluster to segment them further if needed. Parameters @@ -353,8 +378,10 @@ def break_clusters(data, clusts, eps, metric): (C) List of cluster indexes eps : float DBSCAN clustering distance scale - metric : str - DBSCAN clustering distance metric + metric_id : int + DBSCAN clustering distance metric enumerator + p : float + p-norm factor for the Minkowski metric, if used Returns ------- @@ -365,7 +392,7 @@ def break_clusters(data, clusts, eps, metric): return np.copy(data[:, CLUST_COL]) # Break labels - break_labels = _break_clusters(data, clusts, eps, metric) + break_labels = _break_clusters(data, clusts, eps, metric_id, p) # Offset individual broken labels to prevent overlap labels = np.copy(data[:, CLUST_COL]) @@ -379,10 +406,11 @@ def break_clusters(data, clusts, eps, metric): return labels @nb.njit(cache=True, parallel=True, nogil=True) -def _break_clusters(data: nb.float64[:,:], +def _break_clusters(data: nb.float32[:,:], clusts: nb.types.List(nb.int64[:]), eps: nb.float64, - metric: str) -> nb.float64[:]: + metric_id: nb.int64, + p: nb.float64) -> nb.int64[:]: # Loop over clusters to break, run DBSCAN break_labels = np.full(len(data), -1, dtype=data.dtype) points = data[:, COORD_COLS] @@ -392,7 +420,8 @@ def _break_clusters(data: nb.float64[:,:], points_c = points[clust] # Run DBSCAN on the cluster, update labels - clust_ids = nbl.dbscan(points_c, eps=eps, metric=metric) + clust_ids = sm.cluster.dbscan( + points_c, eps=eps, metric_id=metric_id, p=p) # Store the breaking IDs break_labels[clust] = clust_ids @@ -431,7 +460,7 @@ def _get_cluster_label(data: nb.float64[:,:], labels = np.empty(len(clusts), dtype=data.dtype) for i, c in enumerate(clusts): - v, cts = nbl.unique(data[c, column]) + v, cts = sm.unique(data[c, column]) labels[i] = v[np.argmax(cts)] return labels @@ -491,7 +520,7 @@ def _get_cluster_closest_label(data: nb.float64[:,:], # Minimize the point-cluster distances dists = np.empty(len(group_index), dtype=data.dtype) for i, c in enumerate(group_index): - dists[i] = np.min(nbl.cdist(start_point, voxels[clusts[c]])) + dists[i] = np.min(sm.distance.cdist(start_point, voxels[clusts[c]])) # Label the closest cluster as the original label only, assign default # values ot the other clusters in the group @@ -545,10 +574,10 @@ def _get_cluster_primary_label(data: nb.float64[:,:], primary_mask = np.where(part_ids == group_ids[i])[0] if len(primary_mask): # Only use the primary component to label the cluster - v, cts = nbl.unique(data[clusts[i][primary_mask], column]) + v, cts = sm.unique(data[clusts[i][primary_mask], column]) else: # If there is no primary contribution, use the whole cluster - v, cts = nbl.unique(data[clusts[i], column]) + v, cts = sm.unique(data[clusts[i], column]) labels[i] = v[np.argmax(cts)] return labels @@ -606,7 +635,7 @@ def _get_cluster_closest_primary_label(data: nb.float64[:,:], # Minimize the point-cluster distances dists = np.empty(len(group_index), dtype=data.dtype) for i, c in enumerate(group_index): - dists[i] = np.min(nbl.cdist(start_point, voxels[clusts[c]])) + dists[i] = np.min(sm.cdist(start_point, voxels[clusts[c]])) # Label the closest cluster as the only primary cluster labels[group_index] = 0 @@ -798,7 +827,7 @@ def _get_cluster_features_base(data: nb.float64[:,:], x = data[clust][:, COORD_COLS] # Get cluster center - center = nbl.mean(x, 0) + center = sm.mean(x, 0) # Get orientation matrix A = np.cov(x.T, ddof = len(x) - 1).astype(x.dtype) @@ -900,7 +929,7 @@ def _get_cluster_features_extended(data: nb.float64[:,:], # Get the cluster semantic class, if requested if add_shape: - types, cnts = nbl.unique(data[clust, SHAPE_COL]) + types, cnts = sm.unique(data[clust, SHAPE_COL]) major_sem_type = types[np.argmax(cnts)] feats[k, -1] = major_sem_type @@ -961,8 +990,8 @@ def _get_cluster_points_label(data: nb.float64[:,:], # Bring the start points to the closest point in the corresponding cluster for i, c in enumerate(clusts): - dist_mat = nbl.cdist(points[i].reshape(-1,3), data[c][:, COORD_COLS]) - argmins = nbl.argmin(dist_mat, axis=1) + dist_mat = sm.distance.cdist(points[i].reshape(-1,3), data[c][:, COORD_COLS]) + argmins = sm.argmin(dist_mat, axis=1) points[i] = data[c][argmins][:, COORD_COLS].reshape(-1) return points @@ -1057,20 +1086,20 @@ def cluster_direction(voxels: nb.float64[:,:], "The shape of the input is not compatible with voxel coordinates.") if max_dist > 0: - dist_mat = nbl.cdist(start.reshape(1,-1), voxels).flatten() + dist_mat = sm.distance.cdist(start.reshape(1,-1), voxels).flatten() voxels = voxels[dist_mat <= max(max_dist, np.min(dist_mat))] # If optimize is set, select the radius by minimizing the transverse spread if optimize and len(voxels) > 2: # Order the cluster points by increasing distance to the start point - dist_mat = nbl.cdist(start.reshape(1,-1), voxels).flatten() + dist_mat = sm.distance.cdist(start.reshape(1,-1), voxels).flatten() order = np.argsort(dist_mat) voxels = voxels[order] dist_mat = dist_mat[order] # Find the PCA relative secondary spread for each point labels = -np.ones(len(voxels), dtype=voxels.dtype) - meank = nbl.mean(voxels[:3], 0) + meank = sm.mean(voxels[:3], 0) covk = (np.transpose(voxels[:3] - meank) @ (voxels[:3] - meank))/3 for i in range(2, len(voxels)): # Get the eigenvalues, compute relative transverse spread @@ -1102,7 +1131,7 @@ def cluster_direction(voxels: nb.float64[:,:], for i in range(len(voxels)): rel_voxels[i] = voxels[i] - start - mean = nbl.mean(rel_voxels, 0) + mean = sm.mean(rel_voxels, 0) norm = np.sqrt(np.sum(mean**2)) if norm: return mean/norm @@ -1190,12 +1219,12 @@ def cluster_dedx(voxels: nb.float64[:,:], # If necessary, anchor start point to the closest cluster point if anchor: - dists = nbl.cdist(start.reshape(1, -1), voxels).flatten() + dists = sm.distance.cdist(start.reshape(1, -1), voxels).flatten() start = voxels[np.argmin(dists)].astype(start.dtype) # Dirty # If max_dist is set, limit the set of voxels to those within a sphere of # radius max_dist around the start point - dists = nbl.cdist(start.reshape(1, -1), voxels).flatten() + dists = sm.distance.cdist(start.reshape(1, -1), voxels).flatten() if max_dist > 0.: index = np.where(dists <= max_dist)[0] if len(index) < 2: @@ -1252,12 +1281,12 @@ def cluster_dedx_dir(voxels: nb.float64[:,:], # If necessary, anchor start point to the closest cluster point if anchor: - dists = nbl.cdist(start.reshape(1, -1), voxels).flatten() + dists = sm.distance.cdist(start.reshape(1, -1), voxels).flatten() start = voxels[np.argmin(dists)].astype(start.dtype) # Dirty # If max_dist is set, limit the set of voxels to those within a sphere of # radius max_dist around the start point - dists = nbl.cdist(start.reshape(1, -1), voxels).flatten() + dists = sm.distance.cdist(start.reshape(1, -1), voxels).flatten() if max_dist > 0.: index = np.where(dists <= max_dist)[0] if len(index) < 2: @@ -1347,7 +1376,7 @@ def cluster_end_points(voxels: nb.float64[:,:]) -> ( Index of the end voxel """ # Get the axis of maximum spread - axis = nbl.principal_components(voxels)[0] + axis = sm.decomposition.principal_components(voxels)[0] # Compute coord values along that axis coords = np.empty(len(voxels)) @@ -1390,7 +1419,7 @@ def umbrella_curv(voxels: nb.float64[:,:], # Find the mean direction from that point refvox = voxels[vox_id] diffs = voxels - refvox - axis = nbl.mean(voxels - refvox, axis=0) + axis = sm.mean(voxels - refvox, axis=0) axis /= np.linalg.norm(axis) # Compute the dot product of every displacement vector w.r.t. the axis diff --git a/spine/utils/gnn/evaluation.py b/spine/utils/gnn/evaluation.py index d76db8fd8..2ba42e37d 100644 --- a/spine/utils/gnn/evaluation.py +++ b/spine/utils/gnn/evaluation.py @@ -11,9 +11,10 @@ from scipy.sparse import csr_array from scipy.sparse.csgraph import minimum_spanning_tree +import spine.math as sm + from spine.data import TensorBatch, IndexBatch, EdgeIndexBatch -import spine.utils.numba_local as nbl from spine.utils.metrics import sbd, ami, ari, pur_eff int_array = nb.int64[:] @@ -370,7 +371,7 @@ def edge_assignment_forest(edge_index, edge_pred, group_ids): # Convert the sparse incidence matrix scores to a CSR matrix n = len(group_ids) - off_scores = nbl.softmax(edge_pred, axis=1)[:, 0] + off_scores = sm.softmax(edge_pred, axis=1)[:, 0] score_mat = csr_array((off_scores, edge_index.T), shape=(n,n)) # Build the MST graph to minimize off scores @@ -414,7 +415,7 @@ def node_assignment(edge_index: nb.int64[:,:], # Loop over on edges, reset the group IDs of connected node on_edges = edge_index[np.where(edge_pred[:, 1] > edge_pred[:, 0])[0]] - return nbl.union_find(on_edges, num_nodes, return_inverse=True)[0] + return sm.graph.union_find(on_edges, num_nodes, return_inverse=True)[0] @nb.njit(cache=True) @@ -476,10 +477,10 @@ def primary_assignment(node_pred: nb.float32[:,:], (C) Primary labels """ if group_ids is None: - return nbl.argmax(node_pred, axis=1).astype(np.bool_) + return sm.argmax(node_pred, axis=1).astype(np.bool_) primary_ids = np.zeros(len(node_pred), dtype=np.bool_) - node_pred = nbl.softmax(node_pred, axis=1) + node_pred = sm.softmax(node_pred, axis=1) for g in np.unique(group_ids): mask = np.where(group_ids == g)[0] idx = np.argmax(node_pred[mask][:,1]) @@ -538,7 +539,7 @@ def grouping_loss(pred_mat: nb.float32[:], Graph grouping loss """ if loss == 'ce': - return nbl.log_loss(target_mat, pred_mat) + return sm.log_loss(target_mat, pred_mat) elif loss == 'l1': return np.mean(np.absolute(pred_mat-target_mat)) elif loss == 'l2': @@ -587,7 +588,7 @@ def edge_assignment_score(edge_index: nb.int64[:,:], adj_mat = adjacency_matrix(edge_index, num_nodes) # Interpret the softmax score as a dense adjacency matrix probability - edge_scores = nbl.softmax(edge_pred, axis=1) + edge_scores = sm.softmax(edge_pred, axis=1) pred_adj = np.eye(num_nodes, dtype=edge_pred.dtype) for k, (i, j) in enumerate(edge_index): pred_adj[i, j] = edge_scores[k, 1] @@ -632,8 +633,8 @@ def edge_assignment_score(edge_index: nb.int64[:,:], # the two candidate groups node_mask = np.where( (best_groups == group_a) | (best_groups == group_b))[0] - sub_pred = nbl.submatrix(pred_adj, node_mask, node_mask).flatten() - sub_adj = nbl.submatrix(adj_mat, node_mask, node_mask).flatten() + sub_pred = sm.linalg.submatrix(pred_adj, node_mask, node_mask).flatten() + sub_adj = sm.linalg.submatrix(adj_mat, node_mask, node_mask).flatten() # Compute the current adjacency matrix between the two groups current_adj = (best_groups[node_mask] == diff --git a/spine/utils/gnn/network.py b/spine/utils/gnn/network.py index 9efa2bf50..3c4d6d092 100644 --- a/spine/utils/gnn/network.py +++ b/spine/utils/gnn/network.py @@ -3,15 +3,16 @@ import numpy as np import numba as nb +import spine.math as sm + from spine.data import TensorBatch from spine.utils.decorators import numbafy from spine.utils.globals import COORD_COLS -import spine.utils.numba_local as nbl def get_cluster_edge_features_batch(data, clusts, edge_index, - closest_index=True, algorithm='brute'): + closest_index=None, iterative=False): """Batched version of :func:`get_cluster_edge_features`. Parameters @@ -24,8 +25,8 @@ def get_cluster_edge_features_batch(data, clusts, edge_index, (2, E) Sparse incidence matrix closest_index : Union[np.ndarray, torch.Tensor], optional (C, C) : Combined index of the closest pair of voxels per edge - algorithm : str, default 'brute' - Method used to compute the inter-cluster distance + iterative : bool, default False + If `True`, uses an iterative, fast approximation for distance computations Returns ------- @@ -36,7 +37,7 @@ def get_cluster_edge_features_batch(data, clusts, edge_index, index = edge_index.index_t if directed else edge_index.directed_index_t counts = edge_index.counts if directed else edge_index.directed_counts feats = get_cluster_edge_features( - data.tensor, clusts.index_list, index, closest_index, algorithm) + data.tensor, clusts.index_list, index, closest_index, iterative) return TensorBatch(feats, counts) @@ -67,7 +68,7 @@ def get_voxel_edge_features_batch(data, edge_index, max_dist=5.0): @numbafy(cast_args=['data'], list_args=['clusts'], keep_torch=True, ref_arg='data') def get_cluster_edge_features(data, clusts, edge_index, - closest_index=None, algorithm='brute'): + closest_index=None, iterative=False): """Returns a tensor of edge features for each edge connecting point clusters in the graph. @@ -88,8 +89,8 @@ def get_cluster_edge_features(data, clusts, edge_index, (2, E) Incidence map between voxels closest_index : Union[np.ndarray, torch.Tensor], optional (C, C) : Combined index of the closest pair of voxels per edge - algorithm : str, default 'brute' - Method used to compute the inter-cluster distance + iterative : bool, default False + If `True`, uses an iterative, fast approximation for distance computations Returns ------- @@ -100,16 +101,16 @@ def get_cluster_edge_features(data, clusts, edge_index, return np.empty((0, 19), dtype=data.dtype) # Cannot type empty list return _get_cluster_edge_features( - data, clusts, edge_index, closest_index, algorithm) + data, clusts, edge_index, closest_index, iterative) # return _get_cluster_edge_features_vec( - # data, clusts, edge_index, closest_index, algorithm) + # data, clusts, edge_index, closest_index, iterative) @nb.njit(parallel=True, cache=True) def _get_cluster_edge_features(data: nb.float32[:,:], clusts: nb.types.List(nb.int64[:]), edge_index: nb.int64[:,:], closest_index: nb.int64[:] = None, - algorithm: str = 'brute') -> ( + iterative: nb.boolean = False) -> ( nb.float32[:,:]): feats = np.empty((len(edge_index), 19), dtype=data.dtype) @@ -124,7 +125,7 @@ def _get_cluster_edge_features(data: nb.float32[:,:], imin = closest_index[c1, c2] i1, i2 = imin//len(x2), imin%len(x2) else: - i1, j2, _ = nbl.closest_pair(x1, x2, algorithm) + i1, j2, _ = sm.distance.closest_pair(x1, x2, iterative) v1 = x1[i1,:] v2 = x2[i2,:] @@ -148,13 +149,13 @@ def _get_cluster_edge_features_vec(data: nb.float32[:,:], clusts: nb.types.List(nb.int64[:]), edge_index: nb.int64[:,:], closest_index: nb.int64[:] = None, - algorithm: str = 'brute') -> ( + iterative: nb.boolean = False) -> ( nb.float32[:,:]): # Get the closest points of approach IDs for each edge if closest_index is None: lend, idxs1, idxs2 = _get_edge_distances( - data[:,COORD_COLS], clusts, edge_index, algorithm) + data[:,COORD_COLS], clusts, edge_index, iterative) else: idxs1, idxs2 = closest_index[(edge_index[0], edge_index[1])] @@ -240,7 +241,7 @@ def _get_voxel_edge_features(data: nb.float32[:,:], @numbafy(cast_args=['voxels'], list_args=['clusts']) -def get_edge_distances(voxels, clusts, edge_index, algorithm='brute'): +def get_edge_distances(voxels, clusts, edge_index, iterative): """For each edge, finds the closest points of approach (CPAs) between the two voxel clusters it connects, and the distance that separates them. @@ -256,6 +257,8 @@ def get_edge_distances(voxels, clusts, edge_index, algorithm='brute'): (C) List of arrays of voxel IDs in each cluster edge_index : Union[np.ndarray, torch.Tensor] (2, E) Incidence matrix + iterative : bool, default False + If `True`, uses an iterative, fast approximation for distance computations Returns ------- @@ -266,13 +269,13 @@ def get_edge_distances(voxels, clusts, edge_index, algorithm='brute'): np.ndarray (E) List of voxel IDs corresponding to the second edge cluster CPA """ - return _get_edge_distances(voxels, clusts, edge_index, algorithm) + return _get_edge_distances(voxels, clusts, edge_index, iterative) @nb.njit(parallel=True, cache=True) def _get_edge_distances(voxels: nb.float32[:,:], clusts: nb.types.List(nb.int64[:]), edge_index: nb.int64[:,:], - algorithm: str = 'brute') -> ( + iterative: nb.boolean = False) -> ( nb.float32[:], nb.int64[:], nb.int64[:]): # Loop over the provided edges @@ -286,8 +289,8 @@ def _get_edge_distances(voxels: nb.float32[:,:], ii = jj = 0 dist = 0. else: - ii, jj, dist = nbl.closest_pair( - voxels[clusts[i]], voxels[clusts[j]], algorithm) + ii, jj, dist = sm.distance.closest_pair( + voxels[clusts[i]], voxels[clusts[j]], iterative) lend[k] = dist resi[k] = clusts[i][ii] @@ -297,8 +300,8 @@ def _get_edge_distances(voxels: nb.float32[:,:], @numbafy(cast_args=['voxels'], list_args=['clusts']) -def inter_cluster_distance(voxels, clusts, counts=None, method='voxel', - algorithm='brute', return_index=False): +def inter_cluster_distance(voxels, clusts, counts=None, centroid=False, + iterative=False, return_index=False): """Finds the inter-cluster distance between every pair of clusters within each batch, returned as a block-diagonal matrix. @@ -310,15 +313,13 @@ def inter_cluster_distance(voxels, clusts, counts=None, method='voxel', (C) List of cluster indexes counts : np.ndarray, optional (B) Number of clusters in each entry of the batch - method : str, default 'voxel' - Either the closest voxel distance ('voxel') of the cluster centroid - distance ('centroid') - algorithm : str, default 'brute' - Algorithm used to compute the 'voxel' distance. The 'brute' method - is exact but slow, 'recursive' uses a fast but approximate method. + centroid : bool, default False + If `True`, use the centroid distance as a fast, approximate proxy + iterative : bool, default False + If `True`, uses an iterative, fast approximation to compute voxel distance return_index : bool, default True Returns a combined index of the closest pair of voxels for each - cluster, if the 'voxel' distance method is used + cluster, if the 'centroid' distance method is not used Returns ------- @@ -337,49 +338,47 @@ def inter_cluster_distance(voxels, clusts, counts=None, method='voxel', return np.empty((0, 0), dtype=voxels.dtype) return _inter_cluster_distance( - voxels, clusts, counts, method, algorithm) + voxels, clusts, counts, method, iterative) else: # If there are no clusters, return empty - assert method == 'voxel', "Cannot return index for centroid method." + assert not centroid, "Cannot return index for centroid method." if len(clusts) == 0: return (np.empty((0, 0), dtype=voxels.dtype), np.empty((0, 0), dtype=np.int64)) return _inter_cluster_distance_index( - voxels, clusts, counts, algorithm) + voxels, clusts, counts, iterative) @nb.njit(parallel=True, cache=True) def _inter_cluster_distance(voxels: nb.float32[:,:], clusts: nb.types.List(nb.int64[:]), counts: nb.int64[:], - method: str = 'voxel', - algorithm: str = 'brute') -> nb.float32[:,:]: + centroid: nb.boolean = False, + iterative: nb.boolean = False) -> nb.float32[:,:]: # Loop over the upper diagonal elements of each block on the diagonal dist_mat = np.zeros((len(clusts), len(clusts)), dtype=voxels.dtype) indxi, indxj = complete_graph(counts) - if method == 'voxel': + if not centroid: for k in nb.prange(len(indxi)): # Identifiy the two voxels closest to each other in each cluster i, j = indxi[k], indxj[k] - dist_mat[i, j] = dist_mat[j, i] = nbl.closest_pair( - voxels[clusts[i]], voxels[clusts[j]], algorithm)[-1] + dist_mat[i, j] = dist_mat[j, i] = sm.distance.closest_pair( + voxels[clusts[i]], voxels[clusts[j]], iterative)[-1] - elif method == 'centroid': + else: # Compute the centroid of each cluster dtype = voxels.dtype centroids = np.empty((len(clusts), voxels.shape[1]), dtype=dtype) for i in nb.prange(len(clusts)): - centroids[i] = nbl.mean(voxels[clusts[i]], axis=0) + centroids[i] = sm.mean(voxels[clusts[i]], axis=0) # Measure the distance between cluster centroids for k in nb.prange(len(indxi)): i, j = indxi[k], indxj[k] dist_mat[i,j] = dist_mat[j,i] = np.sqrt( np.sum((centroids[j]-centroids[i])**2)) - else: - raise ValueError("Inter-cluster distance method not supported.") return dist_mat @@ -387,7 +386,7 @@ def _inter_cluster_distance(voxels: nb.float32[:,:], def _inter_cluster_distance_index(voxels: nb.float32[:,:], clusts: nb.types.List(nb.int64[:]), counts: nb.int64[:], - algorithm: str = 'brute') -> ( + iterative: nb.boolean = False) -> ( nb.float32[:,:], nb.int64[:,:]): # Loop over the upper diagonal elements of each block on the diagonal @@ -397,8 +396,8 @@ def _inter_cluster_distance_index(voxels: nb.float32[:,:], for k in nb.prange(len(indxi)): # Identify the two voxels closest to each other in each cluster i, j = indxi[k], indxj[k] - ii, jj, dist = nbl.closest_pair( - voxels[clusts[i]], voxels[clusts[j]], algorithm) + ii, jj, dist = sm.distance.closest_pair( + voxels[clusts[i]], voxels[clusts[j]], iterative) index = ii*len(clusts[j]) + jj # Store the index and the distance in a matrix diff --git a/spine/utils/gnn/voxels.py b/spine/utils/gnn/voxels.py index 35b35eff5..da3936909 100644 --- a/spine/utils/gnn/voxels.py +++ b/spine/utils/gnn/voxels.py @@ -4,11 +4,12 @@ import numpy as np import numba as nb +import spine.math as sm + from spine.data import TensorBatch from spine.utils.globals import COORD_COLS from spine.utils.decorators import numbafy -import spine.utils.numba_local as nbl def get_voxel_features_batch(data, max_dist=5.0): @@ -72,7 +73,7 @@ def _get_voxel_features(data: nb.float32[:,:], max_dist=5.0): # Compute intervoxel distance matrix voxels = data[:, COORD_COLS] - dist_mat = nbl.cdist(voxels, voxels) + dist_mat = sm.distance.cdist(voxels, voxels) # Get local geometrical features for each voxel feats = np.empty((len(voxels), 16), dtype=data.dtype) diff --git a/spine/utils/match.py b/spine/utils/match.py index 8ba744788..2a86a2881 100644 --- a/spine/utils/match.py +++ b/spine/utils/match.py @@ -3,7 +3,7 @@ import numpy as np import numba as nb -from .numba_local import cdist +from spine.math.distance import cdist __all__ = ['overlap_counts', 'overlap_iou', 'overlap_weighted_iou', 'overlap_dice', 'overlap_weighted_dice', 'overlap_chamfer'] diff --git a/spine/utils/metrics.py b/spine/utils/metrics.py index 436ab061d..9df73343b 100644 --- a/spine/utils/metrics.py +++ b/spine/utils/metrics.py @@ -5,7 +5,7 @@ from sklearn.metrics import adjusted_rand_score, adjusted_mutual_info_score -from .numba_local import contingency_table +from spine.math.linalg import contingency_table __all__ = ['pur', 'eff', 'pur_eff', 'ari', 'ami', 'sbd'] diff --git a/spine/utils/numba_local.py b/spine/utils/numba_local.py deleted file mode 100644 index 70b0dd533..000000000 --- a/spine/utils/numba_local.py +++ /dev/null @@ -1,767 +0,0 @@ -"""Extensions to the basic Numba package.""" - -import numpy as np -import numba as nb - - -@nb.njit -def seed(seed: int) -> None: - """Sets the numpy random seed for all Numba jitted functions. - - Note that setting the seed using `np.random.seed` outside a Numba jitted - function does *not* set the seed of Numba functions. - - Parameters - ---------- - seed : int - Random number generator seed - """ - np.random.seed(seed) - - -@nb.njit(cache=True) -def submatrix(x: nb.float32[:,:], - index1: nb.int32[:], - index2: nb.int32[:]) -> nb.float32[:,:]: - """Numba implementation of matrix subsampling. - - Parameters - ---------- - x : np.ndarray - (N,M) array of values - index1 : np.ndarray - (N') array of indices along axis 0 in the input matrix - index2 : np.ndarray - (M') array of indices along axis 1 in the input matrix - - Returns - ------- - np.ndarray - (N',M') array of values from the original matrix - """ - subx = np.empty((len(index1), len(index2)), dtype=x.dtype) - for i, i1 in enumerate(index1): - for j, i2 in enumerate(index2): - subx[i,j] = x[i1,i2] - return subx - - -@nb.njit(cache=True) -def unique(x: nb.int32[:]) -> (nb.int32[:], nb.int64[:]): - """Numba implementation of `np.unique(x, return_counts=True)`. - - Parameters - ---------- - x : np.ndarray - (N) array of values - - Returns - ------- - np.ndarray - (U) array of unique values - np.ndarray - (U) array of counts of each unique value in the original array - """ - b = np.sort(x.flatten()) - unique = list(b[:1]) - counts = [1 for _ in unique] - for v in b[1:]: - if v != unique[-1]: - unique.append(v) - counts.append(1) - else: - counts[-1] += 1 - - unique_np = np.empty(len(unique), dtype=x.dtype) - counts_np = np.empty(len(counts), dtype=np.int32) - for i in range(len(unique)): - unique_np[i] = unique[i] - counts_np[i] = counts[i] - - return unique_np, counts_np - - -@nb.njit(cache=True) -def mean(x: nb.float32[:,:], - axis: nb.int32) -> nb.float32[:]: - """Numba implementation of `np.mean(x, axis)`. - - Parameters - ---------- - x : np.ndarray - (N,M) array of values - axis : int - Array axis ID - - Returns - ------- - np.ndarray - (N) or (M) array of `mean` values - """ - assert axis == 0 or axis == 1 - mean = np.empty(x.shape[1-axis], dtype=x.dtype) - if axis == 0: - for i in range(len(mean)): - mean[i] = np.mean(x[:,i]) - else: - for i in range(len(mean)): - mean[i] = np.mean(x[i]) - return mean - - -@nb.njit(cache=True) -def norm(x: nb.float32[:,:], - axis: nb.int32) -> nb.float32[:]: - """Numba implementation of `np.linalg.norm(x, axis)`. - - Parameters - ---------- - x : np.ndarray - (N,M) array of values - axis : int - Array axis ID - - Returns - ------- - np.ndarray - (N) or (M) array of `norm` values - """ - assert axis == 0 or axis == 1 - xnorm = np.empty(x.shape[1-axis], dtype=np.int32) - if axis == 0: - for i in range(len(xnorm)): - xnorm[i] = np.linalg.norm(x[:,i]) - else: - for i in range(len(xnorm)): - xnorm[i] = np.linalg.norm(x[i]) - return xnorm - - -@nb.njit(cache=True) -def argmin(x: nb.float32[:,:], - axis: nb.int32) -> nb.int32[:]: - """Numba implementation of `np.argmin(x, axis)`. - - Parameters - ---------- - x : np.ndarray - (N,M) array of values - axis : int - Array axis ID - - Returns - ------- - np.ndarray - (N) or (M) array of `argmin` values - """ - assert axis == 0 or axis == 1 - argmin = np.empty(x.shape[1-axis], dtype=np.int32) - if axis == 0: - for i in range(len(argmin)): - argmin[i] = np.argmin(x[:,i]) - else: - for i in range(len(argmin)): - argmin[i] = np.argmin(x[i]) - return argmin - - -@nb.njit(cache=True) -def argmax(x: nb.float32[:,:], - axis: nb.int32) -> nb.int32[:]: - """Numba implementation of `np.argmax(x, axis)`. - - Parameters - ---------- - x : np.ndarray - (N,M) array of values - axis : int - Array axis ID - - Returns - ------- - np.ndarray - (N) or (M) array of `argmax` values - """ - assert axis == 0 or axis == 1 - argmax = np.empty(x.shape[1-axis], dtype=np.int32) - if axis == 0: - for i in range(len(argmax)): - argmax[i] = np.argmax(x[:,i]) - - else: - for i in range(len(argmax)): - argmax[i] = np.argmax(x[i]) - - return argmax - - -@nb.njit(cache=True) -def amin(x: nb.float32[:,:], - axis: nb.int32) -> nb.float32[:]: - """Numba implementation of `np.amin(x, axis)`. - - Parameters - ---------- - x : np.ndarray - (N,M) array of values - axis : int - Array axis ID - - Returns - ------- - np.ndarray - (N) or (M) array of `min` values - """ - assert axis == 0 or axis == 1 - xmin = np.empty(x.shape[1-axis], dtype=np.int32) - if axis == 0: - for i in range(len(xmin)): - xmin[i] = np.min(x[:, i]) - - else: - for i in range(len(xmin)): - xmin[i] = np.min(x[i]) - - return xmin - - -@nb.njit(cache=True) -def amax(x: nb.float32[:,:], - axis: nb.int32) -> nb.float32[:]: - """Numba implementation of `np.amax(x, axis)`. - - Parameters - ---------- - x : np.ndarray - (N,M) array of values - axis : int - Array axis ID - - Returns - ------- - np.ndarray - (N) or (M) array of `max` values - """ - assert axis == 0 or axis == 1 - xmax = np.empty(x.shape[1-axis], dtype=np.int32) - if axis == 0: - for i in range(len(xmax)): - xmax[i] = np.max(x[:, i]) - - else: - for i in range(len(xmax)): - xmax[i] = np.max(x[i]) - - return xmax - - -@nb.njit(cache=True) -def all(x: nb.float32[:,:], - axis: nb.int32) -> nb.boolean[:]: - """Numba implementation of `np.all(x, axis)`. - - Parameters - ---------- - x : np.ndarray - (N, M) Array of values - axis : int - Array axis ID - - Returns - ------- - np.ndarray - (N) or (M) array of `all` outputs - """ - assert axis == 0 or axis == 1 - all = np.empty(x.shape[1-axis], dtype=np.bool_) - if axis == 0: - for i in range(len(all)): - all[i] = np.all(x[:,i]) - - else: - for i in range(len(all)): - all[i] = np.all(x[i]) - - return all - - -@nb.njit(cache=True) -def contingency_table(x: nb.int32[:], - y: nb.int32[:], - nx: nb.int32=None, - ny: nb.int32=None) -> nb.int64[:, :]: - """Build a contingency table for two sets of labels. - - Parameters - ---------- - x : np.ndarray - (N) Array of integrer values - y : np.ndarray - (M) Array of integrer values - nx : int, optional - Number of integer values allowed in `x`, N - ny : int, optional - Number of integer values allowd in `y`, M - - Returns - ------- - np.ndarray - (N, M) Contingency table - """ - # If not provided, assume that the max label is the max of the label array - if not nx: - nx = np.max(x) + 1 - if not ny: - ny = np.max(y) + 1 - - # Bin the table - table = np.zeros((nx, ny), dtype=np.int64) - for i, j in zip(x, y): - table[i, j] += 1 - - return table - - -@nb.njit(cache=True) -def softmax(x: nb.float32[:,:], - axis: nb.int32) -> nb.float32[:,:]: - """ - Numba implementation of `scipy.special.softmax(x, axis)`. - - Parameters - ---------- - x : np.ndarray - (N,M) array of values - axis : int - Array axis ID - - Returns - ------- - np.ndarray - (N,M) Array of softmax scores - """ - assert axis == 0 or axis == 1 - if axis == 0: - xmax = amax(x, axis=0) - logsumexp = np.log(np.sum(np.exp(x-xmax), axis=0)) + xmax - return np.exp(x - logsumexp) - else: - xmax = amax(x, axis=1).reshape(-1,1) - logsumexp = np.log(np.sum(np.exp(x-xmax), axis=1)).reshape(-1,1) + xmax - return np.exp(x - logsumexp) - - -@nb.njit(cache=True) -def log_loss(label: nb.boolean[:], - pred: nb.float32[:]) -> nb.float32: - """Numba implementation of cross-entropy loss. - - Parameters - ---------- - label : np.ndarray - (N) array of boolean labels (0 or 1) - pred : np.ndarray - (N) array of float scores (between 0 and 1) - - Returns - ------- - float - Cross-entropy loss - """ - if len(label) > 0: - return -(np.sum(np.log(pred[label])) + np.sum(np.log(1.-pred[~label])))/len(label) - else: - return 0. - - -@nb.njit(cache=True) -def pdist(x: nb.float32[:,:], - metric: str = 'euclidean') -> nb.float32[:,:]: - """Numba implementation of - `scipy.spatial.distance.pdist(x, p=2)` in 3D. - - Parameters - ---------- - x : np.ndarray - (N, 3) array of point coordinates in the set - metric : str, default 'euclidean' - Distance metric - - Returns - ------- - np.ndarray - (N, N) array of pair-wise Euclidean distances - """ - # Initialize the return matrix - assert x.shape[1] == 3, "Only supports 3D points for now." - res = np.empty((len(x), len(x)), dtype=x.dtype) - - if metric == 'euclidean': - for i in range(x.shape[0]): - res[i, i] = 0. - for j in range(i+1, x.shape[0]): - res[i, j] = res[j, i] = np.sqrt( - (x[i][0] - x[j][0])**2 + - (x[i][1] - x[j][1])**2 + - (x[i][2] - x[j][2])**2) - - elif metric == 'cityblock': - for i in range(x.shape[0]): - res[i, i] = 0. - for j in range(i+1, x.shape[0]): - res[i, j] = res[j, i] = ( - abs(x[i][0] - x[j][0]) + - abs(x[i][1] - x[j][1]) + - abs(x[i][2] - x[j][2])) - - elif metric == 'chebyshev': - for i in range(x.shape[0]): - res[i, i] = 0. - for j in range(i+1, x.shape[0]): - res[i, j] = res[j, i] = max( - abs(x[i][0] - x[j][0]), - abs(x[i][1] - x[j][1]), - abs(x[i][2] - x[j][2])) - - else: - raise ValueError("Distance metric not recognized.") - - return res - - -@nb.njit(cache=True) -def cdist(x1: nb.float32[:,:], - x2: nb.float32[:,:]) -> nb.float32[:,:]: - """Numba implementation of Euclidean - `scipy.spatial.distance.cdist(x, p=2)` in 1D, 2D or 3D. - - Parameters - ---------- - x1 : np.ndarray - (N,d) array of point coordinates in the first set - x2 : np.ndarray - (M,d) array of point coordinates in the second set - - Returns - ------- - np.ndarray - (N,M) array of pair-wise Euclidean distances - """ - dim = x1.shape[1] - assert dim and dim < 4, 'Only supports point dimensions up to 3' - res = np.empty((x1.shape[0], x2.shape[0]), dtype=x1.dtype) - if dim == 1: - for i1 in range(x1.shape[0]): - for i2 in range(x2.shape[0]): - res[i1,i2] = abs(x1[i1][0] - x2[i2][0]) - - elif dim == 2: - for i1 in range(x1.shape[0]): - for i2 in range(x2.shape[0]): - res[i1,i2] = np.sqrt( - (x1[i1][0] - x2[i2][0])**2 + - (x1[i1][1] - x2[i2][1])**2) - - elif dim == 3: - for i1 in range(x1.shape[0]): - for i2 in range(x2.shape[0]): - res[i1,i2] = np.sqrt( - (x1[i1][0]-x2[i2][0])**2 + - (x1[i1][1]-x2[i2][1])**2 + - (x1[i1][2]-x2[i2][2])**2) - - return res - - -@nb.njit(cache=True) -def radius_graph(x: nb.float32[:,:], - radius: nb.float32, - metric: str = 'euclidean') -> nb.float32[:,:]: - """Numba implementation of a radius-graph construction. - - This function generates a list of edges in a graph which connects all nodes - within some radius R of each other. - - Parameters - ---------- - x : np.ndarray - (N, 3) array of node coordinates - radius : float - Radius within which to build connections in the graph - metric : str, default 'euclidean' - Distance metric - - Returns - ------- - np.ndarray - (E, 2) array of edges in the radius graph - """ - # Initialize an empty list of edges to add to the graph - assert x.shape[1] == 3, "Only supports 3D points for now." - edges = nb.typed.List.empty_list(nb.int64[:]) - - if metric == 'euclidean': - for i in range(x.shape[0]): - for j in range(i+1, x.shape[0]): - dist = np.sqrt( - (x[i][0] - x[j][0])**2 + - (x[i][1] - x[j][1])**2 + - (x[i][2] - x[j][2])**2) - if dist <= radius: - edges.append(np.array([i, j])) - - elif metric == 'cityblock': - for i in range(x.shape[0]): - for j in range(i+1, x.shape[0]): - dist = ( - abs(x[i][0] - x[j][0]) + - abs(x[i][1] - x[j][1]) + - abs(x[i][2] - x[j][2])) - if dist <= radius: - edges.append(np.array([i, j])) - - elif metric == 'chebyshev': - for i in range(x.shape[0]): - for j in range(i+1, x.shape[0]): - dist = max( - abs(x[i][0] - x[j][0]), - abs(x[i][1] - x[j][1]), - abs(x[i][2] - x[j][2])) - if dist <= radius: - edges.append(np.array([i, j])) - - else: - raise ValueError("Distance metric not recognized.") - - edge_index = np.empty((len(edges), 2), dtype=np.int64) - for i, e in enumerate(edges): - edge_index[i] = e - - return edge_index - - -@nb.njit(cache=True) -def union_find(edge_index: nb.int64[:,:], - count: nb.int64, - return_inverse: bool = True) -> nb.int64[:]: - """Numba implementation of the Union-Find algorithm. - - This function assigns a group to each node in a graph, provided - a set of edges connecting the nodes together. - - Parameters - ---------- - edge_index : np.ndarray - (E, 2) List of edges (sparse adjacency matrix) - count : int - Number of nodes in the graph, C - return_inverse : bool, default True - Make sure the group IDs range from 0 to N_groups-1 - - Returns - ------- - np.ndarray - (C) Group assignments for each of the nodes in the graph - Dict[int, np.ndarray] - Dictionary which maps groups to indexes - """ - labels = np.arange(count) - groups = {i: np.array([i]) for i in labels} - for e in edge_index: - li, lj = labels[e[0]], labels[e[1]] - if li != lj: - labels[groups[lj]] = li - groups[li] = np.concatenate((groups[li], groups[lj])) - del groups[lj] - - if return_inverse: - mask = np.zeros(count, dtype=np.bool_) - mask[labels] = True - mapping = np.empty(count, dtype=labels.dtype) - mapping[mask] = np.arange(np.sum(mask)) - labels = mapping[labels] - - return labels, groups - - -@nb.njit(cache=True) -def dbscan(x: nb.float32[:, :], - eps: nb.float32, - metric: str = 'euclidean') -> nb.float32[:]: - """Runs DBSCAN on 3D points and returns the group assignments. - - Notes - ----- - The traditional 'min_samples' is always set to 1 here. - - Parameters - ---------- - x : np.ndarray - (N, 3) array of point coordinates - eps : float - Distance below which two points are considered neighbors - metric : str, default 'euclidean' - Distance metric used to compute pdist - - Returns - ------- - np.ndarray - (N) Group assignments - """ - # Produce a sparse adjacency matrix (edge index) - edge_index = radius_graph(x, eps, metric) - - # Build groups - return union_find(edge_index, len(x), return_inverse=True)[0] - - -@nb.njit(cache=True) -def principal_components(x: nb.float32[:,:]) -> nb.float32[:,:]: - """Computes the principal components of a point cloud by computing the - eigenvectors of the centered covariance matrix. - - Parameters - ---------- - x : np.ndarray - (N, d) Coordinates in d dimensions - - Returns - ------- - np.ndarray - (d, d) List of principal components (row-ordered) - """ - # Get covariance matrix - A = np.cov(x.T, ddof = len(x) - 1).astype(x.dtype) # Casting needed... - - # Get eigenvectors - _, v = np.linalg.eigh(A) - v = np.ascontiguousarray(np.fliplr(v).T) - - return v - - -@nb.njit(cache=True) -def farthest_pair(x: nb.float32[:,:], - algorithm: str = 'brute') -> (nb.int32, nb.int32, nb.float32): - """Algorithm which finds the two points which are farthest from each other - in a set. - - Two algorithms: - - `brute`: compute pdist, use argmax - - `recursive`: Start with the first point in one set, find the farthest - point in the other, move to that point, repeat. This - algorithm is *not* exact, but a good and very quick proxy. - - Parameters - ---------- - x : np.ndarray - (N, 3) array of point coordinates - algorithm : str - Name of the algorithm to use: `brute` or `recursive` - - Returns - ------- - int - ID of the first point that makes up the pair - int - ID of the second point that makes up the pair - float - Distance between the two points - """ - if algorithm == 'brute': - dist_mat = pdist(x) - index = np.argmax(dist_mat) - idxs = [index//x.shape[0], index%x.shape[0]] - dist = dist_mat[idxs[0], idxs[1]] - - elif algorithm == 'recursive': - centroid = mean(x, 0) - start_idx = np.argmax(cdist(centroid.reshape(1, -1), x)) - idxs, subidx, dist, tempdist = [start_idx, start_idx], 0, 0., -1. - while dist > tempdist: - tempdist = dist - dists = cdist(np.ascontiguousarray(x[idxs[subidx]]).reshape(1,-1), x).flatten() - idxs[~subidx] = np.argmax(dists) - dist = dists[idxs[~subidx]] - subidx = ~subidx - - else: - raise ValueError("Algorithm not supported") - - return idxs[0], idxs[1], dist - - -@nb.njit(cache=True) -def closest_pair(x1: nb.float32[:,:], - x2: nb.float32[:,:], - algorithm: bool = 'brute', - seed: bool = True) -> (nb.int32, nb.int32, nb.float32): - """Algorithm which finds the two points which are closest to each other - from two separate sets. - - Two algorithms: - - `brute`: compute cdist, use argmin - - `recursive`: Start with one point in one set, find the closest - point in the other set, move to theat point, repeat. This - algorithm is *not* exact, but a good and very quick proxy. - - Parameters - ---------- - x1 : np.ndarray - (Nx3) array of point coordinates in the first set - x1 : np.ndarray - (Nx3) array of point coordinates in the second set - algorithm : str - Name of the algorithm to use: `brute` or `recursive` - seed : bool - Whether or not to use the two farthest points in one set to seed the recursion - - Returns - ------- - int - ID of the first point that makes up the pair - int - ID of the second point that makes up the pair - float - Distance between the two points - """ - # Find the two points in two sets of points that are closest to each other - if algorithm == 'brute': - # Compute every pair-wise distances between the two sets - dist_mat = cdist(x1, x2) - - # Select the closest pair of point - index = np.argmin(dist_mat) - idxs = [index//dist_mat.shape[1], index%dist_mat.shape[1]] - dist = dist_mat[idxs[0], idxs[1]] - - elif algorithm == 'recursive': - # Pick the point to start iterating from - xarr = [x1, x2] - idxs, set_id, dist, tempdist = [0, 0], 0, 1e9, 1e9+1. - if seed: - # Find the end points of the two sets - for i, x in enumerate(xarr): - seed_idxs = np.array(farthest_pair(xarr[i], 'recursive')[:2]) - seed_dists = cdist(xarr[i][seed_idxs], xarr[~i]) - seed_argmins = argmin(seed_dists, axis=1) - seed_mins = np.array([seed_dists[0][seed_argmins[0]], - seed_dists[1][seed_argmins[1]]]) - if np.min(seed_mins) < dist: - set_id = ~i - seed_choice = np.argmin(seed_mins) - idxs[int(~set_id)] = seed_idxs[seed_choice] - idxs[int(set_id)] = seed_argmins[seed_choice] - dist = seed_mins[seed_choice] - - # Find the closest point in the other set, repeat until convergence - while dist < tempdist: - tempdist = dist - dists = cdist(np.ascontiguousarray(xarr[set_id][idxs[set_id]]).reshape(1,-1), xarr[~set_id]).flatten() - idxs[~set_id] = np.argmin(dists) - dist = dists[idxs[~set_id]] - subidx = ~set_id - else: - raise ValueError("Algorithm not supported") - - return idxs[0], idxs[1], dist diff --git a/spine/utils/ppn.py b/spine/utils/ppn.py index 697fa2ecf..8704e2892 100644 --- a/spine/utils/ppn.py +++ b/spine/utils/ppn.py @@ -8,17 +8,13 @@ import numba as nb import torch from warnings import warn -from typing import Union, List -from scipy.special import softmax as softmax_sp -from scipy.spatial.distance import cdist as cdist_sp +import spine.math as sm from spine.data import TensorBatch -from . import numba_local as nbl from .decorators import numbafy -from .dbscan import dbscan_points -from .torch_local import local_cdist +from .torch_local import cdist_fast from .globals import ( BATCH_COL, COORD_COLS, PPN_ROFF_COLS, PPN_RTYPE_COLS, PPN_RPOS_COLS, PPN_SCORE_COLS, PPN_OCC_COL, PPN_CLASS_COLS, PPN_SHAPE_COL, @@ -194,7 +190,7 @@ def process_single(self, ppn_raw, ppn_coords, ppn_mask, ppn_ends=None, dtype, device = ppn_raw.dtype, ppn_raw.device cat, unique, argmax = torch.cat, torch.unique, torch.argmax where, mean, softmax = torch.where, torch.mean, torch.softmax - cdist = local_cdist + cdist = cdist_fast empty = lambda x: torch.empty(x, dtype=dtype, device=device) zeros = lambda x: torch.zeros(x, dtype=dtype, device=device) pool_fn = getattr(torch, self.pool_score_fn) @@ -203,8 +199,8 @@ def process_single(self, ppn_raw, ppn_coords, ppn_mask, ppn_ends=None, else: cat, unique, argmax = np.concatenate, np.unique, np.argmax - where, mean, softmax = np.where, np.mean, softmax_sp - cdist = cdist_sp + where, mean = np.where, np.mean + softmax, cdist = sm.softmax, sm.distance.cdist empty = lambda x: np.empty(x, dtype=ppn_raw.dtype) zeros = lambda x: np.zeros(x, dtype=ppn_raw.dtype) pool_fn = getattr(np, self.pool_score_fn) @@ -283,11 +279,9 @@ def process_single(self, ppn_raw, ppn_coords, ppn_mask, ppn_ends=None, # Cluster nearby points together if torch.is_tensor(coords): - clusts = dbscan_points( - coords.detach().cpu().numpy(), eps=self.pool_dist, - min_samples=1) + clusts = self.dbscan_points(coords.detach().cpu().numpy()) else: - clusts = dbscan_points(coords, eps=self.pool_dist, min_samples=1) + clusts = self.dbscan_points(coords) ppn_pred = empty((len(clusts), 13 + 2*(ppn_ends is not None))) for i, c in enumerate(clusts): @@ -304,6 +298,30 @@ def process_single(self, ppn_raw, ppn_coords, ppn_mask, ppn_ends=None, return ppn_pred + def dbscan_points(self, coordinates): + """Form clusters of predited points based on proximity. + + Parameters + ---------- + coordinates : np.ndarray + Coordinates of the points to cluster + + Returns + ------- + List[np.ndarray] + List of proposed point cluster indexes + """ + # Assign cluster labels to all proposed poins + labels = sm.cluster.dbscan( + coordinates, eps=self.pool_dist, min_samples=1) + + # Convert the list of labels into a list of cluster indexes + clusts = [] + for c in np.unique(labels): + clusts.append(np.where(labels == c)[0]) + + return clusts + class ParticlePointPredictor: """Produces start/end points given a list of particles and PPN predictions. @@ -342,9 +360,7 @@ def __init__(self, use_numpy=True, contained_first=True, anchor_points=True, self.contained_first = contained_first self.anchor_points = anchor_points self.enhance_track_points = enhance_track_points - self.farthest_pair_algo = 'brute' - if approx_farthest_points: - self.farthest_pair_algo = 'recursive' + self.approx_farthest_points = approx_farthest_points def __call__(self, data, clusts, clust_shapes, ppn_points): """Assign start/end points to one batch of events. @@ -418,7 +434,7 @@ def get_end_points_torch(self, points, clusts, clusts_seg, ppn_points): # For tracks, find the two poins farthest away from each other if clusts_seg[i] == TRACK_SHP: # Get the two most separated points in the cluster - idx = torch.argmax(local_cdist(points_c, points_c)) + idx = torch.argmax(cdist_fast(points_c, points_c)) idxs = sorted([int(idx//len(points_c)), int(idx%len(points_c))]) track_points = points_c[idxs] @@ -432,7 +448,7 @@ def get_end_points_torch(self, points, clusts, clusts_seg, ppn_points): # If needed, anchor the track endpoints to the track cluster if self.anchor_points: - dist_mat = local_cdist(track_points, points_c) + dist_mat = cdist_fast(track_points, points_c) track_points = points_c[torch.argmin(dist_mat, 1)] # Store @@ -459,7 +475,7 @@ def get_end_points_torch(self, points, clusts, clusts_seg, ppn_points): # If needed, anchor the shower start point to the shower cluster if self.anchor_points: - dists = local_cdist(start_point[None, :], points_c) + dists = cdist_fast(start_point[None, :], points_c) start_point = points_c[torch.argmin(dists)] # Store twice to preserve the feature vector length @@ -491,7 +507,7 @@ def get_end_points_numpy(self, points, clusts, clust_shapes, ppn_points): return self._get_end_points_numpy( points, clusts, clust_shapes, ppn_points, self.contained_first, self.anchor_points, - self.enhance_track_points, self.farthest_pair_algo) + self.enhance_track_points, self.approx_farthest_points) @staticmethod @nb.njit(cache=True, parallel=True, nogil=True) @@ -502,7 +518,7 @@ def _get_end_points_numpy(points: nb.float32[:,:], contained_first: nb.boolean, anchor_points: nb.boolean, enhance_track_points: nb.boolean, - farthest_pair_algo: str): + approx_farthest_pair: nb.boolean): # Loop over the relevant clusters end_points = np.empty((len(clusts), 6), dtype=points.dtype) for k in nb.prange(len(clusts)): @@ -515,7 +531,7 @@ def _get_end_points_numpy(points: nb.float32[:,:], if clust_shapes[k] == TRACK_SHP: # Get the two most separated points in the cluster idxs = np.sort(np.array( - nbl.farthest_pair(points_c, farthest_pair_algo)[:2])) + sm.distance.farthest_pair(points_c, approx_farthest_pair)[:2])) track_points = points_c[idxs] # If requested, enhance using the PPN predictions. Only consider @@ -529,7 +545,7 @@ def _get_end_points_numpy(points: nb.float32[:,:], # If needed, anchor the track endpoints to the track cluster if anchor_points: - dist_mat = nbl.cdist(track_points, points_c) + dist_mat = sm.distance.cdist(track_points, points_c) track_points = points_c[np.argmin(dist_mat, 1)] # Store @@ -539,12 +555,12 @@ def _get_end_points_numpy(points: nb.float32[:,:], else: # Only use positive voxels and give precedence to predictions # that are contained within the voxel making the prediction. - ppn_scores = nbl.softmax(ppn_points_c[:, PPN_RPOS_COLS], 1)[:, -1] + ppn_scores = sm.softmax(ppn_points_c[:, PPN_RPOS_COLS], 1)[:, -1] if contained_first: dists = np.abs(ppn_points_c[:, PPN_ROFF_COLS]) val_index = np.where( - (ppn_scores > 0.5) & nbl.all(dists < 1., 1))[0] + (ppn_scores > 0.5) & sm.all(dists < 1., 1))[0] if len(val_index): best_id = val_index[np.argmax(ppn_scores[val_index])] else: @@ -558,7 +574,7 @@ def _get_end_points_numpy(points: nb.float32[:,:], # If needed, anchor the shower start point to the shower cluster if anchor_points: - dists = nbl.cdist(start_point[None, :], points_c) + dists = sm.distance.cdist(start_point[None, :], points_c) start_point = points_c[np.argmin(dists)] # Store twice to preserve the feature vector length @@ -599,7 +615,7 @@ def check_track_orientation_ppn(start_point, end_point, ppn_candidates): # Compute the distance between the track end points and the PPN candidates end_points = np.vstack([start_point, end_point]) - dist_mat = nbl.cdist(end_points, ppn_points) + dist_mat = sm.distance.cdist(end_points, ppn_points) # If both track end points are closest to the same PPN point, the start # point must be closest to it if the score is high, farthest otherwise diff --git a/spine/utils/torch_local.py b/spine/utils/torch_local.py index e5b4504a9..4039e164e 100644 --- a/spine/utils/torch_local.py +++ b/spine/utils/torch_local.py @@ -3,7 +3,7 @@ import torch -def local_cdist(v1, v2): +def cdist_fast(v1, v2, metric='euclidean'): """Computes the pairwise distances between two `torch.Tensor` objects. This is necessary because the torch.cdist implementation is either @@ -17,6 +17,8 @@ def local_cdist(v1, v2): (N, D) tensor of coordinates v2 : torch.Tensor (M, D) tensor of coordinates + metric : str + Distance metric Returns ------- @@ -25,33 +27,9 @@ def local_cdist(v1, v2): """ v1_2 = v1.unsqueeze(1).expand(v1.size(0), v2.size(0), v1.size(1)) v2_2 = v2.unsqueeze(0).expand(v1.size(0), v2.size(0), v1.size(1)) - return torch.sqrt(torch.pow(v2_2 - v1_2, 2).sum(2)) - - -def unique_index(x, dim=None): - """Returns the list of unique indexes in the tensor and their first index. - - This is a temporary implementation until PyTorch adds support for the - `return_index` argument in their `torch.unique` function. - - Parameters - ---------- - x : torch.Tensor - (N) Tensor of values - - Returns - ------- - unique : torch.Tensor - (U) List of unique values in the input tensor - index : torch.Tensor - (U) List of the first index of each unique values - """ - unique, inverse = torch.unique( - x, sorted=True, return_inverse=True, dim=dim) - perm = torch.arange(inverse.size(0), dtype=inverse.dtype, - device=inverse.device) - inverse, perm = inverse.flip([0]), perm.flip([0]) - - index = inverse.new_empty(unique.size(0)).scatter_(0, inverse, perm) - - return unique.long(), index + if metric == 'cityblock': + return torch.abs(v2_2 - v1_2).sum(2) + elif metric == 'euclidean': + return torch.sqrt(torch.pow(v2_2 - v1_2, 2).sum(2)) + elif metric == 'chebyshev': + return torch.abs(v2_2 - v1_2).amax(2) diff --git a/spine/utils/tracking.py b/spine/utils/tracking.py index dc77f8966..817bb9141 100644 --- a/spine/utils/tracking.py +++ b/spine/utils/tracking.py @@ -3,7 +3,7 @@ from scipy.interpolate import UnivariateSpline -from . import numba_local as nbl +import spine.math as sm def get_track_length(coordinates: nb.float32[:,:], @@ -42,7 +42,7 @@ def get_track_length(coordinates: nb.float32[:,:], """ if method == 'displacement': # Project points along the principal component, compute displacement - track_dir = nbl.principal_components(coordinates)[0] + track_dir = sm.decomposition.principal_components(coordinates)[0] pcoordinates = np.dot(coordinates, track_dir) return np.max(pcoordinates) - np.min(pcoordinates) @@ -111,8 +111,8 @@ def check_track_orientation(coordinates: nb.float32[:,:], # If requested, anchor the end points to the closest track points end_points = np.vstack((start_point, end_point)) if anchor_points: - dist_mat = nbl.cdist(end_points, coordinates) - end_ids = nbl.argmin(dist_mat, axis=1) + dist_mat = sm.distance.cdist(end_points, coordinates) + end_ids = sm.argmin(dist_mat, axis=1) end_points = coordinates[end_ids] # Compute the local dE/dx around each end, pick the end with the lowest @@ -322,12 +322,12 @@ def get_track_segments(coordinates: nb.float32[:,:], if point is not None: start_point = point if anchor_point: - start_id = np.argmin(nbl.cdist(np.atleast_2d(point), - coordinates)) + start_id = np.argmin(sm.distance.cdist( + np.atleast_2d(point), coordinates)) start_point = coordinates[start_id] else: # If not specified, pick a random end point of the track - start_id = nbl.farthest_pair(coordinates)[0] + start_id = sm.distance.farthest_pair(coordinates)[0] start_point = coordinates[start_id] # Step through the track iteratively @@ -340,8 +340,8 @@ def get_track_segments(coordinates: nb.float32[:,:], while len(left_index): # Compute distances from the segment start point to the all # the leftover points - dists = nbl.cdist(np.atleast_2d(seg_start), - coordinates[left_index]).flatten() + dists = sm.distance.cdist( + np.atleast_2d(seg_start), coordinates[left_index]).flatten() # Select the points that belong to this segment dist_mask = dists <= segment_length @@ -361,11 +361,11 @@ def get_track_segments(coordinates: nb.float32[:,:], # Estimate the direction of the segment seg_coords = coordinates[seg_index] - if method == 'step' \ - and len(seg_index) > min_count \ - and np.max(dists[pass_index]) > 0.: + if (method == 'step' + and len(seg_index) > min_count + and np.max(dists[pass_index]) > 0.): # Estimate direction w.r.t. the segment start point ('step') - direction = nbl.mean(seg_coords - seg_start, axis=0) + direction = sm.mean(seg_coords - seg_start, axis=0) elif len(fail_index): # Take direction as the vector joining the next closest point # ('step_next'). Also apply this method is the `min_count` @@ -414,7 +414,7 @@ def get_track_segments(coordinates: nb.float32[:,:], elif method == 'bin_pca': # Find the principal component of the whole track - track_dir = nbl.principal_components(coordinates)[0] + track_dir = sm.decomposition.principal_components(coordinates)[0] # If a track end point is provided, check which end the track end point # is on and flip the principal axis coordinates, if needed @@ -452,7 +452,7 @@ def get_track_segments(coordinates: nb.float32[:,:], # Compute principal component of the segment, use it as direction if len(seg) > min_count: - direction = nbl.principal_components(coordinates[seg])[0] + direction = sm.decomposition.principal_components(coordinates[seg])[0] if np.dot(direction, track_dir) < 0.: direction = -direction else: @@ -510,7 +510,7 @@ def get_track_spline(coordinates, segment_length, s=None): The estimate of the total length of the curve """ # Compute the principal component along which to segment the track - track_dir = nbl.principal_components(coordinates)[0] + track_dir = sm.decomposition.principal_components(coordinates)[0] pcoords = np.dot(coordinates, track_dir) perm = np.argsort(pcoords.squeeze()) u = pcoords[perm] diff --git a/spine/utils/vertex.py b/spine/utils/vertex.py index 2524513f5..079820631 100644 --- a/spine/utils/vertex.py +++ b/spine/utils/vertex.py @@ -1,7 +1,8 @@ import numpy as np import numba as nb -from . import numba_local as nbl +import spine.math as sm + from .globals import TRACK_SHP, INTER_COL, PRINT_COL, VTX_COLS @@ -171,7 +172,7 @@ def get_confluence_points(start_points: nb.float32[:,:], for j, (sj, ej) in enumerate(zip(start_points, end_points)): if j > i: pointsj = np.vstack((sj, ej)) - submat = nbl.cdist(pointsi, pointsj) + submat = sm.distance.cdist(pointsi, pointsj) mini, minj = np.argmin(submat)//2, np.argmin(submat)%2 dist_mat[i,j] = submat[mini, minj] end_mat[i,j], end_mat[j,i] = mini, minj @@ -187,7 +188,7 @@ def get_confluence_points(start_points: nb.float32[:,:], # Find cycles to build particle groups and confluence points (vertices) leftover = np.ones(n_part, dtype=np.bool_) - max_walks = nbl.amax(walk_mat, axis=1) + max_walks = sm.amax(walk_mat, axis=1) vertices = nb.typed.List.empty_list(np.empty(0, dtype=start_points.dtype)) while np.any(leftover): # Find the longest available cycle (must be at least 2 particles) @@ -203,7 +204,7 @@ def get_confluence_points(start_points: nb.float32[:,:], # Take the barycenter of the touching particle ends as the vertex if end_points is None: - vertices.append(nbl.mean(start_points[group], axis=0)) + vertices.append(sm.mean(start_points[group], axis=0)) else: vertex = np.zeros(3, dtype=start_points.dtype) for i, t in enumerate(group): diff --git a/spine/vis/__init__.py b/spine/vis/__init__.py index fc6bb7357..cc20c9b1d 100644 --- a/spine/vis/__init__.py +++ b/spine/vis/__init__.py @@ -1,13 +1,13 @@ """Module which centralizes all tools used to visualize data.""" -from .out import Drawer -from .geo import GeoDrawer -from .train import TrainDrawer -from .point import scatter_points -from .arrow import scatter_arrows -from .cluster import scatter_clusters -from .box import scatter_boxes -from .particle import scatter_particles -from .network import network_topology, network_schematic -from .evaluation import heatmap, annotate_heatmap -from .layout import layout3d, dual_figure3d +from .out import * +from .geo import * +from .train import * +from .point import * +from .arrow import * +from .cluster import * +from .box import * +from .particle import * +from .network import * +from .evaluation import * +from .layout import * diff --git a/spine/vis/arrow.py b/spine/vis/arrow.py index 52fe63d3b..db5afa959 100644 --- a/spine/vis/arrow.py +++ b/spine/vis/arrow.py @@ -7,6 +7,8 @@ from .point import scatter_points +__all__ = ['scatter_arrows'] + def scatter_arrows(points, directions, length=10.0, tip_ratio=0.25, color=None, hovertext=None, line=None, linewidth=5, name=None): diff --git a/spine/vis/box.py b/spine/vis/box.py index 7e8fdfed5..64bd96091 100644 --- a/spine/vis/box.py +++ b/spine/vis/box.py @@ -14,6 +14,8 @@ import numpy as np import plotly.graph_objs as go +__all__ = ['scatter_boxes'] + def box_trace(lower, upper, draw_faces=False, line=None, linewidth=None, color=None, cmin=None, cmax=None, colorscale=None, diff --git a/spine/vis/cluster.py b/spine/vis/cluster.py index 4a6b0f389..19c60f500 100644 --- a/spine/vis/cluster.py +++ b/spine/vis/cluster.py @@ -9,6 +9,8 @@ from .cone import cone_trace from .hull import hull_trace +__all__ = ['scatter_clusters'] + def scatter_clusters(points, clusts, color=None, hovertext=None, single_trace=False, name=None, mode='scatter', diff --git a/spine/vis/cone.py b/spine/vis/cone.py index 143f4242f..45c626f8d 100644 --- a/spine/vis/cone.py +++ b/spine/vis/cone.py @@ -3,7 +3,9 @@ import numpy as np from plotly import graph_objs as go -from spine.utils.numba_local import principal_components +from spine.math.decomposition import principal_components + +__all__ = ['cone_trace'] def cone_trace(points, fraction=0.5, num_samples=10, color=None, hovertext=None, diff --git a/spine/vis/ellipsoid.py b/spine/vis/ellipsoid.py index 424101e59..15eada35a 100644 --- a/spine/vis/ellipsoid.py +++ b/spine/vis/ellipsoid.py @@ -6,6 +6,8 @@ from scipy.special import gammaincinv # pylint: disable=E0611 import plotly.graph_objs as go +__all__ = ['ellipsoid_trace'] + def ellipsoid_trace(points=None, centroid=None, covmat=None, contour=0.5, num_samples=10, color=None, intensity=None, hovertext=None, diff --git a/spine/vis/evaluation.py b/spine/vis/evaluation.py index 67912b466..37ae8dd04 100644 --- a/spine/vis/evaluation.py +++ b/spine/vis/evaluation.py @@ -4,6 +4,8 @@ import matplotlib.pyplot as plt from matplotlib.ticker import Formatter +__all__ = ['heatmap', 'annotate_heatmap'] + class UncertaintyFormatter(Formatter): """Use a new-style format string (as used by `str.format`) to format the tick. diff --git a/spine/vis/geo.py b/spine/vis/geo.py index b1475ed67..eab826ff1 100644 --- a/spine/vis/geo.py +++ b/spine/vis/geo.py @@ -10,6 +10,8 @@ from .box import box_traces from .ellipsoid import ellipsoid_traces +__all__ = ['GeoDrawer'] + class GeoDrawer: """Handles drawing all things related to the detector geometry. diff --git a/spine/vis/hull.py b/spine/vis/hull.py index 20a124fb3..a7f4c93e0 100644 --- a/spine/vis/hull.py +++ b/spine/vis/hull.py @@ -3,6 +3,8 @@ import numpy as np import plotly.graph_objs as go +__all__ = ['hull_trace'] + def hull_trace(points, color=None, intensity=None, hovertext=None, showscale=False, alphahull=0, **kwargs): diff --git a/spine/vis/layout.py b/spine/vis/layout.py index f6069094c..b60e6f759 100644 --- a/spine/vis/layout.py +++ b/spine/vis/layout.py @@ -19,6 +19,8 @@ HIGH_CONTRAST_COLORS = np.concatenate( [colors.qualitative.Dark24, colors.qualitative.Light24]) +__all__ = ['layout3d', 'dual_figure3d'] + def layout3d(ranges=None, meta=None, detector=None, titles=None, detector_coords=False, backgroundcolor='white', diff --git a/spine/vis/network.py b/spine/vis/network.py index 6af159187..0590d0f20 100644 --- a/spine/vis/network.py +++ b/spine/vis/network.py @@ -2,12 +2,15 @@ import numpy as np +from spine.math.distance import closest_pair + from spine.utils.globals import COORD_COLS -from spine.utils.numba_local import closest_pair from .point import scatter_points from .cluster import scatter_clusters +__all__ = ['network_topology', 'network_schematic'] + def network_topology(points, clusts, edge_index, clust_labels=None, edge_labels=None, mode='scatter', color=None, diff --git a/spine/vis/out.py b/spine/vis/out.py index bdb276a7f..57e0c3e47 100644 --- a/spine/vis/out.py +++ b/spine/vis/out.py @@ -15,6 +15,8 @@ from .layout import ( layout3d, dual_figure3d, PLOTLY_COLORS_WGRAY, HIGH_CONTRAST_COLORS) +__all__ = ['Drawer'] + class Drawer: """Handles drawing the true/reconstructed output. diff --git a/spine/vis/particle.py b/spine/vis/particle.py index 91f79556f..741c7a690 100644 --- a/spine/vis/particle.py +++ b/spine/vis/particle.py @@ -8,6 +8,8 @@ from .point import scatter_points from .layout import HIGH_CONTRAST_COLORS +__all__ = ['scatter_particles'] + def scatter_particles(cluster_label, particles, part_col=PART_COL, markersize=1, **kwargs): diff --git a/spine/vis/point.py b/spine/vis/point.py index 9fc463d6b..5773fd292 100644 --- a/spine/vis/point.py +++ b/spine/vis/point.py @@ -5,6 +5,8 @@ from spine.utils.globals import COORD_COLS +__all__ = ['scatter_points'] + def scatter_points(points, color=None, markersize=2, linewidth=2, colorscale=None, cmin=None, cmax=None, opacity=None, diff --git a/spine/vis/train.py b/spine/vis/train.py index bc8d14c73..579114356 100644 --- a/spine/vis/train.py +++ b/spine/vis/train.py @@ -14,6 +14,8 @@ from .layout import PLOTLY_COLORS_TUPLE, color_rgba, apply_latex_style +__all__ = ['TrainDrawer'] + class TrainDrawer: """Class which centralizes function used to monitor a training process."""