Skip to content

Commit cae43c8

Browse files
committed
Accelerate ExpandNeighbours() with (R)cpp code
1 parent 3055fbc commit cae43c8

File tree

7 files changed

+164
-29
lines changed

7 files changed

+164
-29
lines changed

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ importFrom(Matrix,t)
111111
importFrom(Rcpp,sourceCpp)
112112
importFrom(RhpcBLASctl,blas_get_num_procs)
113113
importFrom(RhpcBLASctl,blas_set_num_threads)
114+
importFrom(RhpcBLASctl,omp_get_max_threads)
114115
importFrom(RhpcBLASctl,omp_get_num_procs)
115116
importFrom(RhpcBLASctl,omp_set_num_threads)
116117
importFrom(Seurat,"DefaultAssay<-")

NEWS.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# SeuratIntegrate (development version)
22

3+
* Speed up Dijkstra's algorithm-like used in `ExpandNeighbours` with a new c++
4+
implementation
5+
36
* The most suited matrix format is automatically chosen for corrected counts
47
output by integration methods (should be dense matrix most of the time)
58

R/RcppExports.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
# Generated by using Rcpp::compileAttributes() -> do not edit by hand
22
# Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393
33

4+
dijkstra_cpp <- function(m, k) {
5+
.Call(`_SeuratIntegrate_dijkstra_cpp`, m, k)
6+
}
7+
48
n_zeros_dense_mat <- function(mat) {
59
.Call(`_SeuratIntegrate_n_zeros_dense_mat`, mat)
610
}

R/expand_knn_graph.R

