Skip to content

Commit 36ca34c

Browse files
authored
Fix strides(::ReinterpretArray) and extend strides(::ReshapedArray) (#264)
* Remove some version check. Update stridelayout.jl * Fix `strides(::Base.ReinterpretArray)` * Extend `strides(A::ReshapedArray)` `ReshapedArray` has no static size. Thus we only try to make the 1st stride static. * Replace `Integer` with `Union{Int,StaticInt}`.
1 parent 828ba3f commit 36ca34c

File tree

2 files changed

+231
-157
lines changed

2 files changed

+231
-157
lines changed

src/stridelayout.jl

Lines changed: 136 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -262,31 +262,29 @@ stride_rank(x, i) = stride_rank(x)[i]
262262
function stride_rank(::Type{R}) where {T,N,S,A<:Array{S},R<:Base.ReinterpretArray{T,N,S,A}}
263263
return nstatic(Val(N))
264264
end
265-
if VERSION v"1.6.0-DEV.1581"
266-
@inline function stride_rank(::Type{A}) where {NB, NA, B <: AbstractArray{<:Any,NB},A<: Base.ReinterpretArray{<:Any, NA, <:Any, B, true}}
265+
@inline function stride_rank(::Type{A}) where {NB,NA,B<:AbstractArray{<:Any,NB},A<:Base.ReinterpretArray{<:Any,NA,<:Any,B,true}}
267266
NA == NB ? stride_rank(B) : _stride_rank_reinterpret(stride_rank(B), gt(StaticInt{NB}(), StaticInt{NA}()))
268-
end
269-
@inline _stride_rank_reinterpret(sr, ::False) = (One(), map(Base.Fix2(+,One()),sr)...)
270-
@inline _stride_rank_reinterpret(sr::Tuple{One,Vararg}, ::True) = map(Base.Fix2(-,One()), tail(sr))
271-
# if the leading dim's `stride_rank` is not one, then that means the individual elements are split across an axis, which ArrayInterface
272-
# doesn't currently have a means of representing.
273-
@inline function contiguous_axis(::Type{A}) where {NB, NA, B <: AbstractArray{<:Any,NB},A<: Base.ReinterpretArray{<:Any, NA, <:Any, B, true}}
267+
end
268+
@inline _stride_rank_reinterpret(sr, ::False) = (One(), map(Base.Fix2(+, One()), sr)...)
269+
@inline _stride_rank_reinterpret(sr::Tuple{One,Vararg}, ::True) = map(Base.Fix2(-, One()), tail(sr))
270+
# if the leading dim's `stride_rank` is not one, then that means the individual elements are split across an axis, which ArrayInterface
271+
# doesn't currently have a means of representing.
272+
@inline function contiguous_axis(::Type{A}) where {NB,NA,B<:AbstractArray{<:Any,NB},A<:Base.ReinterpretArray{<:Any,NA,<:Any,B,true}}
274273
_reinterpret_contiguous_axis(stride_rank(B), dense_dims(B), contiguous_axis(B), gt(StaticInt{NB}(), StaticInt{NA}()))
275-
end
276-
@inline _reinterpret_contiguous_axis(::Any, ::Any, ::Any, ::False) = One()
277-
@inline _reinterpret_contiguous_axis(::Any, ::Any, ::Any, ::True) = Zero()
278-
@generated function _reinterpret_contiguous_axis(t::Tuple{One,Vararg{StaticInt,N}}, d::Tuple{True,Vararg{StaticBool,N}}, ::One, ::True) where {N}
274+
end
275+
@inline _reinterpret_contiguous_axis(::Any, ::Any, ::Any, ::False) = One()
276+
@inline _reinterpret_contiguous_axis(::Any, ::Any, ::Any, ::True) = Zero()
277+
@generated function _reinterpret_contiguous_axis(t::Tuple{One,Vararg{StaticInt,N}}, d::Tuple{True,Vararg{StaticBool,N}}, ::One, ::True) where {N}
279278
for n in 1:N
280-
if t.parameters[n+1].parameters[1] === 2
281-
if d.parameters[n+1] === True
282-
return :(StaticInt{$n}())
283-
else
284-
return :(Zero())
279+
if t.parameters[n+1].parameters[1] === 2
280+
if d.parameters[n+1] === True
281+
return :(StaticInt{$n}())
282+
else
283+
return :(Zero())
284+
end
285285
end
286-
end
287286
end
288287
:(Zero())
289-
end
290288
end
291289

292290
function stride_rank(::Type{Base.ReshapedArray{T, N, P, Tuple{Vararg{Base.SignedMultiplicativeInverse{Int},M}}}}) where {T,N,P,M}
@@ -411,11 +409,9 @@ end
411409
function dense_dims(::Type{S}) where {N,NP,T,A<:AbstractArray{T,NP},I,S<:SubArray{T,N,A,I}}
412410
return _dense_dims(S, dense_dims(A), Val(stride_rank(A)))
413411
end
414-
if VERSION v"1.6.0-DEV.1581"
415-
@inline function dense_dims(::Type{A}) where {NB, NA, B <: AbstractArray{<:Any,NB},A<: Base.ReinterpretArray{<:Any, NA, <:Any, B, true}}
416-
ddb = dense_dims(B)
417-
IfElse.ifelse(Static.le(StaticInt(NB), StaticInt(NA)), (True(), ddb...), Base.tail(ddb))
418-
end
412+
@inline function dense_dims(::Type{A}) where {NB, NA, B <: AbstractArray{<:Any,NB},A<: Base.ReinterpretArray{<:Any, NA, <:Any, B, true}}
413+
ddb = dense_dims(B)
414+
IfElse.ifelse(Static.le(StaticInt(NB), StaticInt(NA)), (True(), ddb...), Base.tail(ddb))
419415
end
420416

421417
_dense_dims(::Type{S}, ::Nothing, ::Val{R}) where {R,N,NP,T,A<:AbstractArray{T,NP},I,S<:SubArray{T,N,A,I}} = nothing
@@ -561,70 +557,127 @@ strides(A::StrideIndex) = getfield(A, :strides)
561557
end
562558
end
563559

564-
# Fixes the example of https://github.com/JuliaArrays/ArrayInterface.jl/issues/160
565-
# TODO: Should be generalized to reshaped arrays wrapping more general array types
566-
function strides(A::ReshapedArray{T,N,P}) where {T, N, P<:AbstractVector}
567-
if defines_strides(A)
568-
return size_to_strides(size(A), first(strides(parent(A))))
560+
_is_column_dense(::A) where {A<:AbstractArray} =
561+
defines_strides(A) &&
562+
(ndims(A) == 0 || Bool(is_dense(A)) && Bool(is_column_major(A)))
563+
564+
# Fixes the example of https://github.com/JuliaArrays/ArrayInterfaceCore.jl/issues/160
565+
function strides(A::ReshapedArray)
566+
_is_column_dense(parent(A)) && return size_to_strides(size(A), One())
567+
pst = strides(parent(A))
568+
psz = size(parent(A))
569+
# Try dimension merging in order (starting from dim1).
570+
# `sz1` and `st1` are the `size`/`stride` of dim1 after dimension merging.
571+
# `n` indicates the last merged dimension.
572+
# note: `st1` should be static if possible
573+
sz1, st1, n = merge_adjacent_dim(psz, pst)
574+
n == ndims(A.parent) && return size_to_strides(size(A), st1)
575+
return _reshaped_strides(size(A), One(), sz1, st1, n, Dims(psz), Dims(pst))
576+
end
577+
578+
@inline function _reshaped_strides(::Dims{0}, reshaped, msz::Int, _, ::Int, ::Dims, ::Dims)
579+
reshaped == msz && return ()
580+
throw(ArgumentError("Input is not strided."))
581+
end
582+
function _reshaped_strides(asz::Dims, reshaped, msz::Int, mst, n::Int, apsz::Dims, apst::Dims)
583+
st = reshaped * mst
584+
reshaped = reshaped * asz[1]
585+
if length(asz) > 1 && reshaped == msz && asz[2] != 1
586+
msz, mst′, n = merge_adjacent_dim(apsz, apst, n + 1)
587+
reshaped = 1
588+
else
589+
mst′ = Int(mst)
590+
end
591+
sts = _reshaped_strides(tail(asz), reshaped, msz, mst′, n, apsz, apst)
592+
return (st, sts...)
593+
end
594+
595+
merge_adjacent_dim(::Tuple{}, ::Tuple{}) = 1, One(), 0
596+
merge_adjacent_dim(szs::Tuple{Any}, sts::Tuple{Any}) = Int(szs[1]), sts[1], 1
597+
function merge_adjacent_dim(szs::Tuple, sts::Tuple)
598+
if szs[1] isa One # Just ignore dimension with size 1
599+
sz, st, n = merge_adjacent_dim(tail(szs), tail(sts))
600+
return sz, st, n + 1
601+
elseif szs[2] isa One # Just ignore dimension with size 1
602+
sz, st, n = merge_adjacent_dim((szs[1], tail(tail(szs))...), (sts[1], tail(tail(sts))...))
603+
return sz, st, n + 1
604+
elseif (szs[1], szs[2], sts[1], sts[2]) isa NTuple{4,StaticInt} # the check could be done during compiling.
605+
if sts[2] == sts[1] * szs[1]
606+
szs′ = (szs[1] * szs[2], tail(tail(szs))...)
607+
sts′ = (sts[1], tail(tail(sts))...)
608+
sz, st, n = merge_adjacent_dim(szs′, sts′)
609+
return sz, st, n + 1
610+
else
611+
return Int(szs[1]), sts[1], 1
612+
end
613+
else # the check can't be done during compiling.
614+
sz, st, n = merge_adjacent_dim(Dims(szs), Dims(sts), 1)
615+
if (szs[1], sts[1]) isa NTuple{2,StaticInt} && szs[1] != 1
616+
# But the 1st stride might still be static.
617+
return sz, sts[1], n
618+
else
619+
return sz, st, n
620+
end
621+
end
622+
end
623+
624+
function merge_adjacent_dim(psz::Dims{N}, pst::Dims{N}, n::Int) where {N}
625+
sz, st = psz[n], pst[n]
626+
while n < N
627+
szₙ, stₙ = psz[n+1], pst[n+1]
628+
if sz == 1
629+
sz, st = szₙ, stₙ
630+
elseif stₙ == st * sz
631+
sz *= szₙ
632+
elseif szₙ != 1
633+
break
634+
end
635+
n += 1
636+
end
637+
return sz, st, n
638+
end
639+
640+
# `strides` for `Base.ReinterpretArray`
641+
function strides(A::Base.ReinterpretArray{T,<:Any,S,<:AbstractArray{S},IsReshaped}) where {T,S,IsReshaped}
642+
_is_column_dense(parent(A)) && return size_to_strides(size(A), One())
643+
stp = strides(parent(A))
644+
ET, ES = static(sizeof(T)), static(sizeof(S))
645+
ET === ES && return stp
646+
IsReshaped && ET < ES && return (One(), _reinterp_strides(stp, ET, ES)...)
647+
first(stp) == 1 || throw(ArgumentError("Parent must be contiguous in the 1st dimension!"))
648+
if IsReshaped
649+
# The wrapper tell us `A`'s parent has static size in dim1.
650+
# We can make the next stride static if the following dim is still dense.
651+
sr = stride_rank(parent(A))
652+
dd = dense_dims(parent(A))
653+
stp′ = _new_static(stp, sr, dd, ET ÷ ES)
654+
return _reinterp_strides(tail(stp′), ET, ES)
569655
else
570-
return Base.strides(A)
656+
return (One(), _reinterp_strides(tail(stp), ET, ES)...)
571657
end
572658
end
573-
function strides(A::ReshapedArray{T,N,P}) where {T, N, P}
574-
if defines_strides(A)
575-
return size_to_strides(size(A), static(1))
659+
_new_static(P,_,_,_) = P # This should never be called, just in case.
660+
@generated function _new_static(p::P, ::SR, ::DD, ::StaticInt{S}) where {S,N,P<:NTuple{N,Union{Int,StaticInt}},SR<:NTuple{N,StaticInt},DD<:NTuple{N,StaticBool}}
661+
sr = fieldtypes(SR)
662+
j = findfirst(T -> T() == sr[1]()+1, sr)
663+
if !isnothing(j) && !(fieldtype(P, j) <: StaticInt) && fieldtype(DD, j) === True
664+
return :(tuple($((i == j ? :(static($S)) : :(p[$i]) for i in 1:N)...)))
576665
else
577-
return Base.strides(A)
578-
end
579-
end
580-
581-
582-
@inline bmap(f::F, t::Tuple{}, x::Number) where {F} = ()
583-
@inline bmap(f::F, t::Tuple{T}, x::Number) where {F, T} = (f(first(t),x), )
584-
@inline bmap(f::F, t::Tuple, x::Number) where {F} = (f(first(t),x), bmap(f, Base.tail(t), x)...)
585-
@static if VERSION v"1.6.0-DEV.1581"
586-
# from `reinterpret(reshape, ...)`
587-
@inline function strides(A::Base.ReinterpretArray{R, N, T, B, true}) where {R,N,T,B}
588-
P = strides(parent(A))
589-
if sizeof(R) == sizeof(T)
590-
P
591-
elseif sizeof(R) > sizeof(T)
592-
x = Base.tail(P)
593-
fx = first(x)
594-
if fx isa Int
595-
(One(), bmap(Base.sdiv_int, Base.tail(x), fx)...)
596-
else
597-
(One(), bmap(÷, Base.tail(x), fx)...)
598-
end
666+
return :(p)
667+
end
668+
end
669+
@inline function _reinterp_strides(stp::Tuple, els::StaticInt, elp::StaticInt)
670+
if elp % els == 0
671+
N = elp ÷ els
672+
return map(i -> N * i, stp)
599673
else
600-
(One(), bmap(*, P, StaticInt(sizeof(T)) ÷ StaticInt(sizeof(R)))...)
601-
end
602-
end
603-
# plain `reinterpret(...)`
604-
@inline function strides(A::Base.ReinterpretArray{R, N, T, B, false}) where {R,N,T,B}
605-
P = strides(parent(A))
606-
if sizeof(R) == sizeof(T)
607-
P
608-
elseif sizeof(R) > sizeof(T)
609-
(first(P), bmap(÷, Base.tail(P), StaticInt(sizeof(R)) ÷ StaticInt(sizeof(T)))...)
610-
else # sizeof(R) < sizeof(T)
611-
(first(P), bmap(*, Base.tail(P), StaticInt(sizeof(T)) ÷ StaticInt(sizeof(R)))...)
612-
end
613-
end
614-
else
615-
# plain `reinterpret(...)`
616-
@inline function strides(A::Base.ReinterpretArray{R, N, T}) where {R,N,T}
617-
P = strides(parent(A))
618-
if sizeof(R) == sizeof(T)
619-
P
620-
elseif sizeof(R) > sizeof(T)
621-
(first(P), bmap(÷, Base.tail(P), StaticInt(sizeof(R)) ÷ StaticInt(sizeof(T)))...)
622-
else # sizeof(R) < sizeof(T)
623-
(first(P), bmap(*, Base.tail(P), StaticInt(sizeof(T)) ÷ StaticInt(sizeof(R)))...)
624-
end
625-
end
626-
end
627-
#@inline strides(A) = _strides(A, Base.strides(A), contiguous_axis(A))
674+
return map(stp) do i
675+
d, r = divrem(elp * i, els)
676+
iszero(r) || throw(ArgumentError("Parent's strides could not be exactly divided!"))
677+
d
678+
end
679+
end
680+
end
628681

629682
strides(::AbstractRange) = (One(),)
630683
function strides(x::VecAdjTrans)

0 commit comments

Comments
 (0)