Skip to content

Commit 834268b

Browse files
committed
Merge pull request #29 from KristofferC/pull-request/fc2cc4da
unify the result_type method
2 parents 3a2a801 + 8c9dc03 commit 834268b

File tree

3 files changed

+14
-10
lines changed

3 files changed

+14
-10
lines changed

src/generic.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ abstract Metric <: SemiMetric
2424

2525
# Generic functions
2626

27-
result_type(::PreMetric, T1::Type, T2::Type) = Float64
27+
result_type(::PreMetric, ::AbstractArray, ::AbstractArray) = Float64
2828

2929

3030
# Generic column-wise evaluation
@@ -62,19 +62,19 @@ end
6262

6363
function colwise(metric::PreMetric, a::AbstractMatrix, b::AbstractMatrix)
6464
n = get_common_ncols(a, b)
65-
r = Array(result_type(metric, eltype(a), eltype(b)), n)
65+
r = Array(result_type(metric, a, b), n)
6666
colwise!(r, metric, a, b)
6767
end
6868

6969
function colwise(metric::PreMetric, a::AbstractVector, b::AbstractMatrix)
7070
n = size(b, 2)
71-
r = Array(result_type(metric, eltype(a), eltype(b)), n)
71+
r = Array(result_type(metric, a, b), n)
7272
colwise!(r, metric, a, b)
7373
end
7474

7575
function colwise(metric::PreMetric, a::AbstractMatrix, b::AbstractVector)
7676
n = size(a, 2)
77-
r = Array(result_type(metric, eltype(a), eltype(b)), n)
77+
r = Array(result_type(metric, a, b), n)
7878
colwise!(r, metric, a, b)
7979
end
8080

@@ -117,19 +117,19 @@ end
117117
function pairwise(metric::PreMetric, a::AbstractMatrix, b::AbstractMatrix)
118118
m = size(a, 2)
119119
n = size(b, 2)
120-
r = Array(result_type(metric, eltype(a), eltype(b)), (m, n))
120+
r = Array(result_type(metric, a, b), (m, n))
121121
pairwise!(r, metric, a, b)
122122
end
123123

124124
function pairwise(metric::PreMetric, a::AbstractMatrix)
125125
n = size(a, 2)
126-
r = Array(result_type(metric, eltype(a), eltype(a)), (n, n))
126+
r = Array(result_type(metric, a, a), (n, n))
127127
pairwise!(r, metric, a)
128128
end
129129

130130
function pairwise(metric::SemiMetric, a::AbstractMatrix)
131131
n = size(a, 2)
132-
r = Array(result_type(metric, eltype(a), eltype(a)), (n, n))
132+
r = Array(result_type(metric, a, a), (n, n))
133133
pairwise!(r, metric, a)
134134
end
135135

src/mahalanobis.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ immutable SqMahalanobis{T} <: SemiMetric
88
qmat::Matrix{T}
99
end
1010

11-
result_type{T}(::Mahalanobis{T}, T1::Type, T2::Type) = T
12-
result_type{T}(::SqMahalanobis{T}, T1::Type, T2::Type) = T
11+
result_type{T}(::Mahalanobis{T}, ::AbstractArray, ::AbstractArray) = T
12+
result_type{T}(::SqMahalanobis{T}, ::AbstractArray, ::AbstractArray) = T
1313

1414
# SqMahalanobis
1515

src/metrics.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ cosine_dist(a::AbstractArray, b::AbstractArray) = evaluate(CosineDist(), a, b)
130130
_centralize(x::AbstractArray) = x .- mean(x)
131131
evaluate(::CorrDist, a::AbstractArray, b::AbstractArray) = cosine_dist(_centralize(a), _centralize(b))
132132
corr_dist(a::AbstractArray, b::AbstractArray) = evaluate(CorrDist(), a, b)
133+
result_type(::CorrDist, a::AbstractArray, b::AbstractArray) = result_type(CosineDist(), a, b)
133134

134135
# ChiSqDist
135136
@compat @inline eval_op(::ChiSqDist, ai, bi) = abs2(ai - bi) / (ai + bi)
@@ -166,9 +167,12 @@ end
166167
end
167168
return min_d, max_d
168169
end
170+
169171
eval_end(::SpanNormDist, s) = s[2] - s[1]
170172
spannorm_dist(a::AbstractArray, b::AbstractArray) = evaluate(SpanNormDist(), a, b)
171-
result_type(dist::SpanNormDist, T1::Type, T2::Type) = typeof(eval_op(dist, one(T1), one(T2)))
173+
function result_type{T1, T2}(dist::SpanNormDist, ::AbstractArray{T1}, ::AbstractArray{T2})
174+
typeof(eval_op(dist, one(T1), one(T2)))
175+
end
172176

173177

174178
###########################################################

0 commit comments

Comments
 (0)