Skip to content

Commit 25aecc9

Browse files
committed
Optimize indices through @generated reduce.
1 parent 703c19b commit 25aecc9

File tree

2 files changed

+142
-4
lines changed

2 files changed

+142
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ArrayInterface"
22
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
3-
version = "2.14.10"
3+
version = "2.14.11"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/ranges.jl

Lines changed: 141 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -443,21 +443,159 @@ specified, then the indices for visiting each index of `x` are returned.
443443
end
444444
@inline indices(x::AbstractUnitRange{<:Integer}) = Base.Slice(OptionallyStaticUnitRange(x))
445445

446+
"""
447+
reduce_tup(f::F, inds::Tuple{Vararg{Any,N}}) where {F,N}
448+
449+
An optimized `reduce` for tuples. `Base.reduce`'s `afoldl` will often not inline.
450+
Additionally, `reduce_tup` attempts to order the reduction in an optimal manner.
451+
452+
```julia
453+
julia> using StaticArrays, ArrayInterface, BenchmarkTools
454+
455+
julia> rsum(v::SVector) = ArrayInterface.reduce_tup(+, v.data)
456+
rsum (generic function with 2 methods)
457+
458+
julia> for n ∈ 2:16
459+
@show n
460+
v = @SVector rand(n)
461+
s1 = @btime sum(\$(Ref(v))[])
462+
s2 = @btime rsum(\$(Ref(v))[])
463+
end
464+
n = 2
465+
0.863 ns (0 allocations: 0 bytes)
466+
0.864 ns (0 allocations: 0 bytes)
467+
n = 3
468+
0.863 ns (0 allocations: 0 bytes)
469+
0.863 ns (0 allocations: 0 bytes)
470+
n = 4
471+
0.864 ns (0 allocations: 0 bytes)
472+
1.075 ns (0 allocations: 0 bytes)
473+
n = 5
474+
0.864 ns (0 allocations: 0 bytes)
475+
1.076 ns (0 allocations: 0 bytes)
476+
n = 6
477+
0.865 ns (0 allocations: 0 bytes)
478+
1.077 ns (0 allocations: 0 bytes)
479+
n = 7
480+
1.075 ns (0 allocations: 0 bytes)
481+
0.866 ns (0 allocations: 0 bytes)
482+
n = 8
483+
0.974 ns (0 allocations: 0 bytes)
484+
1.076 ns (0 allocations: 0 bytes)
485+
n = 9
486+
1.081 ns (0 allocations: 0 bytes)
487+
1.077 ns (0 allocations: 0 bytes)
488+
n = 10
489+
1.203 ns (0 allocations: 0 bytes)
490+
1.077 ns (0 allocations: 0 bytes)
491+
n = 11
492+
1.355 ns (0 allocations: 0 bytes)
493+
1.292 ns (0 allocations: 0 bytes)
494+
n = 12
495+
1.539 ns (0 allocations: 0 bytes)
496+
1.079 ns (0 allocations: 0 bytes)
497+
n = 13
498+
1.704 ns (0 allocations: 0 bytes)
499+
1.290 ns (0 allocations: 0 bytes)
500+
n = 14
501+
1.916 ns (0 allocations: 0 bytes)
502+
1.185 ns (0 allocations: 0 bytes)
503+
n = 15
504+
2.072 ns (0 allocations: 0 bytes)
505+
1.292 ns (0 allocations: 0 bytes)
506+
n = 16
507+
2.273 ns (0 allocations: 0 bytes)
508+
1.076 ns (0 allocations: 0 bytes)
509+
```
510+
511+
More importantly, `reduce_tup(_pick_range, inds)` often performs better than `reduce(_pick_range, inds)`.
512+
```julia
513+
julia> using ArrayInterface, BenchmarkTools
514+
515+
julia> inds = (Base.OneTo(100), 1:100, 1:ArrayInterface.StaticInt(100))
516+
(Base.OneTo(100), 1:100, 1:Static(100))
517+
518+
julia> @btime reduce(ArrayInterface._pick_range, \$(Ref(inds))[])
519+
6.000 ns (0 allocations: 0 bytes)
520+
Base.Slice(Static(1):Static(100))
521+
522+
julia> @btime ArrayInterface.reduce_tup(ArrayInterface._pick_range, \$(Ref(inds))[])
523+
2.578 ns (0 allocations: 0 bytes)
524+
Base.Slice(Static(1):Static(100))
525+
526+
julia> inds = (Base.OneTo(100), 1:100, 1:UInt(100))
527+
(Base.OneTo(100), 1:100, 0x0000000000000001:0x0000000000000064)
528+
529+
julia> @btime reduce(ArrayInterface._pick_range, \$(Ref(inds))[])
530+
6.191 ns (0 allocations: 0 bytes)
531+
Base.Slice(Static(1):100)
532+
533+
julia> @btime ArrayInterface.reduce_tup(ArrayInterface._pick_range, \$(Ref(inds))[])
534+
2.591 ns (0 allocations: 0 bytes)
535+
Base.Slice(Static(1):100)
536+
537+
julia> inds = (Base.OneTo(100), 1:100, 1:UInt(100), Int32(1):Int32(100))
538+
(Base.OneTo(100), 1:100, 0x0000000000000001:0x0000000000000064, 1:100)
539+
540+
julia> @btime reduce(ArrayInterface._pick_range, $(Ref(inds))[])
541+
9.268 ns (0 allocations: 0 bytes)
542+
Base.Slice(Static(1):100)
543+
544+
julia> @btime ArrayInterface.reduce_tup(ArrayInterface._pick_range, $(Ref(inds))[])
545+
2.570 ns (0 allocations: 0 bytes)
546+
Base.Slice(Static(1):100)
547+
```
548+
"""
549+
@generated function reduce_tup(f::F, inds::Tuple{Vararg{Any,N}}) where {F,N}
550+
q = Expr(:block, Expr(:meta, :inline))
551+
if N == 1
552+
push!(q.args, :(inds[1]))
553+
return q
554+
end
555+
splits = 0
556+
_N = N
557+
while _N > 1
558+
_Nhalf = _N >> 1
559+
for n 1:_Nhalf
560+
assign = Symbol(:r_,n,:_,splits)
561+
call = if splits == 0
562+
Expr(:call, :f, Expr(:ref, :inds, n), Expr(:ref, :inds, n + _Nhalf))
563+
else
564+
Expr(:call, :f, Symbol(:r_,n,:_,splits-1), Symbol(:r_,n + _Nhalf,:_,splits-1))
565+
end
566+
push!(q.args, Expr(:(=), assign, call))
567+
end
568+
for (i,n) enumerate((_Nhalf<<1)+1:_N)
569+
assign = Symbol(:r_,i,:_,splits)
570+
call = if _N == N
571+
Expr(:call, :f, assign, Expr(:ref, :inds, n))
572+
else
573+
Expr(:call, :f, assign, Symbol(:r_, n, :_, splits-1))
574+
end
575+
push!(q.args, Expr(:(=), assign, call))
576+
end
577+
splits += 1
578+
_N = _Nhalf
579+
end
580+
push!(q.args, Symbol(:r_,1,:_,splits - 1))
581+
q
582+
end
583+
446584
function indices(x::Tuple)
447585
inds = map(eachindex, x)
448-
return reduce(_pick_range, inds)
586+
return reduce_tup(_pick_range, inds)
449587
end
450588

451589
@inline indices(x, d) = indices(axes(x, d))
452590

453591
@inline function indices(x::Tuple{Vararg{Any,N}}, dim) where {N}
454592
inds = map(x_i -> indices(x_i, dim), x)
455-
return reduce(_pick_range, inds)
593+
return reduce_tup(_pick_range, inds)
456594
end
457595

458596
@inline function indices(x::Tuple{Vararg{Any,N}}, dim::Tuple{Vararg{Any,N}}) where {N}
459597
inds = map(indices, x, dim)
460-
return reduce(_pick_range, inds)
598+
return reduce_tup(_pick_range, inds)
461599
end
462600

463601
@inline function _pick_range(x, y)

0 commit comments

Comments
 (0)