diff --git a/src/host/indexing.jl b/src/host/indexing.jl index 8aee8b9d..5e744dbd 100644 --- a/src/host/indexing.jl +++ b/src/host/indexing.jl @@ -191,6 +191,7 @@ Base.getindex(ei::EachIndex, i::Int) = ei.indices[i] Base.IndexStyle(::Type{<:EachIndex}) = Base.IndexLinear() function Base.findfirst(f::Function, A::AnyGPUArray) + isempty(A) && return nothing indices = EachIndex(A) dummy_index = first(indices) diff --git a/test/testsuite.jl b/test/testsuite.jl index b48d7ccd..f2ec6388 100644 --- a/test/testsuite.jl +++ b/test/testsuite.jl @@ -16,6 +16,7 @@ using Test using Adapt +test_result(a, b; kwargs...) = a == b test_result(a::Number, b::Number; kwargs...) = ≈(a, b; kwargs...) test_result(a::Missing, b::Missing; kwargs...) = true test_result(a::Number, b::Missing; kwargs...) = false diff --git a/test/testsuite/indexing.jl b/test/testsuite/indexing.jl index bc19cd7e..d8943de3 100644 --- a/test/testsuite/indexing.jl +++ b/test/testsuite/indexing.jl @@ -170,6 +170,10 @@ end let x = rand(Float32, 10, 10) @test findfirst(>(0.5f0), x) == findfirst(>(0.5f0), AT(x)) end + + # emtpy + @test compare(findfirst, AT, Bool[]) + @test compare(x->findfirst(>(0.5f0), x), AT, Float32[]) end @testset "findmax & findmin" begin