Skip to content

Commit c6b529d

Browse files
ntdefararslan
authored andcommitted
Extend CosineDist support for number types other than Float64 (#61)
* Cleanup whitespace * Fix for #55: `CosineDist` fails with Ints The reason CosineDist was failing for integers was because `eval_start` was only defined for subtypes of AbstractFloat instead of Number. Since `zero` has a method in Base for any Number, it seemed reasonable to widen the net on the types of arrays that `eval_start` instead of writing a new method on Integers. * Change CosineDist type from Number to Real Doing this to avoid ambiguities with Complex numbers. * Add unit tests for CosineDist on integers * Fix type coercion for compatibility with v0.4
1 parent f6cee36 commit c6b529d

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

src/metrics.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,10 @@ immutable RenyiDivergence{T <: Real} <: PreMetric
3737
is_zero = q zero(T)
3838
is_one = q one(T)
3939
is_inf = isinf(q)
40-
40+
4141
# Only positive Rényi divergences are defined
4242
!is_zero && q < zero(T) && throw(ArgumentError("Order of Rényi divergence not legal, $(q) < 0."))
43-
43+
4444
new(q - 1, !(is_zero || is_one || is_inf), is_zero, is_one, is_inf)
4545
end
4646
end
@@ -132,7 +132,7 @@ hamming(a::AbstractArray, b::AbstractArray) = evaluate(Hamming(), a, b)
132132
hamming{T <: Number}(a::T, b::T) = evaluate(Hamming(), a, b)
133133

134134
# Cosine dist
135-
function eval_start{T<:AbstractFloat}(::CosineDist, a::AbstractArray{T}, b::AbstractArray{T})
135+
function eval_start{T<:Real}(::CosineDist, a::AbstractArray{T}, b::AbstractArray{T})
136136
zero(T), zero(T), zero(T)
137137
end
138138
@inline eval_op(::CosineDist, ai, bi) = ai * bi, ai * ai, bi * bi

test/test_dists.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ for (x, y) in (([4., 5., 6., 7.], [3., 9., 8., 1.]),
5353

5454
@test_throws DimensionMismatch cosine_dist(1.:2, 1.:3)
5555
@test cosine_dist(x, y) (1.0 - 112. / sqrt(19530.))
56+
x_int, y_int = map(Int64, x), map(Int64, y)
57+
@test cosine_dist(x_int, y_int) == (1.0 - 112. / sqrt(19530.))
5658

5759
@test corr_dist(x, x) < 1.0e-14
5860
@test corr_dist(x, y) cosine_dist(x .- mean(x), vec(y) .- mean(y))

0 commit comments

Comments
 (0)