Skip to content

Commit b91d803

Browse files
authored
Merge pull request #161 from ranocha/hr/reshaped_views
reshaped views and reinterpreted arrays
2 parents 6a8efa6 + d84f2d4 commit b91d803

File tree

3 files changed

+111
-22
lines changed

3 files changed

+111
-22
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ArrayInterface"
22
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
3-
version = "3.1.15"
3+
version = "3.1.16"
44

55
[deps]
66
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"

src/stridelayout.jl

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ end
8888
contiguous_axis(::Type{T}) -> StaticInt{N}
8989
9090
Returns the axis of an array of type `T` containing contiguous data.
91-
If no axis is contiguous, it returns `StaticInt{-1}`.
91+
If no axis is contiguous, it returns a `StaticInt{-1}`.
9292
If unknown, it returns `nothing`.
9393
"""
9494
contiguous_axis(x) = contiguous_axis(typeof(x))
@@ -297,7 +297,7 @@ contiguous_batch_size(::Type{<:Base.ReinterpretArray{T,N,S,A}}) where {T,N,S,A}
297297
"""
298298
is_column_major(A) -> True/False
299299
300-
Returns `Val{true}` if elements of `A` are stored in column major order. Otherwise returns `Val{false}`.
300+
Returns `True()` if elements of `A` are stored in column major order. Otherwise returns `False()`.
301301
"""
302302
is_column_major(A) = is_column_major(stride_rank(A), contiguous_batch_size(A))
303303
is_column_major(sr::Nothing, cbs) = False()
@@ -310,10 +310,11 @@ _is_column_major(sr::R, cbs::StaticInt) where {R} = False()
310310
_is_column_major(sr::R, cbs::Union{StaticInt{0},StaticInt{-1}}) where {R} = is_increasing(sr)
311311

312312
"""
313-
dense_dims(::Type{T}) -> NTuple{N,Bool}
313+
dense_dims(::Type{<:AbstractArray{N}}) -> NTuple{N,StaticBool}
314314
315315
Returns a tuple of indicators for whether each axis is dense.
316-
An axis `i` of array `A` is dense if `stride(A, i) * Base.size(A, i) == stride(A, j)` where `stride_rank(A)[i] + 1 == stride_rank(A)[j]`.
316+
An axis `i` of array `A` is dense if `stride(A, i) * Base.size(A, i) == stride(A, j)`
317+
where `stride_rank(A)[i] + 1 == stride_rank(A)[j]`.
317318
"""
318319
dense_dims(x) = dense_dims(typeof(x))
319320
function dense_dims(::Type{T}) where {T}
@@ -359,7 +360,7 @@ end
359360
if VERSION v"1.6.0-DEV.1581"
360361
@inline function dense_dims(::Type{A}) where {NB, NA, B <: AbstractArray{<:Any,NB},A<: Base.ReinterpretArray{<:Any, NA, <:Any, B, true}}
361362
ddb = dense_dims(B)
362-
IfElse.ifelse(Static.le(StaticInt(NB), StaticInt(NA)), (True(), ddb...), Base.tail(ddb))
363+
IfElse.ifelse(Static.le(StaticInt(NB), StaticInt(NA)), (True(), ddb...), Base.tail(ddb))
363364
end
364365
end
365366

@@ -413,6 +414,9 @@ function _reshaped_dense_dims(dense::D, ::True, ::Val{N}, ::Val{0}) where {D,N}
413414
return nothing
414415
end
415416
end
417+
function _reshaped_dense_dims(dense::Tuple{Static.False}, ::True, ::Val{N}, ::Val{0}) where {N}
418+
return return ntuple(_ -> False(), Val{N}())
419+
end
416420

417421
"""
418422
known_strides(::Type{T}[, dim]) -> Tuple
@@ -464,13 +468,16 @@ julia> A = rand(3,4);
464468
465469
julia> ArrayInterface.strides(A)
466470
(static(1), 3)
471+
```
467472
468473
Additionally, the behavior differs from `Base.strides` for adjoint vectors:
469474
475+
```julia
470476
julia> x = rand(5);
471477
472478
julia> ArrayInterface.strides(x')
473479
(static(1), static(1))
480+
```
474481
475482
This is to support the pattern of using just the first stride for linear indexing, `x[i]`,
476483
while still producing correct behavior when using valid cartesian indices, such as `x[1,i]`.
@@ -485,11 +492,23 @@ function strides(x)
485492
return Base.strides(x)
486493
end
487494
end
495+
496+
# Fixes the example of https://github.com/JuliaArrays/ArrayInterface.jl/issues/160
497+
# TODO: Should be generalized to reshaped arrays wrapping more general array types
498+
function strides(A::ReshapedArray{T,N,P}) where {T, N, P<:AbstractVector}
499+
if defines_strides(A)
500+
return size_to_strides(size(A), first(strides(parent(A))))
501+
else
502+
return Base.strides(A)
503+
end
504+
end
505+
488506
@inline bmap(f::F, t::Tuple{}, x::Number) where {F} = ()
489507
@inline bmap(f::F, t::Tuple{T}, x::Number) where {F, T} = (f(first(t),x), )
490508
@inline bmap(f::F, t::Tuple, x::Number) where {F} = (f(first(t),x), bmap(f, Base.tail(t), x)...)
491509
if VERSION v"1.6.0-DEV.1581"
492-
@inline @inline function strides(A::Base.ReinterpretArray{R, N, T, B, true}) where {R,N,T,B}
510+
# from `reinterpret(reshape, ...)`
511+
@inline function strides(A::Base.ReinterpretArray{R, N, T, B, true}) where {R,N,T,B}
493512
P = strides(parent(A))
494513
if sizeof(R) == sizeof(T)
495514
P
@@ -505,6 +524,18 @@ if VERSION ≥ v"1.6.0-DEV.1581"
505524
(One(), bmap(*, P, StaticInt(sizeof(T)) ÷ StaticInt(sizeof(R)))...)
506525
end
507526
end
527+
528+
# plain `reinterpret(...)`
529+
@inline function strides(A::Base.ReinterpretArray{R, N, T, B, false}) where {R,N,T,B}
530+
P = strides(parent(A))
531+
if sizeof(R) == sizeof(T)
532+
P
533+
elseif sizeof(R) > sizeof(T)
534+
(first(P), bmap(÷, Base.tail(P), StaticInt(sizeof(R)) ÷ StaticInt(sizeof(T)))...)
535+
else # sizeof(R) < sizeof(T)
536+
(first(P), bmap(*, Base.tail(P), StaticInt(sizeof(T)) ÷ StaticInt(sizeof(R)))...)
537+
end
538+
end
508539
end
509540
#@inline strides(A) = _strides(A, Base.strides(A), contiguous_axis(A))
510541

test/runtests.jl

Lines changed: 73 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -207,22 +207,22 @@ end
207207
@testset "Range Interface" begin
208208
@testset "Range Constructors" begin
209209
@test @inferred(StaticInt(1):StaticInt(10)) == 1:10
210-
@test @inferred(StaticInt(1):StaticInt(2):StaticInt(10)) == 1:2:10
210+
@test @inferred(StaticInt(1):StaticInt(2):StaticInt(10)) == 1:2:10
211211
@test @inferred(1:StaticInt(2):StaticInt(10)) == 1:2:10
212212
@test @inferred(StaticInt(1):StaticInt(2):10) == 1:2:10
213-
@test @inferred(StaticInt(1):2:StaticInt(10)) == 1:2:10
213+
@test @inferred(StaticInt(1):2:StaticInt(10)) == 1:2:10
214214
@test @inferred(1:2:StaticInt(10)) == 1:2:10
215215
@test @inferred(1:StaticInt(2):10) == 1:2:10
216-
@test @inferred(StaticInt(1):2:10) == 1:2:10
217-
@test @inferred(StaticInt(1):UInt(10)) === StaticInt(1):10
216+
@test @inferred(StaticInt(1):2:10) == 1:2:10
217+
@test @inferred(StaticInt(1):UInt(10)) === StaticInt(1):10
218218
@test @inferred(UInt(1):StaticInt(1):StaticInt(10)) === 1:StaticInt(10)
219219
@test @inferred(ArrayInterface.OptionallyStaticUnitRange{Int,Int}(1:10)) == 1:10
220220
@test @inferred(ArrayInterface.OptionallyStaticUnitRange(1:10)) == 1:10
221221

222222
@inferred(ArrayInterface.OptionallyStaticUnitRange(1:10))
223223

224-
@test @inferred(ArrayInterface.OptionallyStaticStepRange(StaticInt(1), 1, UInt(10))) == StaticInt(1):1:10
225-
@test @inferred(ArrayInterface.OptionallyStaticStepRange(UInt(1), 1, StaticInt(10))) == StaticInt(1):1:10
224+
@test @inferred(ArrayInterface.OptionallyStaticStepRange(StaticInt(1), 1, UInt(10))) == StaticInt(1):1:10
225+
@test @inferred(ArrayInterface.OptionallyStaticStepRange(UInt(1), 1, StaticInt(10))) == StaticInt(1):1:10
226226
@test @inferred(ArrayInterface.OptionallyStaticStepRange(1:10)) == 1:1:10
227227

228228
@test_throws ArgumentError ArrayInterface.OptionallyStaticUnitRange(1:2:10)
@@ -331,7 +331,6 @@ using OffsetArrays
331331
@test @inferred(ArrayInterface.defines_strides(D1))
332332
@test !@inferred(ArrayInterface.defines_strides(view(A, :, [1,2],1)))
333333
@test @inferred(ArrayInterface.defines_strides(DenseWrapper{Int,2,Matrix{Int}}))
334-
335334
@test @inferred(device(A)) === ArrayInterface.CPUPointer()
336335
@test @inferred(device(B)) === ArrayInterface.CPUIndex()
337336
@test @inferred(device(-1:19)) === ArrayInterface.CPUIndex()
@@ -372,7 +371,7 @@ using OffsetArrays
372371
@test @inferred(contiguous_axis(PermutedDimsArray(DummyZeros(3,4), (2, 1)))) === nothing
373372
@test @inferred(contiguous_axis(view(DummyZeros(3,4), 1, :))) === nothing
374373
@test @inferred(contiguous_axis(view(DummyZeros(3,4), 1, :)')) === nothing
375-
374+
376375
@test @inferred(ArrayInterface.contiguous_axis_indicator(@SArray(zeros(2,2,2)))) == (true,false,false)
377376
@test @inferred(ArrayInterface.contiguous_axis_indicator(A)) == (true,false,false)
378377
@test @inferred(ArrayInterface.contiguous_axis_indicator(B)) == (true,false,false)
@@ -424,7 +423,7 @@ using OffsetArrays
424423
@test @inferred(stride_rank(PermutedDimsArray(DummyZeros(3,4), (2, 1)))) === nothing
425424
@test @inferred(stride_rank(view(DummyZeros(3,4), 1, :))) === nothing
426425

427-
426+
428427
#=
429428
@btime ArrayInterface.is_column_major($(PermutedDimsArray(A,(3,1,2))))
430429
0.047 ns (0 allocations: 0 bytes)
@@ -494,11 +493,11 @@ using OffsetArrays
494493
@test @inferred(ArrayInterface.defines_strides(C1))
495494
@test @inferred(ArrayInterface.defines_strides(C2))
496495
@test @inferred(ArrayInterface.defines_strides(C3))
497-
496+
498497
@test @inferred(device(C1)) === ArrayInterface.CPUPointer()
499498
@test @inferred(device(C2)) === ArrayInterface.CPUPointer()
500499
@test @inferred(device(C3)) === ArrayInterface.CPUPointer()
501-
500+
502501
@test @inferred(contiguous_batch_size(C1)) === ArrayInterface.StaticInt(0)
503502
@test @inferred(contiguous_batch_size(C2)) === ArrayInterface.StaticInt(0)
504503
@test @inferred(contiguous_batch_size(C3)) === ArrayInterface.StaticInt(0)
@@ -510,7 +509,7 @@ using OffsetArrays
510509
@test @inferred(contiguous_axis(C1)) === StaticInt(1)
511510
@test @inferred(contiguous_axis(C2)) === StaticInt(0)
512511
@test @inferred(contiguous_axis(C3)) === StaticInt(2)
513-
512+
514513
@test @inferred(ArrayInterface.contiguous_axis_indicator(C1)) == (true,false,false,false)
515514
@test @inferred(ArrayInterface.contiguous_axis_indicator(C2)) == (false,false)
516515
@test @inferred(ArrayInterface.contiguous_axis_indicator(C3)) == (false,true)
@@ -535,6 +534,8 @@ end
535534
R = StaticInt(1):StaticInt(2);
536535
Rnr = reinterpret(Int32, R);
537536
Ar = reinterpret(Float32, A);
537+
A2 = zeros(4, 3, 5)
538+
A2r = reinterpret(ComplexF64, A2)
538539

539540
sv5 = @SVector(zeros(5)); v5 = Vector{Float64}(undef, 5);
540541
@test @inferred(ArrayInterface.size(sv5)) === (StaticInt(5),)
@@ -546,6 +547,8 @@ end
546547
@test @inferred(ArrayInterface.size(R)) === (StaticInt(2),)
547548
@test @inferred(ArrayInterface.size(Rnr)) === (StaticInt(4),)
548549
@test @inferred(ArrayInterface.known_length(Rnr)) === 4
550+
@test @inferred(ArrayInterface.size(A2)) === (4,3,5)
551+
@test @inferred(ArrayInterface.size(A2r)) === (2,3,5)
549552

550553
@test @inferred(ArrayInterface.size(S)) === (StaticInt(2), StaticInt(3), StaticInt(4))
551554
@test @inferred(ArrayInterface.size(Sp)) === (2, 2, StaticInt(3))
@@ -577,6 +580,8 @@ end
577580
@test @inferred(ArrayInterface.known_size(Ar)) === (nothing,nothing, nothing,)
578581
@test @inferred(ArrayInterface.known_size(Ar, static(1))) === nothing
579582
@test @inferred(ArrayInterface.known_size(Ar, static(4))) === 1
583+
@test @inferred(ArrayInterface.known_size(A2)) === (nothing, nothing, nothing)
584+
@test @inferred(ArrayInterface.known_size(A2r)) === (nothing, nothing, nothing)
580585

581586
@test @inferred(ArrayInterface.known_size(S)) === (2, 3, 4)
582587
@test @inferred(ArrayInterface.known_size(Wrapper(S))) === (2, 3, 4)
@@ -596,6 +601,8 @@ end
596601
@test @inferred(ArrayInterface.strides(A)) == strides(A)
597602
@test @inferred(ArrayInterface.strides(Ap)) == strides(Ap)
598603
@test @inferred(ArrayInterface.strides(Ar)) === (StaticInt{1}(), 6, 24)
604+
@test @inferred(ArrayInterface.strides(A2)) === (StaticInt(1), 4, 12)
605+
@test @inferred(ArrayInterface.strides(A2r)) === (StaticInt(1), 2, 6)
599606

600607
@test @inferred(ArrayInterface.strides(S)) === (StaticInt(1), StaticInt(2), StaticInt(6))
601608
@test @inferred(ArrayInterface.strides(Sp)) === (StaticInt(6), StaticInt(1), StaticInt(2))
@@ -618,6 +625,8 @@ end
618625
@test @inferred(ArrayInterface.known_strides(Ap)) === (1, nothing)
619626
@test @inferred(ArrayInterface.known_strides(Ar)) === (1, nothing, nothing)
620627
@test @inferred(ArrayInterface.known_strides(reshape(view(zeros(100), 1:60), (3,4,5)))) === (1, nothing, nothing)
628+
@test @inferred(ArrayInterface.known_strides(A2)) === (1, nothing, nothing)
629+
@test @inferred(ArrayInterface.known_strides(A2r)) === (1, nothing, nothing)
621630

622631
@test @inferred(ArrayInterface.known_strides(S)) === (1, 2, 6)
623632
@test @inferred(ArrayInterface.known_strides(Sp)) === (6, 1, 2)
@@ -635,6 +644,8 @@ end
635644
@test @inferred(ArrayInterface.offsets(A)) === (StaticInt(1), StaticInt(1), StaticInt(1))
636645
@test @inferred(ArrayInterface.offsets(Ap)) === (StaticInt(1), StaticInt(1))
637646
@test @inferred(ArrayInterface.offsets(Ar)) === (StaticInt(1), StaticInt(1), StaticInt(1))
647+
@test @inferred(ArrayInterface.offsets(A2)) === (StaticInt(1), StaticInt(1), StaticInt(1))
648+
@test @inferred(ArrayInterface.offsets(A2r)) === (StaticInt(1), StaticInt(1), StaticInt(1))
638649

639650
@test @inferred(ArrayInterface.offsets(S)) === (StaticInt(1), StaticInt(1), StaticInt(1))
640651
@test @inferred(ArrayInterface.offsets(Sp)) === (StaticInt(1), StaticInt(1), StaticInt(1))
@@ -649,6 +660,8 @@ end
649660
@test @inferred(ArrayInterface.known_offsets(Ar)) === (1, 1, 1)
650661
@test @inferred(ArrayInterface.known_offsets(Ar, static(1))) === 1
651662
@test @inferred(ArrayInterface.known_offsets(Ar, static(4))) === 1
663+
@test @inferred(ArrayInterface.known_offsets(A2)) === (1, 1, 1)
664+
@test @inferred(ArrayInterface.known_offsets(A2r)) === (1, 1, 1)
652665

653666
@test @inferred(ArrayInterface.known_offsets(S)) === (1, 1, 1)
654667
@test @inferred(ArrayInterface.known_offsets(Sp)) === (1, 1, 1)
@@ -675,7 +688,7 @@ end
675688
colormat = reinterpret(reshape, Float64, colors)
676689
@test @inferred(ArrayInterface.strides(colormat)) === (StaticInt(1), StaticInt(3))
677690
@test @inferred(ArrayInterface.dense_dims(colormat)) === (True(),True())
678-
@test @inferred(ArrayInterface.dense_dims(view(colormat,:,4))) === (True(),)
691+
@test @inferred(ArrayInterface.dense_dims(view(colormat,:,4))) === (True(),)
679692
@test @inferred(ArrayInterface.dense_dims(view(colormat,:,4:7))) === (True(),True())
680693
@test @inferred(ArrayInterface.dense_dims(view(colormat,2:3,:))) === (True(),False())
681694

@@ -702,7 +715,7 @@ end
702715
@test @inferred(ArrayInterface.strides(Ac2r)) === (StaticInt(1), StaticInt(2), 10)
703716
Ac2r_static = reinterpret(reshape, Float64, view(@MMatrix(rand(ComplexF64, 5, 7)), 2:4, 3:6));
704717
@test @inferred(ArrayInterface.strides(Ac2r_static)) === (StaticInt(1), StaticInt(2), StaticInt(10))
705-
718+
706719
Ac2t = reinterpret(reshape, Tuple{Float64,Float64}, view(rand(ComplexF64, 5, 7), 2:4, 3:6));
707720
@test @inferred(ArrayInterface.strides(Ac2t)) === (StaticInt(1), 5)
708721
Ac2t_static = reinterpret(reshape, Tuple{Float64,Float64}, view(@MMatrix(rand(ComplexF64, 5, 7)), 2:4, 3:6));
@@ -711,6 +724,51 @@ end
711724
end
712725
end
713726

727+
@testset "Reshaped views" begin
728+
u_base = randn(10, 10)
729+
u_view = view(u_base, 3, :)
730+
u_reshaped_view1 = reshape(u_view, 1, :)
731+
u_reshaped_view2 = reshape(u_view, 2, :)
732+
733+
@test @inferred(ArrayInterface.defines_strides(u_base))
734+
@test @inferred(ArrayInterface.defines_strides(u_view))
735+
@test @inferred(ArrayInterface.defines_strides(u_reshaped_view1))
736+
@test @inferred(ArrayInterface.defines_strides(u_reshaped_view2))
737+
738+
# See https://github.com/JuliaArrays/ArrayInterface.jl/issues/160
739+
@test @inferred(ArrayInterface.strides(u_base)) == (StaticInt(1), 10)
740+
@test @inferred(ArrayInterface.strides(u_view)) == (10,)
741+
@test @inferred(ArrayInterface.strides(u_reshaped_view1)) == (10, 10)
742+
@test @inferred(ArrayInterface.strides(u_reshaped_view2)) == (10, 20)
743+
744+
# See https://github.com/JuliaArrays/ArrayInterface.jl/issues/157
745+
@test @inferred(ArrayInterface.dense_dims(u_base)) == (True(), True())
746+
@test @inferred(ArrayInterface.dense_dims(u_view)) == (False(),)
747+
@test @inferred(ArrayInterface.dense_dims(u_reshaped_view1)) == (False(), False())
748+
@test @inferred(ArrayInterface.dense_dims(u_reshaped_view2)) == (False(), False())
749+
end
750+
751+
@testset "Reinterpreted reshaped views" begin
752+
u_base = randn(1, 4, 4, 5)
753+
u_vectors = reshape(reinterpret(SVector{1, eltype(u_base)}, u_base),
754+
Base.tail(size(u_base))...)
755+
u_view = view(u_vectors, 2, :, 3)
756+
u_view_reinterpreted = reinterpret(eltype(u_base), u_view)
757+
u_view_reshaped = reshape(u_view_reinterpreted, 1, length(u_view))
758+
759+
# See https://github.com/JuliaArrays/ArrayInterface.jl/issues/163
760+
@test @inferred(ArrayInterface.strides(u_base)) == (StaticInt(1), 1, 4, 16)
761+
@test @inferred(ArrayInterface.strides(u_vectors)) == (StaticInt(1), 4, 16)
762+
@test @inferred(ArrayInterface.strides(u_view)) == (4,)
763+
if VERSION v"1.6.0-DEV.1581"
764+
@test @inferred(ArrayInterface.strides(u_view_reinterpreted)) == (4,)
765+
@test @inferred(ArrayInterface.strides(u_view_reshaped)) == (4, 4)
766+
else
767+
@test_broken @inferred(ArrayInterface.strides(u_view_reinterpreted)) == (4,)
768+
@test_broken @inferred(ArrayInterface.strides(u_view_reshaped)) == (4, 4)
769+
end
770+
end
771+
714772
@test ArrayInterface.can_avx(ArrayInterface.can_avx) == false
715773

716774
@testset "can_change_size" begin
@@ -842,6 +900,6 @@ end
842900
@test @inferred(is_lazy_conjugate(d)) == false
843901
e = permutedims(d)
844902
@test @inferred(is_lazy_conjugate(e)) == false
845-
903+
846904
@test @inferred(is_lazy_conjugate([1,2,3]')) == false # We don't care about conj on `<:Real`
847905
end

0 commit comments

Comments
 (0)