1- import numba
21import numpy as np
32from scipy .sparse import csr_array
43from sklearn .utils ._param_validation import validate_params
@@ -30,10 +29,10 @@ def tanimoto_binary_similarity(
3029
3130 Parameters
3231 ----------
33- vec_a : { ndarray, sparse matrix}
32+ vec_a : ndarray or CSR sparse matrix
3433 First binary input array or sparse matrix.
3534
36- vec_b : { ndarray, sparse matrix}
35+ vec_b : ndarray or CSR sparse matrix
3736 Second binary input array or sparse matrix.
3837
3938 Returns
@@ -80,7 +79,7 @@ def tanimoto_binary_similarity(
8079 intersection = len (vec_a_idxs & vec_b_idxs )
8180 union = len (vec_a_idxs | vec_b_idxs )
8281
83- sim = intersection / union if union != 0 else 1.0
82+ sim = intersection / union if union != 0 else 1
8483 return float (sim )
8584
8685
@@ -112,10 +111,10 @@ def tanimoto_binary_distance(
112111
113112 Parameters
114113 ----------
115- vec_a : { ndarray, sparse matrix}
114+ vec_a : ndarray or CSR sparse matrix
116115 First binary input array or sparse matrix.
117116
118- vec_b : { ndarray, sparse matrix}
117+ vec_b : ndarray or CSR sparse matrix
119118 Second binary input array or sparse matrix.
120119
121120 References
@@ -176,10 +175,10 @@ def tanimoto_count_similarity(
176175
177176 Parameters
178177 ----------
179- vec_a : { ndarray, sparse matrix}
178+ vec_a : ndarray or CSR sparse matrix
180179 First count input array or sparse matrix.
181180
182- vec_b : { ndarray, sparse matrix}
181+ vec_b : ndarray or CSR sparse matrix
183182 Second count input array or sparse matrix.
184183
185184 Returns
@@ -229,7 +228,7 @@ def tanimoto_count_similarity(
229228 intersection = dot_ab
230229 union = dot_aa + dot_bb - dot_ab
231230
232- sim = intersection / union if union >= 1e-8 else 1.0
231+ sim = intersection / union if union >= 1e-8 else 1
233232 return float (sim )
234233
235234
@@ -260,10 +259,10 @@ def tanimoto_count_distance(
260259
261260 Parameters
262261 ----------
263- vec_a : { ndarray, sparse matrix}
262+ vec_a : ndarray or CSR sparse matrix
264263 First count input array or sparse matrix.
265264
266- vec_b : { ndarray, sparse matrix}
265+ vec_b : ndarray or CSR sparse matrix
267266 Second count input array or sparse matrix.
268267
269268 References
@@ -299,11 +298,14 @@ def tanimoto_count_distance(
299298
300299
301300@validate_params (
302- {"X" : ["array-like" ], "Y" : ["array-like" , None ]},
301+ {
302+ "X" : ["array-like" , csr_array ],
303+ "Y" : ["array-like" , csr_array , None ],
304+ },
303305 prefer_skip_nested_validation = True ,
304306)
305307def bulk_tanimoto_binary_similarity (
306- X : np .ndarray , Y : np .ndarray | None = None
308+ X : np .ndarray | csr_array , Y : np .ndarray | csr_array | None = None
307309) -> np .ndarray :
308310 r"""
309311 Bulk Tanimoto similarity for binary matrices.
@@ -317,12 +319,12 @@ def bulk_tanimoto_binary_similarity(
317319
318320 Parameters
319321 ----------
320- X : ndarray
321- First binary input array, of shape :math:`m \times m`
322+ X : ndarray or CSR sparse array
323+ First binary input array or sparse matrix , of shape :math:`m \times m`.
322324
323- Y : ndarray, default=None
324- Second binary input array, of shape :math:`n \times n`. If not passed, similarities
325- are computed between rows of X.
325+ Y : ndarray or CSR sparse array , default=None
326+ Second binary input array or sparse matrix , of shape :math:`n \times n`.
327+ If not passed, similarities are computed between rows of X.
326328
327329 Returns
328330 -------
@@ -345,39 +347,39 @@ def bulk_tanimoto_binary_similarity(
345347 array([[1. , 0.33333333],
346348 [0.5 , 0.5 ]])
347349 """
350+ if not isinstance (X , csr_array ):
351+ X = csr_array (X )
352+
348353 if Y is None :
349354 return _bulk_tanimoto_binary_similarity_single (X )
350355 else :
356+ if not isinstance (Y , csr_array ):
357+ Y = csr_array (Y )
358+
351359 return _bulk_tanimoto_binary_similarity_two (X , Y )
352360
353361
354- @ numba . njit ( parallel = True )
355- def _bulk_tanimoto_binary_similarity_single ( X : np . ndarray ) -> np . ndarray :
356- m = X . shape [ 0 ]
357- sims = np .empty (( m , m ))
362+ def _bulk_tanimoto_binary_similarity_single ( X : csr_array ) -> np . ndarray :
363+ intersection = ( X @ X . T ). toarray ()
364+ row_sums = np . asarray ( X . sum ( axis = 1 )). ravel ()
365+ unions = np .add . outer ( row_sums , row_sums ) - intersection
358366
359- for i in numba .prange (m ):
360- sims [i , i ] = 1.0
361- for j in numba .prange (i + 1 , m ):
362- intersection = np .sum (np .logical_and (X [i ], X [j ]))
363- union = np .sum (np .logical_or (X [i ], X [j ]))
364- sim = intersection / union if union != 0 else 1.0
365- sims [i , j ] = sims [j , i ] = sim
367+ sims = np .empty_like (intersection , dtype = float )
368+ np .divide (intersection , unions , out = sims , where = unions != 0 )
366369
367370 return sims
368371
369372
370- @numba .njit (parallel = True )
371- def _bulk_tanimoto_binary_similarity_two (X : np .ndarray , Y : np .ndarray ) -> np .ndarray :
372- m = X .shape [0 ]
373- n = Y .shape [0 ]
374- sims = np .empty ((m , n ))
373+ def _bulk_tanimoto_binary_similarity_two (X : csr_array , Y : csr_array ) -> np .ndarray :
374+ intersection = (X @ Y .T ).toarray ()
375+
376+ row_sums_X = np .asarray (X .sum (axis = 1 )).ravel ()
377+ row_sums_Y = np .asarray (Y .sum (axis = 1 )).ravel ()
378+
379+ unions = np .add .outer (row_sums_X , row_sums_Y ) - intersection
375380
376- for i in numba .prange (m ):
377- for j in numba .prange (n ):
378- intersection = np .sum (np .logical_and (X [i ], Y [j ]))
379- union = np .sum (np .logical_or (X [i ], Y [j ]))
380- sims [i , j ] = intersection / union if union != 0 else 1.0
381+ sims = np .empty_like (intersection , dtype = float )
382+ np .divide (intersection , unions , out = sims , where = unions != 0 )
381383
382384 return sims
383385
@@ -390,7 +392,7 @@ def _bulk_tanimoto_binary_similarity_two(X: np.ndarray, Y: np.ndarray) -> np.nda
390392 prefer_skip_nested_validation = True ,
391393)
392394def bulk_tanimoto_binary_distance (
393- X : np .ndarray , Y : np .ndarray | None = None
395+ X : np .ndarray | csr_array , Y : np .ndarray | csr_array | None = None
394396) -> np .ndarray :
395397 r"""
396398 Bulk Tanimoto distance for vectors of binary values.
@@ -404,12 +406,12 @@ def bulk_tanimoto_binary_distance(
404406
405407 Parameters
406408 ----------
407- X : ndarray
408- First binary input array, of shape :math:`m \times m`
409+ X : ndarray or CSR sparse array
410+ First binary input array or sparse matrix , of shape :math:`m \times m`
409411
410- Y : ndarray, default=None
411- Second binary input array, of shape :math:`n \times n`. If not passed, distances
412- are computed between rows of X.
412+ Y : ndarray or CSR sparse array , default=None
413+ Second binary input array or sparse matrix , of shape :math:`n \times n`.
414+ If not passed, distances are computed between rows of X.
413415
414416 Returns
415417 -------
@@ -442,11 +444,14 @@ def bulk_tanimoto_binary_distance(
442444
443445
444446@validate_params (
445- {"X" : ["array-like" ], "Y" : ["array-like" , None ]},
447+ {
448+ "X" : ["array-like" , csr_array ],
449+ "Y" : ["array-like" , csr_array , None ],
450+ },
446451 prefer_skip_nested_validation = True ,
447452)
448453def bulk_tanimoto_count_similarity (
449- X : np .ndarray , Y : np .ndarray | None = None
454+ X : np .ndarray | csr_array , Y : np .ndarray | csr_array | None = None
450455) -> np .ndarray :
451456 r"""
452457 Bulk Tanimoto similarity for count matrices.
@@ -460,12 +465,12 @@ def bulk_tanimoto_count_similarity(
460465
461466 Parameters
462467 ----------
463- X : ndarray
464- First count input array, of shape :math:`m \times m`
468+ X : ndarray or CSR sparse array
469+ First binary input array or sparse matrix , of shape :math:`m \times m`
465470
466- Y : ndarray, default=None
467- Second count input array, of shape :math:`n \times n`. If not passed, similarities
468- are computed between rows of X.
471+ Y : ndarray or CSR sparse array , default=None
472+ Second binary input array or sparse matrix , of shape :math:`n \times n`.
473+ If not passed, similarities are computed between rows of X.
469474
470475 Returns
471476 -------
@@ -489,59 +494,41 @@ def bulk_tanimoto_count_similarity(
489494 [0.5 , 0.5 ]])
490495 """
491496 X = X .astype (float ) # Numba does not allow integers
497+ if not isinstance (X , csr_array ):
498+ X = csr_array (X )
492499
493500 if Y is None :
494501 return _bulk_tanimoto_count_similarity_single (X )
495502 else :
496- Y = Y .astype (float )
497- return _bulk_tanimoto_count_similarity_two (X , Y )
498-
499-
500- @numba .njit (parallel = True )
501- def _bulk_tanimoto_count_similarity_single (X : np .ndarray ) -> np .ndarray :
502- m = X .shape [0 ]
503- sims = np .empty ((m , m ))
503+ Y = Y .astype (float ) # Numba does not allow integers
504+ if not isinstance (Y , csr_array ):
505+ Y = csr_array (Y )
504506
505- for i in numba .prange (m ):
506- vec_a = X [i ]
507- sims [i , i ] = 1.0
507+ return _bulk_tanimoto_count_similarity_two (X , Y )
508508
509- for j in numba .prange (i + 1 , m ):
510- vec_b = X [j ]
511509
512- dot_aa = np .dot (vec_a , vec_a )
513- dot_bb = np .dot (vec_b , vec_b )
514- dot_ab = np .dot (vec_a , vec_b )
510+ def _bulk_tanimoto_count_similarity_single (X : csr_array ) -> np .ndarray :
511+ inter = (X @ X .T ).toarray ()
512+ row_norms = np .asarray (X .multiply (X ).sum (axis = 1 )).ravel ()
513+ unions = np .add .outer (row_norms , row_norms ) - inter
515514
516- intersection = dot_ab
517- union = dot_aa + dot_bb - dot_ab
515+ sims = np . empty_like ( inter , dtype = float )
516+ np . divide ( inter , unions , out = sims , where = unions >= 1e-8 )
518517
519- sim = intersection / union if union >= 1e-8 else 1.0
520- sims [i , j ] = sims [j , i ] = sim
518+ np .fill_diagonal (sims , 1 )
521519
522520 return sims
523521
524522
525- @numba .jit (parallel = True )
526- def _bulk_tanimoto_count_similarity_two (X : np .ndarray , Y : np .ndarray ) -> np .ndarray :
527- m = X .shape [0 ]
528- n = Y .shape [0 ]
529- sims = np .empty ((m , n ))
530-
531- for i in numba .prange (m ):
532- vec_a = X [i ]
533-
534- for j in numba .prange (n ):
535- vec_b = Y [j ]
536-
537- dot_aa = np .dot (vec_a , vec_a )
538- dot_bb = np .dot (vec_b , vec_b )
539- dot_ab = np .dot (vec_a , vec_b )
523+ def _bulk_tanimoto_count_similarity_two (X : csr_array , Y : csr_array ) -> np .ndarray :
524+ inter = (X @ Y .T ).toarray ()
525+ row_norms_X = np .asarray (X .multiply (X ).sum (axis = 1 )).ravel ()
526+ row_norms_Y = np .asarray (Y .multiply (Y ).sum (axis = 1 )).ravel ()
540527
541- intersection = dot_ab
542- union = dot_aa + dot_bb - dot_ab
528+ unions = np .add .outer (row_norms_X , row_norms_Y ) - inter
543529
544- sims [i , j ] = intersection / union if union >= 1e-8 else 1.0
530+ sims = np .empty_like (inter , dtype = float )
531+ np .divide (inter , unions , out = sims , where = unions >= 1e-8 )
545532
546533 return sims
547534
@@ -554,7 +541,7 @@ def _bulk_tanimoto_count_similarity_two(X: np.ndarray, Y: np.ndarray) -> np.ndar
554541 prefer_skip_nested_validation = True ,
555542)
556543def bulk_tanimoto_count_distance (
557- X : np .ndarray , Y : np .ndarray | None = None
544+ X : np .ndarray | csr_array , Y : np .ndarray | csr_array | None = None
558545) -> np .ndarray :
559546 r"""
560547 Bulk Tanimoto distance for vectors of count values.
@@ -568,12 +555,12 @@ def bulk_tanimoto_count_distance(
568555
569556 Parameters
570557 ----------
571- X : ndarray
572- First count input array, of shape :math:`m \times m`
558+ X : ndarray or CSR sparse array
559+ First binary input array or sparse matrix , of shape :math:`m \times m`
573560
574- Y : ndarray, default=None
575- Second count input array, of shape :math:`n \times n`. If not passed, distances
576- are computed between rows of X.
561+ Y : ndarray or CSR sparse array , default=None
562+ Second binary input array or sparse matrix , of shape :math:`n \times n`.
563+ If not passed, distances are computed between rows of X.
577564
578565 Returns
579566 -------
0 commit comments