Skip to content

Commit b845b8c

Browse files
authored
Tanimoto sparse optimization (#489)
1 parent ecabd5f commit b845b8c

File tree

3 files changed

+260
-197
lines changed

3 files changed

+260
-197
lines changed

skfp/distances/tanimoto.py

Lines changed: 83 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import numba
21
import numpy as np
32
from scipy.sparse import csr_array
43
from sklearn.utils._param_validation import validate_params
@@ -30,10 +29,10 @@ def tanimoto_binary_similarity(
3029
3130
Parameters
3231
----------
33-
vec_a : {ndarray, sparse matrix}
32+
vec_a : ndarray or CSR sparse matrix
3433
First binary input array or sparse matrix.
3534
36-
vec_b : {ndarray, sparse matrix}
35+
vec_b : ndarray or CSR sparse matrix
3736
Second binary input array or sparse matrix.
3837
3938
Returns
@@ -80,7 +79,7 @@ def tanimoto_binary_similarity(
8079
intersection = len(vec_a_idxs & vec_b_idxs)
8180
union = len(vec_a_idxs | vec_b_idxs)
8281

83-
sim = intersection / union if union != 0 else 1.0
82+
sim = intersection / union if union != 0 else 1
8483
return float(sim)
8584

8685

@@ -112,10 +111,10 @@ def tanimoto_binary_distance(
112111
113112
Parameters
114113
----------
115-
vec_a : {ndarray, sparse matrix}
114+
vec_a : ndarray or CSR sparse matrix
116115
First binary input array or sparse matrix.
117116
118-
vec_b : {ndarray, sparse matrix}
117+
vec_b : ndarray or CSR sparse matrix
119118
Second binary input array or sparse matrix.
120119
121120
References
@@ -176,10 +175,10 @@ def tanimoto_count_similarity(
176175
177176
Parameters
178177
----------
179-
vec_a : {ndarray, sparse matrix}
178+
vec_a : ndarray or CSR sparse matrix
180179
First count input array or sparse matrix.
181180
182-
vec_b : {ndarray, sparse matrix}
181+
vec_b : ndarray or CSR sparse matrix
183182
Second count input array or sparse matrix.
184183
185184
Returns
@@ -229,7 +228,7 @@ def tanimoto_count_similarity(
229228
intersection = dot_ab
230229
union = dot_aa + dot_bb - dot_ab
231230

232-
sim = intersection / union if union >= 1e-8 else 1.0
231+
sim = intersection / union if union >= 1e-8 else 1
233232
return float(sim)
234233

235234

@@ -260,10 +259,10 @@ def tanimoto_count_distance(
260259
261260
Parameters
262261
----------
263-
vec_a : {ndarray, sparse matrix}
262+
vec_a : ndarray or CSR sparse matrix
264263
First count input array or sparse matrix.
265264
266-
vec_b : {ndarray, sparse matrix}
265+
vec_b : ndarray or CSR sparse matrix
267266
Second count input array or sparse matrix.
268267
269268
References
@@ -299,11 +298,14 @@ def tanimoto_count_distance(
299298

300299

301300
@validate_params(
302-
{"X": ["array-like"], "Y": ["array-like", None]},
301+
{
302+
"X": ["array-like", csr_array],
303+
"Y": ["array-like", csr_array, None],
304+
},
303305
prefer_skip_nested_validation=True,
304306
)
305307
def bulk_tanimoto_binary_similarity(
306-
X: np.ndarray, Y: np.ndarray | None = None
308+
X: np.ndarray | csr_array, Y: np.ndarray | csr_array | None = None
307309
) -> np.ndarray:
308310
r"""
309311
Bulk Tanimoto similarity for binary matrices.
@@ -317,12 +319,12 @@ def bulk_tanimoto_binary_similarity(
317319
318320
Parameters
319321
----------
320-
X : ndarray
321-
First binary input array, of shape :math:`m \times m`
322+
X : ndarray or CSR sparse array
323+
First binary input array or sparse matrix, of shape :math:`m \times m`.
322324
323-
Y : ndarray, default=None
324-
Second binary input array, of shape :math:`n \times n`. If not passed, similarities
325-
are computed between rows of X.
325+
Y : ndarray or CSR sparse array, default=None
326+
Second binary input array or sparse matrix, of shape :math:`n \times n`.
327+
If not passed, similarities are computed between rows of X.
326328
327329
Returns
328330
-------
@@ -345,39 +347,39 @@ def bulk_tanimoto_binary_similarity(
345347
array([[1. , 0.33333333],
346348
[0.5 , 0.5 ]])
347349
"""
350+
if not isinstance(X, csr_array):
351+
X = csr_array(X)
352+
348353
if Y is None:
349354
return _bulk_tanimoto_binary_similarity_single(X)
350355
else:
356+
if not isinstance(Y, csr_array):
357+
Y = csr_array(Y)
358+
351359
return _bulk_tanimoto_binary_similarity_two(X, Y)
352360

353361

354-
@numba.njit(parallel=True)
355-
def _bulk_tanimoto_binary_similarity_single(X: np.ndarray) -> np.ndarray:
356-
m = X.shape[0]
357-
sims = np.empty((m, m))
362+
def _bulk_tanimoto_binary_similarity_single(X: csr_array) -> np.ndarray:
363+
intersection = (X @ X.T).toarray()
364+
row_sums = np.asarray(X.sum(axis=1)).ravel()
365+
unions = np.add.outer(row_sums, row_sums) - intersection
358366

359-
for i in numba.prange(m):
360-
sims[i, i] = 1.0
361-
for j in numba.prange(i + 1, m):
362-
intersection = np.sum(np.logical_and(X[i], X[j]))
363-
union = np.sum(np.logical_or(X[i], X[j]))
364-
sim = intersection / union if union != 0 else 1.0
365-
sims[i, j] = sims[j, i] = sim
367+
sims = np.empty_like(intersection, dtype=float)
368+
np.divide(intersection, unions, out=sims, where=unions != 0)
366369

367370
return sims
368371

369372

370-
@numba.njit(parallel=True)
371-
def _bulk_tanimoto_binary_similarity_two(X: np.ndarray, Y: np.ndarray) -> np.ndarray:
372-
m = X.shape[0]
373-
n = Y.shape[0]
374-
sims = np.empty((m, n))
373+
def _bulk_tanimoto_binary_similarity_two(X: csr_array, Y: csr_array) -> np.ndarray:
374+
intersection = (X @ Y.T).toarray()
375+
376+
row_sums_X = np.asarray(X.sum(axis=1)).ravel()
377+
row_sums_Y = np.asarray(Y.sum(axis=1)).ravel()
378+
379+
unions = np.add.outer(row_sums_X, row_sums_Y) - intersection
375380

376-
for i in numba.prange(m):
377-
for j in numba.prange(n):
378-
intersection = np.sum(np.logical_and(X[i], Y[j]))
379-
union = np.sum(np.logical_or(X[i], Y[j]))
380-
sims[i, j] = intersection / union if union != 0 else 1.0
381+
sims = np.empty_like(intersection, dtype=float)
382+
np.divide(intersection, unions, out=sims, where=unions != 0)
381383

382384
return sims
383385

@@ -390,7 +392,7 @@ def _bulk_tanimoto_binary_similarity_two(X: np.ndarray, Y: np.ndarray) -> np.nda
390392
prefer_skip_nested_validation=True,
391393
)
392394
def bulk_tanimoto_binary_distance(
393-
X: np.ndarray, Y: np.ndarray | None = None
395+
X: np.ndarray | csr_array, Y: np.ndarray | csr_array | None = None
394396
) -> np.ndarray:
395397
r"""
396398
Bulk Tanimoto distance for vectors of binary values.
@@ -404,12 +406,12 @@ def bulk_tanimoto_binary_distance(
404406
405407
Parameters
406408
----------
407-
X : ndarray
408-
First binary input array, of shape :math:`m \times m`
409+
X : ndarray or CSR sparse array
410+
First binary input array or sparse matrix, of shape :math:`m \times m`
409411
410-
Y : ndarray, default=None
411-
Second binary input array, of shape :math:`n \times n`. If not passed, distances
412-
are computed between rows of X.
412+
Y : ndarray or CSR sparse array, default=None
413+
Second binary input array or sparse matrix, of shape :math:`n \times n`.
414+
If not passed, distances are computed between rows of X.
413415
414416
Returns
415417
-------
@@ -442,11 +444,14 @@ def bulk_tanimoto_binary_distance(
442444

443445

444446
@validate_params(
445-
{"X": ["array-like"], "Y": ["array-like", None]},
447+
{
448+
"X": ["array-like", csr_array],
449+
"Y": ["array-like", csr_array, None],
450+
},
446451
prefer_skip_nested_validation=True,
447452
)
448453
def bulk_tanimoto_count_similarity(
449-
X: np.ndarray, Y: np.ndarray | None = None
454+
X: np.ndarray | csr_array, Y: np.ndarray | csr_array | None = None
450455
) -> np.ndarray:
451456
r"""
452457
Bulk Tanimoto similarity for count matrices.
@@ -460,12 +465,12 @@ def bulk_tanimoto_count_similarity(
460465
461466
Parameters
462467
----------
463-
X : ndarray
464-
First count input array, of shape :math:`m \times m`
468+
X : ndarray or CSR sparse array
469+
First binary input array or sparse matrix, of shape :math:`m \times m`
465470
466-
Y : ndarray, default=None
467-
Second count input array, of shape :math:`n \times n`. If not passed, similarities
468-
are computed between rows of X.
471+
Y : ndarray or CSR sparse array, default=None
472+
Second binary input array or sparse matrix, of shape :math:`n \times n`.
473+
If not passed, similarities are computed between rows of X.
469474
470475
Returns
471476
-------
@@ -489,59 +494,41 @@ def bulk_tanimoto_count_similarity(
489494
[0.5 , 0.5 ]])
490495
"""
491496
X = X.astype(float) # Numba does not allow integers
497+
if not isinstance(X, csr_array):
498+
X = csr_array(X)
492499

493500
if Y is None:
494501
return _bulk_tanimoto_count_similarity_single(X)
495502
else:
496-
Y = Y.astype(float)
497-
return _bulk_tanimoto_count_similarity_two(X, Y)
498-
499-
500-
@numba.njit(parallel=True)
501-
def _bulk_tanimoto_count_similarity_single(X: np.ndarray) -> np.ndarray:
502-
m = X.shape[0]
503-
sims = np.empty((m, m))
503+
Y = Y.astype(float) # Numba does not allow integers
504+
if not isinstance(Y, csr_array):
505+
Y = csr_array(Y)
504506

505-
for i in numba.prange(m):
506-
vec_a = X[i]
507-
sims[i, i] = 1.0
507+
return _bulk_tanimoto_count_similarity_two(X, Y)
508508

509-
for j in numba.prange(i + 1, m):
510-
vec_b = X[j]
511509

512-
dot_aa = np.dot(vec_a, vec_a)
513-
dot_bb = np.dot(vec_b, vec_b)
514-
dot_ab = np.dot(vec_a, vec_b)
510+
def _bulk_tanimoto_count_similarity_single(X: csr_array) -> np.ndarray:
511+
inter = (X @ X.T).toarray()
512+
row_norms = np.asarray(X.multiply(X).sum(axis=1)).ravel()
513+
unions = np.add.outer(row_norms, row_norms) - inter
515514

516-
intersection = dot_ab
517-
union = dot_aa + dot_bb - dot_ab
515+
sims = np.empty_like(inter, dtype=float)
516+
np.divide(inter, unions, out=sims, where=unions >= 1e-8)
518517

519-
sim = intersection / union if union >= 1e-8 else 1.0
520-
sims[i, j] = sims[j, i] = sim
518+
np.fill_diagonal(sims, 1)
521519

522520
return sims
523521

524522

525-
@numba.jit(parallel=True)
526-
def _bulk_tanimoto_count_similarity_two(X: np.ndarray, Y: np.ndarray) -> np.ndarray:
527-
m = X.shape[0]
528-
n = Y.shape[0]
529-
sims = np.empty((m, n))
530-
531-
for i in numba.prange(m):
532-
vec_a = X[i]
533-
534-
for j in numba.prange(n):
535-
vec_b = Y[j]
536-
537-
dot_aa = np.dot(vec_a, vec_a)
538-
dot_bb = np.dot(vec_b, vec_b)
539-
dot_ab = np.dot(vec_a, vec_b)
523+
def _bulk_tanimoto_count_similarity_two(X: csr_array, Y: csr_array) -> np.ndarray:
524+
inter = (X @ Y.T).toarray()
525+
row_norms_X = np.asarray(X.multiply(X).sum(axis=1)).ravel()
526+
row_norms_Y = np.asarray(Y.multiply(Y).sum(axis=1)).ravel()
540527

541-
intersection = dot_ab
542-
union = dot_aa + dot_bb - dot_ab
528+
unions = np.add.outer(row_norms_X, row_norms_Y) - inter
543529

544-
sims[i, j] = intersection / union if union >= 1e-8 else 1.0
530+
sims = np.empty_like(inter, dtype=float)
531+
np.divide(inter, unions, out=sims, where=unions >= 1e-8)
545532

546533
return sims
547534

@@ -554,7 +541,7 @@ def _bulk_tanimoto_count_similarity_two(X: np.ndarray, Y: np.ndarray) -> np.ndar
554541
prefer_skip_nested_validation=True,
555542
)
556543
def bulk_tanimoto_count_distance(
557-
X: np.ndarray, Y: np.ndarray | None = None
544+
X: np.ndarray | csr_array, Y: np.ndarray | csr_array | None = None
558545
) -> np.ndarray:
559546
r"""
560547
Bulk Tanimoto distance for vectors of count values.
@@ -568,12 +555,12 @@ def bulk_tanimoto_count_distance(
568555
569556
Parameters
570557
----------
571-
X : ndarray
572-
First count input array, of shape :math:`m \times m`
558+
X : ndarray or CSR sparse array
559+
First binary input array or sparse matrix, of shape :math:`m \times m`
573560
574-
Y : ndarray, default=None
575-
Second count input array, of shape :math:`n \times n`. If not passed, distances
576-
are computed between rows of X.
561+
Y : ndarray or CSR sparse array, default=None
562+
Second binary input array or sparse matrix, of shape :math:`n \times n`.
563+
If not passed, distances are computed between rows of X.
577564
578565
Returns
579566
-------

0 commit comments

Comments
 (0)