Skip to content

Commit c21aab0

Browse files
devmotionararslan
authored andcommitted
Loosen type annotations and fix #130 (#134)
1 parent ba3ac88 commit c21aab0

File tree

4 files changed

+76
-31
lines changed

4 files changed

+76
-31
lines changed

src/common.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,10 @@ function sumsq_percol(a::AbstractMatrix{T}) where {T}
109109
return r
110110
end
111111

112-
function wsumsq_percol(w::AbstractArray{T1}, a::AbstractMatrix{T2}) where {T1, T2}
112+
function wsumsq_percol(w::AbstractArray, a::AbstractMatrix)
113113
m = size(a, 1)
114114
n = size(a, 2)
115-
T = typeof(one(T1) * one(T2))
115+
T = typeof(zero(eltype(w)) * abs2(zero(eltype(a))))
116116
r = Vector{T}(undef, n)
117117
for j = 1:n
118118
aj = view(a, :, j)

src/metrics.jl

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -252,22 +252,21 @@ end
252252
end
253253
return eval_end(d, s)
254254
end
255-
result_type(dist::UnionMetrics, ::AbstractArray{T1}, ::AbstractArray{T2}) where {T1, T2} =
256-
typeof(eval_end(dist, parameters(dist) === nothing ?
257-
eval_op(dist, one(T1), one(T2)) :
258-
eval_op(dist, one(T1), one(T2), one(eltype(dist)))))
255+
result_type(dist::UnionMetrics, a::AbstractArray, b::AbstractArray) =
256+
typeof(evaluate(dist, oneunit(eltype(a)), oneunit(eltype(b))))
257+
259258
eval_start(d::UnionMetrics, a::AbstractArray, b::AbstractArray) =
260259
zero(result_type(d, a, b))
261260
eval_end(d::UnionMetrics, s) = s
262261

263-
evaluate(dist::UnionMetrics, a::T, b::T) where {T <: Number} = eval_end(dist, eval_op(dist, a, b))
262+
evaluate(dist::UnionMetrics, a::Number, b::Number) = eval_end(dist, eval_op(dist, a, b))
264263

265264
# SqEuclidean
266265
@inline eval_op(::SqEuclidean, ai, bi) = abs2(ai - bi)
267266
@inline eval_reduce(::SqEuclidean, s1, s2) = s1 + s2
268267

269268
sqeuclidean(a::AbstractArray, b::AbstractArray) = evaluate(SqEuclidean(), a, b)
270-
sqeuclidean(a::T, b::T) where {T <: Number} = evaluate(SqEuclidean(), a, b)
269+
sqeuclidean(a::Number, b::Number) = evaluate(SqEuclidean(), a, b)
271270

