Skip to content

Commit 473d0db

Browse files
authored
Specialize findlast for integer AbstractUnitRanges and StepRanges (#54902)
For monotonic ranges, `findfirst` and `findlast` with `==(val)` as the predicate should be identical, as each value appears only once in the range. Since `findfirst` is specialized for some ranges, we may define `findlast` as well analogously. On v"1.12.0-DEV.770" ```julia julia> @Btime findlast(==(1), $(Ref(1:1_000))[]) 1.186 μs (0 allocations: 0 bytes) 1 ``` This PR ```julia julia> @Btime findlast(==(1), $(Ref(1:1_000))[]) 3.171 ns (0 allocations: 0 bytes) 1 ``` I've also specialized `findfirst(iszero, r::AbstractRange)` to make this be equivalent to `findfirst(==(0), ::AbstractRange)` for numerical ranges. Similarly, for `isone`. These now take the fast path as well. Thirdly, I've added some `convert` calls to address issues like ```julia julia> r = Int128(1):Int128(1):Int128(4); julia> findfirst(==(Int128(2)), r) |> typeof Int128 julia> keytype(r) Int64 ``` This PR ensures that the return type always corresponds to `keytype`, which is what the docstring promises. This PR also fixes ```julia julia> findfirst(==(0), UnitRange(-0.5, 0.5)) ERROR: InexactError: Int64(0.5) Stacktrace: [1] Int64 @ ./float.jl:994 [inlined] [2] findfirst(p::Base.Fix2{typeof(==), Int64}, r::UnitRange{Float64}) @ Base ./array.jl:2397 [3] top-level scope @ REPL[1]:1 ``` which now returns `nothing`, as expected.
1 parent fb5e96a commit 473d0db

File tree

2 files changed

+66
-6
lines changed

2 files changed

+66
-6
lines changed

base/array.jl

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2439,20 +2439,29 @@ end
24392439
findfirst(testf::Function, A::Union{AbstractArray, AbstractString}) =
24402440
findnext(testf, A, first(keys(A)))
24412441

2442-
findfirst(p::Union{Fix2{typeof(isequal),Int},Fix2{typeof(==),Int}}, r::OneTo{Int}) =
2443-
1 <= p.x <= r.stop ? p.x : nothing
2442+
findfirst(p::Union{Fix2{typeof(isequal),T},Fix2{typeof(==),T}}, r::OneTo) where {T<:Integer} =
2443+
1 <= p.x <= r.stop ? convert(keytype(r), p.x) : nothing
24442444

2445-
findfirst(p::Union{Fix2{typeof(isequal),T},Fix2{typeof(==),T}}, r::AbstractUnitRange) where {T<:Integer} =
2446-
first(r) <= p.x <= last(r) ? firstindex(r) + Int(p.x - first(r)) : nothing
2445+
findfirst(::typeof(iszero), ::OneTo) = nothing
2446+
findfirst(::typeof(isone), r::OneTo) = isempty(r) ? nothing : oneunit(keytype(r))
2447+
2448+
function findfirst(p::Union{Fix2{typeof(isequal),T},Fix2{typeof(==),T}}, r::AbstractUnitRange{<:Integer}) where {T<:Integer}
2449+
first(r) <= p.x <= last(r) || return nothing
2450+
i1 = first(keys(r))
2451+
return i1 + oftype(i1, p.x - first(r))
2452+
end
24472453

24482454
function findfirst(p::Union{Fix2{typeof(isequal),T},Fix2{typeof(==),T}}, r::StepRange{T,S}) where {T,S}
24492455
isempty(r) && return nothing
24502456
minimum(r) <= p.x <= maximum(r) || return nothing
2451-
d = convert(S, p.x - first(r))::S
2457+
d = p.x - first(r)
24522458
iszero(d % step(r)) || return nothing
2453-
return d ÷ step(r) + 1
2459+
return convert(keytype(r), d ÷ step(r) + 1)
24542460
end
24552461

2462+
findfirst(::typeof(iszero), r::AbstractRange) = findfirst(==(zero(first(r))), r)
2463+
findfirst(::typeof(isone), r::AbstractRange) = findfirst(==(one(first(r))), r)
2464+
24562465
"""
24572466
findprev(A, i)
24582467
@@ -2623,6 +2632,17 @@ end
26232632
findlast(testf::Function, A::Union{AbstractArray, AbstractString}) =
26242633
findprev(testf, A, last(keys(A)))
26252634

2635+
# for monotonic ranges, there is a unique index corresponding to a value, so findfirst and findlast are identical
2636+
function findlast(p::Union{Fix2{typeof(isequal),<:Integer},Fix2{typeof(==),<:Integer},typeof(iszero),typeof(isone)},
2637+
r::AbstractUnitRange{<:Integer})
2638+
findfirst(p, r)
2639+
end
2640+
2641+
function findlast(p::Union{Fix2{typeof(isequal),T},Fix2{typeof(==),T},typeof(iszero),typeof(isone)},
2642+
r::StepRange{T,S}) where {T,S}
2643+
findfirst(p, r)
2644+
end
2645+
26262646
"""
26272647
findall(f::Function, A)
26282648

test/ranges.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,15 +438,55 @@ end
438438
@test findfirst(isequal(3), Base.OneTo(10)) == 3
439439
@test findfirst(==(0), Base.OneTo(10)) === nothing
440440
@test findfirst(==(11), Base.OneTo(10)) === nothing
441+
@test @inferred((r -> Val(findfirst(iszero, r)))(Base.OneTo(10))) == Val(nothing)
442+
@test findfirst(isone, Base.OneTo(10)) === 1
443+
@test findfirst(isone, Base.OneTo(0)) === nothing
441444
@test findfirst(==(4), Int16(3):Int16(7)) === Int(2)
442445
@test findfirst(==(2), Int16(3):Int16(7)) === nothing
443446
@test findfirst(isequal(8), 3:7) === nothing
447+
@test findfirst(==(0), UnitRange(-0.5, 0.5)) === nothing
448+
@test findfirst(==(2), big(1):big(2)) === 2
444449
@test findfirst(isequal(7), 1:2:10) == 4
450+
@test findfirst(iszero, -5:5) == 6
451+
@test findfirst(iszero, 2:5) === nothing
452+
@test findfirst(iszero, 6:5) === nothing
453+
@test findfirst(isone, -5:5) == 7
454+
@test findfirst(isone, 2:5) === nothing
455+
@test findfirst(isone, 6:5) === nothing
445456
@test findfirst(==(7), 1:2:10) == 4
446457
@test findfirst(==(10), 1:2:10) === nothing
447458
@test findfirst(==(11), 1:2:10) === nothing
448459
@test findfirst(==(-7), 1:-1:-10) == 9
449460
@test findfirst(==(2),1:-1:2) === nothing
461+
@test findfirst(iszero, 5:-2:-5) === nothing
462+
@test findfirst(iszero, 6:-2:-6) == 4
463+
@test findfirst(==(Int128(2)), Int128(1):Int128(1):Int128(4)) === 2
464+
end
465+
@testset "findlast" begin
466+
@test findlast(==(1), Base.IdentityUnitRange(-1:1)) == 1
467+
@test findlast(isequal(3), Base.OneTo(10)) == 3
468+
@test findlast(==(0), Base.OneTo(10)) === nothing
469+
@test findlast(==(11), Base.OneTo(10)) === nothing
470+
@test @inferred((() -> Val(findlast(iszero, Base.OneTo(10))))()) == Val(nothing)
471+
@test findlast(isone, Base.OneTo(10)) == 1
472+
@test findlast(isone, Base.OneTo(0)) === nothing
473+
@test findlast(==(4), Int16(3):Int16(7)) === Int(2)
474+
@test findlast(==(2), Int16(3):Int16(7)) === nothing
475+
@test findlast(isequal(8), 3:7) === nothing
476+
@test findlast(==(0), UnitRange(-0.5, 0.5)) === nothing
477+
@test findlast(==(2), big(1):big(2)) === 2
478+
@test findlast(isequal(7), 1:2:10) == 4
479+
@test findlast(iszero, -5:5) == 6
480+
@test findlast(iszero, 2:5) === nothing
481+
@test findlast(iszero, 6:5) === nothing
482+
@test findlast(==(7), 1:2:10) == 4
483+
@test findlast(==(10), 1:2:10) === nothing
484+
@test findlast(==(11), 1:2:10) === nothing
485+
@test findlast(==(-7), 1:-1:-10) == 9
486+
@test findlast(==(2),1:-1:2) === nothing
487+
@test findlast(iszero, 5:-2:-5) === nothing
488+
@test findlast(iszero, 6:-2:-6) == 4
489+
@test findlast(==(Int128(2)), Int128(1):Int128(1):Int128(4)) === 2
450490
end
451491
@testset "reverse" begin
452492
@test reverse(reverse(1:10)) == 1:10

0 commit comments

Comments
 (0)