Skip to content

Commit 951f2a4

Browse files
chriselrodTokazama
andauthored
@inline StrideIndex (#343)
* Use generated `map_tuple_type`,`flatten_tuples` for pre-v1.8 Failure to elide certain calls on pre-v1.8 was causing performance problems downstream, so we keep the less desirable code in Julia versions that require it for performance Co-authored-by: Zachary P. Christensen <[email protected]>
1 parent be5c398 commit 951f2a4

File tree

4 files changed

+59
-19
lines changed

4 files changed

+59
-19
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ArrayInterface"
22
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
3-
version = "6.0.22"
3+
version = "6.0.23"
44

55
[deps]
66
ArrayInterfaceCore = "30b0a656-2188-435a-8636-2ec0e6a096e2"

lib/ArrayInterfaceCore/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ArrayInterfaceCore"
22
uuid = "30b0a656-2188-435a-8636-2ec0e6a096e2"
3-
version = "0.1.17"
3+
version = "0.1.18"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl

Lines changed: 54 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,23 @@ julia> ArrayInterfaceCore.map_tuple_type(sqrt, Tuple{1,4,16})
3535
3636
```
3737
"""
38-
@inline function map_tuple_type(f, @nospecialize(T::Type))
39-
ntuple(i -> f(fieldtype(T, i)), Val{fieldcount(T)}())
38+
function map_tuple_type end
39+
if VERSION >= v"1.8"
40+
@inline function map_tuple_type(f, @nospecialize(T::Type))
41+
ntuple(i -> f(fieldtype(T, i)), Val{fieldcount(T)}())
42+
end
43+
else
44+
function map_tuple_type(f::F, ::Type{T}) where {F,T<:Tuple}
45+
if @generated
46+
t = Expr(:tuple)
47+
for i in 1:fieldcount(T)
48+
push!(t.args, :(f($(fieldtype(T, i)))))
49+
end
50+
Expr(:block, Expr(:meta, :inline), t)
51+
else
52+
Tuple(f(fieldtype(T, i)) for i in 1:fieldcount(T))
53+
end
54+
end
4055
end
4156

4257
"""
@@ -58,20 +73,45 @@ julia> ArrayInterfaceCore.flatten_tuples((1, (2, (3,))))
5873
5974
```
6075
"""
61-
function flatten_tuples(t::Tuple)
62-
fields = _new_field_positions(t)
63-
ntuple(Val{nfields(fields)}()) do k
64-
i, j = getfield(fields, k)
65-
i = length(t) - i
66-
@inbounds j === 0 ? getfield(t, i) : getfield(getfield(t, i), j)
76+
function flatten_tuples end
77+
if VERSION >= v"1.8"
78+
function flatten_tuples(t::Tuple)
79+
fields = _new_field_positions(t)
80+
ntuple(Val{nfields(fields)}()) do k
81+
i, j = getfield(fields, k)
82+
i = length(t) - i
83+
@inbounds j === 0 ? getfield(t, i) : getfield(getfield(t, i), j)
84+
end
85+
end
86+
_new_field_positions(::Tuple{}) = ()
87+
@nospecialize
88+
_new_field_positions(x::Tuple) = (_fl1(x, x[1])..., _new_field_positions(Base.tail(x))...)
89+
_fl1(x::Tuple, x1::Tuple) = ntuple(Base.Fix1(tuple, length(x) - 1), Val(length(x1)))
90+
_fl1(x::Tuple, x1) = ((length(x) - 1, 0),)
91+
@specialize
92+
else
93+
@inline function flatten_tuples(t::Tuple)
94+
if @generated
95+
texpr = Expr(:tuple)
96+
for i in 1:fieldcount(t)
97+
p = fieldtype(t, i)
98+
if p <: Tuple
99+
for j in 1:fieldcount(p)
100+
push!(texpr.args, :(@inbounds(getfield(getfield(t, $i), $j))))
101+
end
102+
else
103+
push!(texpr.args, :(@inbounds(getfield(t, $i))))
104+
end
105+
end
106+
Expr(:block, Expr(:meta, :inline), texpr)
107+
else
108+
_flatten(t)
109+
end
67110
end
111+
_flatten(::Tuple{}) = ()
112+
@inline _flatten(t::Tuple{Any,Vararg{Any}}) = (getfield(t, 1), _flatten(Base.tail(t))...)
113+
@inline _flatten(t::Tuple{Tuple,Vararg{Any}}) = (getfield(t, 1)..., _flatten(Base.tail(t))...)
68114
end
69-
_new_field_positions(::Tuple{}) = ()
70-
@nospecialize
71-
_new_field_positions(x::Tuple) = (_fl1(x, x[1])..., _new_field_positions(Base.tail(x))...)
72-
_fl1(x::Tuple, x1::Tuple) = ntuple(Base.Fix1(tuple, length(x) - 1), Val(length(x1)))
73-
_fl1(x::Tuple, x1) = ((length(x) - 1, 0),)
74-
@specialize
75115

76116
"""
77117
parent_type(::Type{T}) -> Type

src/array_index.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@ struct StrideIndex{N,R,C,S,O} <: ArrayIndex{N}
88
strides::S
99
offsets::O
1010

11-
function StrideIndex{N,R,C}(s::S, o::O) where {N,R,C,S,O}
11+
@inline function StrideIndex{N,R,C}(s::S, o::O) where {N,R,C,S,O}
1212
return new{N,R::NTuple{N,Int},C,S,O}(s, o)
1313
end
14-
function StrideIndex{N,R,C}(a::A) where {N,R,C,A}
14+
@inline function StrideIndex{N,R,C}(a::A) where {N,R,C,A}
1515
return StrideIndex{N,R,C}(strides(a), offsets(a))
1616
end
17-
function StrideIndex(a::A) where {A}
17+
@inline function StrideIndex(a::A) where {A}
1818
return StrideIndex{ndims(A),known(stride_rank(A)),known(contiguous_axis(A))}(a)
1919
end
2020
end

0 commit comments

Comments
 (0)