Skip to content

Commit 293457e

Browse files
johnnychen94KristofferC
authored andcommitted
rework result_type (#140)
* rework result_type The essence of `result_type` is based on the `eltype(a)`, and `result_type(::PreMetric, ::AbstractArray, ::AbstractArray)` only serves as an convenient method. There's no functionality changes in this commit, only to reorganize the codes to ease future development. * update docstring * delete usage example
1 parent c21aab0 commit 293457e

File tree

4 files changed

+18
-12
lines changed

4 files changed

+18
-12
lines changed

src/generic.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,15 @@ abstract type Metric <: SemiMetric end
2424

2525
# Generic functions
2626

27-
result_type(::PreMetric, ::AbstractArray, ::AbstractArray) = Float64
27+
"""
28+
result_type(dist::PreMetric, Ta::Type, Tb::Type) -> T
29+
result_type(dist::PreMetric, a::AbstractArray, b::AbstractArray) -> T
30+
31+
Infer the result type of metric `dist` with input type `Ta` and `Tb`, or input
32+
data `a` and `b`.
33+
"""
34+
result_type(::PreMetric, ::Type, ::Type) = Float64 # fallback
35+
result_type(dist::PreMetric, a::AbstractArray, b::AbstractArray) = result_type(dist, eltype(a), eltype(b))
2836

2937

3038
# Generic column-wise evaluation

src/mahalanobis.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ struct SqMahalanobis{T} <: SemiMetric
88
qmat::Matrix{T}
99
end
1010

11-
result_type(::Mahalanobis{T}, ::AbstractArray, ::AbstractArray) where {T} = T
12-
result_type(::SqMahalanobis{T}, ::AbstractArray, ::AbstractArray) where {T} = T
11+
result_type(::Mahalanobis{T}, ::Type, ::Type) where {T} = T
12+
result_type(::SqMahalanobis{T}, ::Type, ::Type) where {T} = T
1313

1414
# SqMahalanobis
1515

src/metrics.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -252,9 +252,8 @@ end
252252
end
253253
return eval_end(d, s)
254254
end
255-
result_type(dist::UnionMetrics, a::AbstractArray, b::AbstractArray) =
256-
typeof(evaluate(dist, oneunit(eltype(a)), oneunit(eltype(b))))
257-
255+
result_type(dist::UnionMetrics, Ta::Type, Tb::Type) =
256+
typeof(evaluate(dist, oneunit(Ta), oneunit(Tb)))
258257
eval_start(d::UnionMetrics, a::AbstractArray, b::AbstractArray) =
259258
zero(result_type(d, a, b))
260259
eval_end(d::UnionMetrics, s) = s
@@ -352,7 +351,7 @@ evaluate(::CorrDist, a::AbstractArray, b::AbstractArray) = cosine_dist(_centrali
352351
# Ambiguity resolution
353352
evaluate(::CorrDist, a::Array, b::Array) = cosine_dist(_centralize(a), _centralize(b))
354353
corr_dist(a::AbstractArray, b::AbstractArray) = evaluate(CorrDist(), a, b)
355-
result_type(::CorrDist, a::AbstractArray, b::AbstractArray) = result_type(CosineDist(), a, b)
354+
result_type(::CorrDist, Ta::Type, Tb::Type) = result_type(CosineDist(), Ta, Tb)
356355

357356
# ChiSqDist
358357
@inline eval_op(::ChiSqDist, ai, bi) = (d = abs2(ai - bi) / (ai + bi); ifelse(ai != bi, d, zero(d)))
@@ -452,9 +451,8 @@ end
452451

453452
eval_end(::SpanNormDist, s) = s[2] - s[1]
454453
spannorm_dist(a::AbstractArray, b::AbstractArray) = evaluate(SpanNormDist(), a, b)
455-
result_type(dist::SpanNormDist, a::AbstractArray, b::AbstractArray) =
456-
typeof(eval_op(dist, oneunit(eltype(a)), oneunit(eltype(b))))
457-
454+
result_type(dist::SpanNormDist, Ta::Type, Tb::Type) =
455+
typeof(eval_op(dist, oneunit(Ta), oneunit(Tb)))
458456

459457
# Jaccard
460458

src/wmetrics.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ Base.eltype(x::UnionWeightedMetrics) = eltype(x.weights)
4242
function evaluate(dist::UnionWeightedMetrics, a::Number, b::Number)
4343
eval_end(dist, eval_op(dist, a, b, oneunit(eltype(dist))))
4444
end
45-
result_type(dist::UnionWeightedMetrics, a::AbstractArray, b::AbstractArray) =
46-
typeof(evaluate(dist, oneunit(eltype(a)), oneunit(eltype(b))))
45+
result_type(dist::UnionWeightedMetrics, Ta::Type, Tb::Type) =
46+
typeof(evaluate(dist, oneunit(Ta), oneunit(Tb)))
4747

4848
@inline function eval_start(d::UnionWeightedMetrics, a::AbstractArray, b::AbstractArray)
4949
zero(result_type(d, a, b))

0 commit comments

Comments
 (0)