Lines changed: 62 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,13 @@ NULL
2020
#' See \strong{Details} section for further explanations
2121
#' @param algo One of "dijkstra" or "diffusion". "diffusion" is suited for
2222
#' connectivity matrices only
23-
#' @param which.dijkstra one of "igraph", "fast" or "slow". "auto" (default)
24-
#' chooses for you. See \strong{Details} section
23+
#' @param which.dijkstra one of "cpp", "igraph", "fast" or "slow". "auto"
24+
#' (default) chooses for you. See \strong{Details} section
2525
#' @param dijkstra.ncores number of cores to use for Dijkstra's algorithm.
2626
#' Ignored when \code{which.dijkstra = "igraph"}
2727
#' @param dijkstra.tol number of sequential iterations with identical best
2828
#' neighbours found to consider that Dijkstra's algorithm should be stopped.
29-
#' Ignored when \code{which.dijkstra = "igraph"}
29+
#' Ignored when \code{which.dijkstra = "igraph" | "cpp"}
3030
#' @param diffusion.iter maximum number of iterations to reach \code{k.target}
3131
#' @param verbose whether to print progress messages
3232
#'
@@ -43,10 +43,29 @@ NULL
4343
#'
4444
#' One can choose to keep the graph as it is and consider it as a directed graph
4545
#' (\code{do.symmetrize = FALSE}).
46+
#'
4647
#' The alternative solution is to use all computed distances to extend the knn
4748
#' graph by making the matrix symmetric. Note that connectivity graphs are
4849
#' already symmetric, so the argument value should have no effect on the result.
4950
#'
51+
#' The choice of Dijkstra's algorithm implementation (\code{which.dijkstra}) is
52+
#' handled automatically by default. The igraph implementation is only preferred
53+
#' when the number of cells is low (1,000 at most). However, for a large
54+
#' \code{k.target} value (from approximately 1,000), igraph is faster on a single
55+
#' thread and thus can be imposed by setting \code{which.dijkstra = 'igraph'}.
56+
#'
57+
#' "fast" and "slow" are pure R implementations. "fast" is faster than "slow"
58+
#' but is not accurate on symmetric graph because it is suited for knn graphs
59+
#' with a constant k value. "slow" is slower but more accurate. No matter,
60+
#' \strong{"slow" and "fast" are both much slower than "cpp" and are deprecated}.
61+
#' They remain available in the even that some users have trouble running the
62+
#' "cpp" implementation.
63+
#'
64+
#' @note
65+
#' igraph implementation of Dijkstra's algorithm is single threaded only.
66+
#'
67+
#' "cpp" implementation is parallelized with OpenMP
68+
#'
5069
#' @importFrom SeuratObject DefaultAssay Cells as.Graph
5170
#'
5271
#' @export
@@ -56,7 +75,7 @@ setGeneric("ExpandNeighbours",
5675
graph.type = c("distances", "connectivities"),
5776
k.target = 90L, do.symmetrize = FALSE,
5877
algo = c("dijkstra", "diffusion"),
59-
which.dijkstra = c("auto", "igraph", "fast", "slow"),
78+
which.dijkstra = c("auto", "cpp", "igraph", "fast", "slow"),
6079
dijkstra.ncores = 1L, dijkstra.tol = 1L, diffusion.iter = 26L,
6180
assay = NULL, verbose = TRUE)
6281
standardGeneric("ExpandNeighbours"))
@@ -68,7 +87,7 @@ setMethod("ExpandNeighbours", "Seurat",
6887
graph.type = c("distances", "connectivities"),
6988
k.target = 90L, do.symmetrize = FALSE,
7089
algo = c("dijkstra", "diffusion"),
71-
which.dijkstra = c("auto", "igraph", "fast", "slow"),
90+
which.dijkstra = c("auto", "cpp", "igraph", "fast", "slow"),
7291
dijkstra.ncores = 1L, dijkstra.tol = 1L, diffusion.iter = 26L,
7392
assay = NULL, verbose = TRUE) {
7493
assay <- assay %||% DefaultAssay(object)
@@ -118,18 +137,19 @@ setMethod("ExpandNeighbours", "Seurat",
118137
setGeneric("expand_neighbours_dijkstra",
119138
function(object, graph.type = c("distances", "connectivities"),
120139
k.target = 90L, do.symmetrize = FALSE,
121-
which.dijkstra = c("auto", "igraph", "fast", "slow"),
140+
which.dijkstra = c("auto", "cpp", "igraph", "fast", "slow"),
122141
ncores = 1L, tol = 1L, verbose = TRUE)
123142
standardGeneric("expand_neighbours_dijkstra"))
124143

125144
#' @importFrom SeuratObject as.Neighbor
126-
#' @importFrom Matrix sparseMatrix drop0
145+
#' @importFrom Matrix sparseMatrix t drop0
146+
#' @importFrom RhpcBLASctl omp_get_max_threads omp_set_num_threads
127147
#' @keywords internal
128148
#' @noRd
129149
setMethod("expand_neighbours_dijkstra", "Matrix",
130150
function(object, graph.type = c("distances", "connectivities"),
131151
k.target = 90L, do.symmetrize = FALSE,
132-
which.dijkstra = c("auto", "igraph", "fast", "slow"),
152+
which.dijkstra = c("auto", "cpp", "igraph", "fast", "slow"),
133153
ncores = 1L, tol = 1L, verbose = TRUE) {
134154
n <- ncol(object)
135155
const.k <- is.kconstant(object)
@@ -142,14 +162,17 @@ setMethod("expand_neighbours_dijkstra", "Matrix",
142162
}
143163

144164
if (which.dijkstra == "auto") {
145-
if (n <= 1e4) {
146-
which.dijkstra <- "igraph"
147-
} else if (const.k) {
148-
which.dijkstra <- "fast"
165+
which.dijkstra <- if (n <= 1e4) {
166+
"igraph"
149167
} else {
150-
which.dijkstra <- "slow"
168+
"cpp"
151169
}
152170
}
171+
if (which.dijkstra == "cpp") {
172+
oomp <- omp_get_max_threads()
173+
omp_set_num_threads(ncores)
174+
on.exit(omp_set_num_threads(oomp))
175+
}
153176
object.symmetry <- const.k.symmetry <- NULL
154177
igraph.mode <- "directed"
155178
if (do.symmetrize && !isSymmetric(object)) {
@@ -189,7 +212,7 @@ setMethod("expand_neighbours_dijkstra", "Matrix",
189212
warning("When all cells have the same number k of nearest ",
190213
"neighbors, ", sQuote("fast"), " implementation of ",
191214
"Dijkstra's algorithm is recommended")
192-
} else {
215+
} else if (which.dijkstra == "fast"){
193216
object <- as.Neighbor(x = object)
194217
}
195218
} else if (which.dijkstra == "fast") {
@@ -204,6 +227,7 @@ setMethod("expand_neighbours_dijkstra", "Matrix",
204227
message(msg[verbose], appendLF = F)
205228
beginning <- Sys.time()
206229
res <- switch (which.dijkstra,
230+
cpp = dijkstra_cpp(m = t(drop0(object)), k = k.target),
207231
igraph = dijkstra.igraph(knnmat = object, k.target = k.target,
208232
mode = igraph.mode, weighted = T,
209233
diag = F),
@@ -225,17 +249,31 @@ setMethod("expand_neighbours_dijkstra", "Matrix",
225249
i <- rep(1:nrow(res$knn.idx), k.target)
226250
j <- as.vector(res$knn.idx)
227251
x <- as.vector(res$knn.dist)
228-
# correct igraph output when not enough neighbours
229-
infs <- which(is.infinite(x))
230-
if (length(infs) > 0) {
231-
x[is.infinite(x)] <- 0
232-
warning('Dijkstra (igraph) : could not find enough neighbours',
233-
' for ', length(infs), ' cell(s) ',
234-
paste0(infs, collapse = ', '),
235-
call. = F, immediate. = F)
252+
if (which.dijkstra == "igraph") {
253+
# correct igraph output when not enough neighbours
254+
infs <- which(is.infinite(x))
255+
if (length(infs) > 0) {
256+
less_neighbours <- sort(unique(i[infs]))
257+
i <- i[-infs]
258+
j <- j[-infs]
259+
x <- x[-infs]
260+
warning('Dijkstra (igraph): could not find enough neighbours',
261+
' for ', length(less_neighbours), ' cell(s) ',
262+
paste0(less_neighbours, collapse = ', '),
263+
call. = F, immediate. = F)
264+
}
236265
}
266+
237267
expanded.mat <- sparseMatrix(i = i, j = j, x = x, dims = rep(n, 2))
238-
expanded.mat <- drop0(expanded.mat)
268+
if (which.dijkstra == "cpp") {
269+
less_neighbours <- which(get.k(expanded.mat, 'all') < k.target)
270+
if (length(less_neighbours) > 0) {
271+
warning('Dijkstra (cpp): could not find enough neighbours',
272+
' for ', length(less_neighbours), ' cell(s) ',
273+
paste0(less_neighbours, collapse = ', '),
274+
call. = F, immediate. = F)
275+
}
276+
}
239277
if(do.symmetrize && which.dijkstra == "fast") {
240278
expanded.mat <- SymmetrizeKnn(expanded.mat, use.max = FALSE)
241279
}
@@ -249,7 +287,7 @@ setMethod("expand_neighbours_dijkstra", "Matrix",
249287
setMethod("expand_neighbours_dijkstra", "Neighbor",
250288
function(object, graph.type = c("distances", "connectivities"),
251289
k.target = 90L, do.symmetrize = FALSE,
252-
which.dijkstra = c("auto", "igraph", "fast", "slow"),
290+
which.dijkstra = c("auto", "cpp", "igraph", "fast", "slow"),
253291
ncores = 1L, tol = 1L, verbose = TRUE) {
254292

255293
return(expand_neighbours_dijkstra(as.Graph(object),

man/ExpandNeighbours.Rd

Lines changed: 24 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/RcppExports.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,18 @@ Rcpp::Rostream<true>& Rcpp::Rcout = Rcpp::Rcpp_cout_get();
1111
Rcpp::Rostream<false>& Rcpp::Rcerr = Rcpp::Rcpp_cerr_get();
1212
#endif
1313

14+
// dijkstra_cpp
15+
List dijkstra_cpp(const arma::sp_mat& m, const int k);
16+
RcppExport SEXP _SeuratIntegrate_dijkstra_cpp(SEXP mSEXP, SEXP kSEXP) {
17+
BEGIN_RCPP
18+
Rcpp::RObject rcpp_result_gen;
19+
Rcpp::RNGScope rcpp_rngScope_gen;
20+
Rcpp::traits::input_parameter< const arma::sp_mat& >::type m(mSEXP);
21+
Rcpp::traits::input_parameter< const int >::type k(kSEXP);
22+
rcpp_result_gen = Rcpp::wrap(dijkstra_cpp(m, k));
23+
return rcpp_result_gen;
24+
END_RCPP
25+
}
1426
// n_zeros_dense_mat
1527
int64_t n_zeros_dense_mat(const arma::mat& mat);
1628
RcppExport SEXP _SeuratIntegrate_n_zeros_dense_mat(SEXP matSEXP) {
@@ -35,6 +47,7 @@ END_RCPP
3547
}
3648

3749
static const R_CallMethodDef CallEntries[] = {
50+
{"_SeuratIntegrate_dijkstra_cpp", (DL_FUNC) &_SeuratIntegrate_dijkstra_cpp, 2},
3851
{"_SeuratIntegrate_n_zeros_dense_mat", (DL_FUNC) &_SeuratIntegrate_n_zeros_dense_mat, 1},
3952
{"_SeuratIntegrate_n_zeros_sparse_mat", (DL_FUNC) &_SeuratIntegrate_n_zeros_sparse_mat, 1},
4053
{NULL, NULL, 0}

src/dijkstra.cpp

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
#include <RcppArmadillo.h>
2+
#include <queue>
3+
#include <unordered_set>
4+
5+
using namespace Rcpp ;
6+
7+
typedef std::pair<double, int> e;
8+
9+
// [[Rcpp::plugins(openmp)]]
10+
11+
// [[Rcpp::export]]
12+
List dijkstra_cpp(const arma::sp_mat& m, const int k) {
13+
14+
const int n_rows = m.n_rows;
15+
arma::uvec seq_idx = arma::regspace<arma::uvec>(0, 1, n_rows - 1);
16+
arma::umat knn_idx = arma::repmat(seq_idx, 1, k);
17+
arma::dmat knn_dist(n_rows, k, arma::fill::zeros);
18+
19+
#ifdef _OPENMP
20+
#pragma omp parallel num_threads(omp_get_max_threads())
21+
#pragma omp for
22+
#endif
23+
for (int root = 0; root < n_rows; root++) {
24+
arma::sp_mat::const_iterator it;
25+
std::priority_queue<e, std::vector<e>, std::greater<e> > pq;
26+
std::unordered_set<int> reached;
27+
reached.insert(root);
28+
arma::sp_mat col(m.col(root));
29+
for (it = col.begin(); it != col.end(); ++it){
30+
pq.push(std::make_pair((*it), it.row()));
31+
}
32+
33+
int i = 1;
34+
while(i < k) {
35+
e current = pq.top();
36+
int source = current.second;
37+
pq.pop();
38+
auto found = reached.find(source);
39+
if (found == reached.end()) {
40+
knn_idx(root, i) = source;
41+
knn_dist(root, i) = current.first;
42+
i++;
43+
reached.insert(source);
44+
45+
col = m.col(source);
46+
for (it = col.begin(); it != col.end(); ++it){
47+
pq.push(std::make_pair((*it) + current.first, it.row()));
48+
}
49+
} else if (pq.size() == 0) {
50+
break;
51+
}
52+
}
53+
}
54+
55+
return(List::create(Named("knn.idx") = knn_idx + 1,
56+
Named("knn.dist") = knn_dist));
57+
}

0 commit comments

Comments
 (0)