272271
# Euclidean
273272
@inline eval_op(::Euclidean, ai, bi) = abs2(ai - bi)
@@ -285,13 +284,13 @@ Base.eltype(d::PeriodicEuclidean) = eltype(d.periods)
285284
s3 = min(s2, p - s2)
286285
abs2(s3)
287286
end
287+
@inline function eval_op(d::PeriodicEuclidean, ai, bi)
288+
periods = d.periods
289+
p = isempty(periods) ? oneunit(eltype(periods)) : first(periods)
290+
eval_op(d, ai, bi, p)
291+
end
288292
@inline eval_reduce(::PeriodicEuclidean, s1, s2) = s1 + s2
289293
@inline eval_end(::PeriodicEuclidean, s) = sqrt(s)
290-
function evaluate(dist::PeriodicEuclidean, a::T, b::T) where {T <: Real}
291-
p = first(dist.periods)
292-
d = mod(abs(a - b), p)
293-
min(d, p - d)
294-
end
295294
peuclidean(a::AbstractArray, b::AbstractArray, p::AbstractArray{<: Real}) =
296295
evaluate(PeriodicEuclidean(p), a, b)
297296
peuclidean(a::Number, b::Number, p::Real) = evaluate(PeriodicEuclidean([p]), a, b)
@@ -300,38 +299,39 @@ peuclidean(a::Number, b::Number, p::Real) = evaluate(PeriodicEuclidean([p]), a,
300299
@inline eval_op(::Cityblock, ai, bi) = abs(ai - bi)
301300
@inline eval_reduce(::Cityblock, s1, s2) = s1 + s2
302301
cityblock(a::AbstractArray, b::AbstractArray) = evaluate(Cityblock(), a, b)
303-
cityblock(a::T, b::T) where {T <: Number} = evaluate(Cityblock(), a, b)
302+
cityblock(a::Number, b::Number) = evaluate(Cityblock(), a, b)
304303

305304
# Total variation
306305
@inline eval_op(::TotalVariation, ai, bi) = abs(ai - bi)
307306
@inline eval_reduce(::TotalVariation, s1, s2) = s1 + s2
308307
eval_end(::TotalVariation, s) = s / 2
309308
totalvariation(a::AbstractArray, b::AbstractArray) = evaluate(TotalVariation(), a, b)
310-
totalvariation(a::T, b::T) where {T <: Number} = evaluate(TotalVariation(), a, b)
309+
totalvariation(a::Number, b::Number) = evaluate(TotalVariation(), a, b)
311310

312311
# Chebyshev
313312
@inline eval_op(::Chebyshev, ai, bi) = abs(ai - bi)
314313
@inline eval_reduce(::Chebyshev, s1, s2) = max(s1, s2)
315314
# if only NaN, will output NaN
316315
@inline Base.@propagate_inbounds eval_start(::Chebyshev, a::AbstractArray, b::AbstractArray) = abs(a[1] - b[1])
317316
chebyshev(a::AbstractArray, b::AbstractArray) = evaluate(Chebyshev(), a, b)
318-
chebyshev(a::T, b::T) where {T <: Number} = evaluate(Chebyshev(), a, b)
317+
chebyshev(a::Number, b::Number) = evaluate(Chebyshev(), a, b)
319318

320319
# Minkowski
321320
@inline eval_op(dist::Minkowski, ai, bi) = abs(ai - bi).^dist.p
322321
@inline eval_reduce(::Minkowski, s1, s2) = s1 + s2
323322
eval_end(dist::Minkowski, s) = s.^(1 / dist.p)
324323
minkowski(a::AbstractArray, b::AbstractArray, p::Real) = evaluate(Minkowski(p), a, b)
325-
minkowski(a::T, b::T, p::Real) where {T <: Number} = evaluate(Minkowski(p), a, b)
324+
minkowski(a::Number, b::Number, p::Real) = evaluate(Minkowski(p), a, b)
326325

327326
# Hamming
328327
@inline eval_op(::Hamming, ai, bi) = ai != bi ? 1 : 0
329328
@inline eval_reduce(::Hamming, s1, s2) = s1 + s2
330329
hamming(a::AbstractArray, b::AbstractArray) = evaluate(Hamming(), a, b)
331-
hamming(a::T, b::T) where {T <: Number} = evaluate(Hamming(), a, b)
330+
hamming(a::Number, b::Number) = evaluate(Hamming(), a, b)
332331

333332
# Cosine dist
334-
@inline function eval_start(::CosineDist, a::AbstractArray{T}, b::AbstractArray{T}) where {T <: Real}
333+
@inline function eval_start(dist::CosineDist, a::AbstractArray, b::AbstractArray)
334+
T = Base.promote_typeof(eval_op(dist, oneunit(eltype(a)), oneunit(eltype(b)))...)
335335
zero(T), zero(T), zero(T)
336336
end
337337
@inline eval_op(::CosineDist, ai, bi) = ai * bi, ai * ai, bi * bi
@@ -360,12 +360,14 @@ result_type(::CorrDist, a::AbstractArray, b::AbstractArray) = result_type(Cosine
360360
chisq_dist(a::AbstractArray, b::AbstractArray) = evaluate(ChiSqDist(), a, b)
361361

362362
# KLDivergence
363-
@inline eval_op(::KLDivergence, ai, bi) = ai > 0 ? ai * log(ai / bi) : zero(ai)
363+
@inline eval_op(dist::KLDivergence, ai, bi) =
364+
ai > 0 ? ai * log(ai / bi) : zero(eval_op(dist, oneunit(ai), bi))
364365
@inline eval_reduce(::KLDivergence, s1, s2) = s1 + s2
365366
kl_divergence(a::AbstractArray, b::AbstractArray) = evaluate(KLDivergence(), a, b)
366367

367368
# GenKLDivergence
368-
@inline eval_op(::GenKLDivergence, ai, bi) = ai > 0 ? ai * log(ai / bi) - ai + bi : bi
369+
@inline eval_op(dist::GenKLDivergence, ai, bi) =
370+
ai > 0 ? ai * log(ai / bi) - ai + bi : oftype(eval_op(dist, oneunit(ai), bi), bi)
369371
@inline eval_reduce(::GenKLDivergence, s1, s2) = s1 + s2
370372
gkl_divergence(a::AbstractArray, b::AbstractArray) = evaluate(GenKLDivergence(), a, b)
371373

@@ -450,15 +452,17 @@ end
450452

451453
eval_end(::SpanNormDist, s) = s[2] - s[1]
452454
spannorm_dist(a::AbstractArray, b::AbstractArray) = evaluate(SpanNormDist(), a, b)
453-
function result_type(dist::SpanNormDist, ::AbstractArray{T1}, ::AbstractArray{T2}) where {T1, T2}
454-
typeof(eval_op(dist, one(T1), one(T2)))
455-
end
455+
result_type(dist::SpanNormDist, a::AbstractArray, b::AbstractArray) =
456+
typeof(eval_op(dist, oneunit(eltype(a)), oneunit(eltype(b))))
456457

457458

458459
# Jaccard
459460

460461
@inline eval_start(::Jaccard, a::AbstractArray{Bool}, b::AbstractArray{Bool}) = 0, 0
461-
@inline eval_start(::Jaccard, a::AbstractArray{T}, b::AbstractArray{T}) where {T} = zero(T), zero(T)
462+
@inline function eval_start(dist::Jaccard, a::AbstractArray, b::AbstractArray)
463+
T = Base.promote_typeof(eval_op(dist, oneunit(eltype(a)), oneunit(eltype(b)))...)
464+
zero(T), zero(T)
465+
end
462466
@inline function eval_op(::Jaccard, s1, s2)
463467
abs_m = abs(s1 - s2)
464468
abs_p = abs(s1 + s2)
@@ -478,7 +482,10 @@ jaccard(a::AbstractArray, b::AbstractArray) = evaluate(Jaccard(), a, b)
478482
# BrayCurtis
479483

480484
@inline eval_start(::BrayCurtis, a::AbstractArray{Bool}, b::AbstractArray{Bool}) = 0, 0
481-
@inline eval_start(::BrayCurtis, a::AbstractArray{T}, b::AbstractArray{T}) where {T} = zero(T), zero(T)
485+
@inline function eval_start(dist::BrayCurtis, a::AbstractArray, b::AbstractArray)
486+
T = Base.promote_typeof(eval_op(dist, oneunit(eltype(a)), oneunit(eltype(b)))...)
487+
zero(T), zero(T)
488+
end
482489
@inline function eval_op(::BrayCurtis, s1, s2)
483490
abs_m = abs(s1 - s2)
484491
abs_p = abs(s1 + s2)

src/wmetrics.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,12 @@ Base.eltype(x::UnionWeightedMetrics) = eltype(x.weights)
3939
#
4040
###########################################################
4141

42-
function evaluate(dist::UnionWeightedMetrics, a::T, b::T) where {T <: Number}
43-
eval_end(dist, eval_op(dist, a, b, one(eltype(dist))))
44-
end
45-
function result_type(dist::UnionWeightedMetrics, ::AbstractArray{T1}, ::AbstractArray{T2}) where {T1, T2}
46-
typeof(evaluate(dist, one(T1), one(T2)))
42+
function evaluate(dist::UnionWeightedMetrics, a::Number, b::Number)
43+
eval_end(dist, eval_op(dist, a, b, oneunit(eltype(dist))))
4744
end
45+
result_type(dist::UnionWeightedMetrics, a::AbstractArray, b::AbstractArray) =
46+
typeof(evaluate(dist, oneunit(eltype(a)), oneunit(eltype(b))))
47+
4848
@inline function eval_start(d::UnionWeightedMetrics, a::AbstractArray, b::AbstractArray)
4949
zero(result_type(d, a, b))
5050
end

test/test_dists.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,44 @@ end # testset
299299
@test_throws DimensionMismatch evaluate(Bregman(x -> sqeuclidean(x, zero(x)), x -> [1, 2]), [1, 2, 3], [1, 2, 3])
300300
end # testset
301301

302+
@testset "Different input types" begin
303+
for (x, y) in (([4, 5, 6, 7], [3.0, 9.0, 8.0, 1.0]),
304+
([4, 5, 6, 7], [3//1 8; 9 1]))
305+
@test (@inferred sqeuclidean(x, y)) == 57
306+
@test (@inferred euclidean(x, y)) == sqrt(57)
307+
@test (@inferred jaccard(x, y)) == convert(Base.promote_eltype(x, y), 13 // 28)
308+
@test (@inferred cityblock(x, y)) == 13
309+
@test (@inferred totalvariation(x, y)) == 6.5
310+
@test (@inferred chebyshev(x, y)) == 6
311+
@test (@inferred braycurtis(x, y)) == convert(Base.promote_eltype(x, y), 13 // 43)
312+
@test (@inferred minkowski(x, y, 2)) == sqrt(57)
313+
@test (@inferred peuclidean(x, y, fill(10, 4))) == sqrt(37)
314+
@test (@inferred peuclidean(x - vec(y), zero(y), fill(10, 4))) == peuclidean(x, y, fill(10, 4))
315+
@test (@inferred peuclidean(x, y, [10.0, 10.0, 10.0, Inf])) == sqrt(57)
316+
@test_throws DimensionMismatch cosine_dist(1.0:2, 1.0:3)
317+
@test (@inferred cosine_dist(x, y)) (1 - 112 / sqrt(19530))
318+
@test (@inferred corr_dist(x, y)) cosine_dist(x .- mean(x), vec(y) .- mean(y))
319+
@test (@inferred chisq_dist(x, y)) == sum((x - vec(y)).^2 ./ (x + vec(y)))
320+
@test (@inferred spannorm_dist(x, y)) == maximum(x - vec(y)) - minimum(x - vec(y))
321+
322+
@test (@inferred gkl_divergence(x, y)) sum(i -> x[i] * log(x[i] / y[i]) - x[i] + y[i], 1:length(x))
323+
324+
@test (@inferred meanad(x, y)) mean(Float64[abs(x[i] - y[i]) for i in 1:length(x)])
325+
@test (@inferred msd(x, y)) mean(Float64[abs2(x[i] - y[i]) for i in 1:length(x)])
326+
@test (@inferred rmsd(x, y)) sqrt(msd(x, y))
327+
@test (@inferred nrmsd(x, y)) sqrt(msd(x, y)) / (maximum(x) - minimum(x))
328+
329+
w = ones(Int, 4)
330+
@test sqeuclidean(x, y) wsqeuclidean(x, y, w)
331+
332+
w = rand(1:length(x), size(x))
333+
@test (@inferred wsqeuclidean(x, y, w)) dot((x - vec(y)).^2, w)
334+
@test (@inferred weuclidean(x, y, w)) == sqrt(wsqeuclidean(x, y, w))
335+
@test (@inferred wcityblock(x, y, w)) dot(abs.(x - vec(y)), w)
336+
@test (@inferred wminkowski(x, y, w, 2)) weuclidean(x, y, w)
337+
end
338+
end
339+
302340
@testset "mahalanobis" begin
303341
for T in (Float64, F64)
304342
x, y = T.([4.0, 5.0, 6.0, 7.0]), T.([3.0, 9.0, 8.0, 1.0])

0 commit comments

Comments
 (0)