diff --git a/src/knn.jl b/src/knn.jl index 6fb4bb1..0da7a50 100644 --- a/src/knn.jl +++ b/src/knn.jl @@ -11,7 +11,7 @@ end Performs a lookup of the `k` nearest neigbours to the `points` from the data in the `tree`. If `sortres = true` the result is sorted such that the results are in the order of increasing distance to the point. `skip` is an optional predicate -to determine if a point that would be returned should be skipped based on its +to determine if a point that would be returned should be skipped based on its index. """ function knn(tree::NNTree{V}, points::Vector{T}, k::Int, sortres=false, skip::F=always_false) where {V, T <: AbstractVector, F<:Function} @@ -27,7 +27,7 @@ function knn(tree::NNTree{V}, points::Vector{T}, k::Int, sortres=false, skip::F= end function knn_point!(tree::NNTree{V}, point::AbstractVector{T}, sortres, dist, idx, skip::F) where {V, T <: Number, F} - fill!(idx, -1) + fill!(idx, 1) fill!(dist, typemax(get_T(eltype(V)))) _knn(tree, point, idx, dist, skip) sortres && heap_sort_inplace!(dist, idx) diff --git a/test/runtests.jl b/test/runtests.jl index a257562..6e893ce 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -31,6 +31,7 @@ include("test_knn.jl") include("test_inrange.jl") include("test_monkey.jl") include("test_datafreetree.jl") +include("test_specialfloats.jl") @testset "periodic euclidean" begin pred = PeriodicEuclidean([Inf, 2.5]) diff --git a/test/test_specialfloats.jl b/test/test_specialfloats.jl new file mode 100644 index 0000000..e2a7d98 --- /dev/null +++ b/test/test_specialfloats.jl @@ -0,0 +1,53 @@ +using NearestNeighbors +using Test + +# Test for issue #125 +@testset "nan on query" begin + for _ in 1:111 + Ndim = 35 + Npt = 408 + + data = randn(Ndim, Npt) + tree = KDTree(data) + + pointnan = repeat([NaN], Ndim) + indnan,distnan = nn(tree, pointnan) + @test 1 <= indnan <= Npt + end +end + +# # Test for issue #78 +# @testset "infs on data" begin +# for _ in 1:11 +# coords = [ +# 29882.5 25974.3 Inf Inf 17821.8 Inf Inf Inf Inf Inf 16322.0; +# 9279.86 9286.35 Inf Inf 10320.4 Inf Inf Inf Inf Inf 11459.0; +# 0.0 0.0 Inf Inf 0.0 Inf Inf Inf Inf Inf 0.0] +# point = [17889.55, 2094.45, 0.0] + +# tree = BallTree(coords) +# idx, _ = knn(tree, point, 1) +# @test idx[1] == 5 +# end +# end + +# @testset "nan on data" begin +# for _ in 1:11 +# # Ndim = 35 +# # Npt = 408 + +# # data = randn(Ndim, Npt) +# # tree = KDTree(data) + +# # datanan = copy(data) +# # datanan[rand(1:Ndim),rand(1:Npt)] = NaN +# # treenan = KDTree(datanan) + +# # pointrand = randn(Ndim) + +# # @show indnan2, distnan2 = nn(tree, pointrand) +# # @show indnan2, distnan2 = nn(treenan, pointrand) +# # @test 1 <= indnan2 <= Npt + +# end +# end