@@ -619,22 +619,22 @@ function _pairwise!(r::AbstractMatrix, dist::Union{SqEuclidean,Euclidean},
619
619
R = inplace ? mul! (r, a' , b) : a' b
620
620
sa2 = sum (abs2, a, dims= 1 )
621
621
sb2 = sum (abs2, b, dims= 1 )
622
- threshT = convert ( eltype (r), dist . thresh )
623
- @inbounds if threshT <= 0
622
+ z² = zero ( real ( eltype (R)) )
623
+ @inbounds if dist . thresh <= 0
624
624
# If there's no chance of triggering the threshold, we can use @simd
625
625
for j = 1 : nb
626
626
sb = sb2[j]
627
627
@simd for i = 1 : na
628
- r[i, j] = eval_end (dist, (max (sa2[i] + sb - 2 real (R[i, j]), 0 )))
628
+ r[i, j] = eval_end (dist, (max (sa2[i] + sb - 2 real (R[i, j]), z² )))
629
629
end
630
630
end
631
631
else
632
632
for j = 1 : nb
633
633
sb = sb2[j]
634
634
for i = 1 : na
635
635
selfterms = sa2[i] + sb
636
- v = max (selfterms - 2 real (R[i, j]), 0 )
637
- if v < threshT * selfterms
636
+ v = max (selfterms - 2 real (R[i, j]), z² )
637
+ if v < dist . thresh * selfterms
638
638
# The distance is likely to be inaccurate, recalculate directly
639
639
# This reflects the following:
640
640
# while sqrt(x+ϵ) ≈ sqrt(x) + O(ϵ/sqrt(x)) when |x| >> ϵ,
@@ -658,22 +658,23 @@ function _pairwise!(r::AbstractMatrix, dist::Union{SqEuclidean,Euclidean}, a::Ab
658
658
# the following checks if a'*b can be stored in r directly, it fails for complex eltypes
659
659
R = inplace ? mul! (r, a' , a) : a' a
660
660
sa2 = sum (abs2, a, dims= 1 )
661
- threshT = convert (eltype (r), dist. thresh)
661
+ safe = dist. thresh <= 0
662
+ z² = zero (real (eltype (R)))
662
663
@inbounds for j = 1 : n
663
664
for i = 1 : (j - 1 )
664
665
r[i, j] = r[j, i]
665
666
end
666
- r[j, j] = 0
667
+ r[j, j] = zero ( eltype (r))
667
668
sa2j = sa2[j]
668
- if threshT <= 0
669
+ if safe
669
670
@simd for i = (j + 1 ): n
670
- r[i, j] = eval_end (dist, (max (sa2[i] + sa2j - 2 real (R[i, j]), 0 )))
671
+ r[i, j] = eval_end (dist, (max (sa2[i] + sa2j - 2 real (R[i, j]), z² )))
671
672
end
672
673
else
673
674
for i = (j + 1 ): n
674
675
selfterms = sa2[i] + sa2j
675
- v = max (selfterms - 2 real (R[i, j]), 0 )
676
- if v < threshT * selfterms
676
+ v = max (selfterms - 2 real (R[i, j]), z² )
677
+ if v < dist . thresh * selfterms
677
678
v = zero (v)
678
679
for k = 1 : m
679
680
v += abs2 (a[k, i] - a[k, j])
@@ -698,9 +699,10 @@ function _pairwise!(r::AbstractMatrix, dist::Union{WeightedSqEuclidean,WeightedE
698
699
# the following checks if a'*b can be stored in r directly, it fails for complex eltypes
699
700
inplace = promote_type (eltype (r), typeof (oneunit (eltype (a))' oneunit (eltype (b)))) === eltype (r)
700
701
R = inplace ? mul! (r, a' , w .* b) : a' * Diagonal (w)* b
702
+ z² = zero (real (eltype (R)))
701
703
for j = 1 : nb
702
704
@simd for i = 1 : na
703
- @inbounds r[i, j] = eval_end (dist, max (sa2[i] + sb2[j] - 2 real (R[i, j]), 0 ))
705
+ @inbounds r[i, j] = eval_end (dist, max (sa2[i] + sb2[j] - 2 real (R[i, j]), z² ))
704
706
end
705
707
end
706
708
r
@@ -715,14 +717,15 @@ function _pairwise!(r::AbstractMatrix, dist::Union{WeightedSqEuclidean,WeightedE
715
717
# the following checks if a'*b can be stored in r directly, it fails for complex eltypes
716
718
inplace = promote_type (eltype (r), typeof (oneunit (eltype (a))' oneunit (eltype (a)))) === eltype (r)
717
719
R = inplace ? mul! (r, a' , w .* a) : a' * Diagonal (w)* a
720
+ z² = zero (real (eltype (R)))
718
721
719
722
@inbounds for j = 1 : n
720
723
for i = 1 : (j - 1 )
721
724
r[i, j] = r[j, i]
722
725
end
723
- r[j, j] = 0
726
+ r[j, j] = zero ( eltype (r))
724
727
@simd for i = (j + 1 ): n
725
- r[i, j] = eval_end (dist, max (sa2[i] + sa2[j] - 2 real (R[i, j]), 0 ))
728
+ r[i, j] = eval_end (dist, max (sa2[i] + sa2[j] - 2 real (R[i, j]), z² ))
726
729
end
727
730
end
728
731
r
@@ -734,28 +737,30 @@ function _pairwise!(r::AbstractMatrix, ::CosineDist,
734
737
a:: AbstractMatrix , b:: AbstractMatrix )
735
738
require_one_based_indexing (r, a, b)
736
739
m, na, nb = get_pairwise_dims (r, a, b)
737
- mul! (r, a' , b)
740
+ inplace = promote_type (eltype (r), typeof (oneunit (eltype (a))' oneunit (eltype (b)))) === eltype (r)
741
+ R = inplace ? mul! (r, a' , b) : a' b
738
742
ra = norm_percol (a)
739
743
rb = norm_percol (b)
740
744
for j = 1 : nb
741
745
@simd for i = 1 : na
742
- @inbounds r[i, j] = max (1 - r [i, j] / (ra[i] * rb[j]), 0 )
746
+ @inbounds r[i, j] = max (1 - R [i, j] / (ra[i] * rb[j]), 0 )
743
747
end
744
748
end
745
749
r
746
750
end
747
751
function _pairwise! (r:: AbstractMatrix , :: CosineDist , a:: AbstractMatrix )
748
752
require_one_based_indexing (r, a)
749
753
m, n = get_pairwise_dims (r, a)
750
- mul! (r, a' , a)
754
+ inplace = promote_type (eltype (r), typeof (oneunit (eltype (a))' oneunit (eltype (a)))) === eltype (r)
755
+ R = inplace ? mul! (r, a' , a) : a' a
751
756
ra = norm_percol (a)
752
757
@inbounds for j = 1 : n
753
758
for i = 1 : (j - 1 )
754
759
r[i, j] = r[j, i]
755
760
end
756
- r[j, j] = 0
761
+ r[j, j] = zero ( eltype (r))
757
762
@simd for i = j + 1 : n
758
- r[i, j] = max (1 - r [i, j] / (ra[i] * ra[j]), 0 )
763
+ r[i, j] = max (1 - R [i, j] / (ra[i] * ra[j]), 0 )
759
764
end
760
765
end
761
766
r
0 commit comments