Skip to content

Commit adf8a41

Browse files
committed
Add tests for ReinterpretArray and ranges
1 parent 05c3eba commit adf8a41

File tree

2 files changed

+28
-10
lines changed

2 files changed

+28
-10
lines changed

src/stridelayout.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -304,11 +304,11 @@ axes_types(::Type{T}) where {T<:Transpose} = _perm_tuple(axes_types(parent_type(
304304
function axes_types(::Type{T}) where {I1,T<:PermutedDimsArray{<:Any,<:Any,I1}}
305305
return _perm_tuple(axes_types(parent_type(T)), Val(I1))
306306
end
307-
function axes_types(::Type{T}) where {T<:OptionallyStaticRange}
307+
function axes_types(::Type{T}) where {T<:AbstractRange}
308308
if known_length(T) === nothing
309309
return Tuple{OptionallyStaticUnitRange{One,Int}}
310310
else
311-
return Tuple{OptionallyStaticUnitRange{One,StaticInt{known_length(T) - 1}}}
311+
return Tuple{OptionallyStaticUnitRange{One,StaticInt{known_length(T)}}}
312312
end
313313
end
314314

@@ -331,16 +331,15 @@ end
331331
end
332332

333333
@inline function axes_types(::Type{T}) where {T<:Base.ReinterpretArray}
334-
return _reinterpret_axes_types(axes_type(parent_type(T)), eltype(T), eltype(parent_type(T)))
334+
return _reinterpret_axes_types(axes_types(parent_type(T)), eltype(T), eltype(parent_type(T)))
335335
end
336336
@generated function _reinterpret_axes_types(::Type{I}, ::Type{T}, ::Type{S}) where {I<:Tuple,T,S}
337337
out = Expr(:curly, :Tuple)
338-
for i in 1:length(T.parameters)
338+
for i in 1:length(I.parameters)
339339
if i === 1
340-
push!(out.args, :(reinterpret_axis_type($(I.parameters[1]), $T, $S)))
340+
push!(out.args, reinterpret_axis_type(I.parameters[1], T, S))
341341
else
342-
# FIXME double check this once I've slept
343-
push!(out.args, :($(I.parameters[i])))
342+
push!(out.args, I.parameters[i])
344343
end
345344
end
346345
Expr(:block, Expr(:meta, :inline), out)
@@ -362,7 +361,7 @@ end
362361
if known_length(A) === nothing
363362
return OptionallyStaticUnitRange{One,Int}
364363
else
365-
return OptionallyStaticUnitRange{One,StaticInt{Int(known_length(A) / (sizeof(T) / sizeof(S))) - 1}}
364+
return OptionallyStaticUnitRange{One,StaticInt{Int(known_length(A) / (sizeof(T) / sizeof(S)))}}
366365
end
367366
end
368367

@@ -531,6 +530,7 @@ end
531530
Expr(:block, Expr(:meta, :inline), t)
532531
end
533532

533+
@inline size(v::AbstractVector) = (static_length(axes_types(v, 1)),)
534534
@inline size(B::Union{Transpose{T,A},Adjoint{T,A}}) where {T,A<:AbstractMatrix{T}} = permute(size(parent(B)), Val{(2,1)}())
535535
@inline size(B::PermutedDimsArray{T,N,I1,I2,A}) where {T,N,I1,I2,A<:AbstractArray{T,N}} = permute(size(parent(B)), Val{I1}())
536536
@inline size(A::AbstractArray, ::StaticInt{N}) where {N} = size(A)[N]

test/runtests.jl

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,6 @@ using OffsetArrays
355355
@test @inferred(dense_dims(@view(PermutedDimsArray(A,(3,1,2))[2:3,:,[1,2]]))) === ArrayInterface.DenseDims((false,true,false))
356356
@test @inferred(dense_dims(@view(PermutedDimsArray(A,(3,1,2))[2:3,[1,2,3],:]))) === ArrayInterface.DenseDims((false,false,false))
357357

358-
359358
B = Array{Int8}(undef, 2,2,2,2);
360359
doubleperm = PermutedDimsArray(PermutedDimsArray(B,(4,2,3,1)), (4,2,1,3));
361360
@test collect(strides(B))[collect(stride_rank(doubleperm))] == collect(strides(doubleperm))
@@ -368,11 +367,18 @@ end
368367
Sp2 = @view(PermutedDimsArray(S,(3,2,1))[2:3,:,:]);
369368
Mp2 = @view(PermutedDimsArray(M,(3,1,2))[2:3,:,2])';
370369
D = @view(A[:,2:2:4,:])
370+
R = StaticInt(1):StaticInt(2)
371+
Rr = reinterpret(Int32, R)
372+
Ar = reinterpret(Float32, A)
371373

374+
372375
@test @inferred(ArrayInterface.size(A)) === (3,4,5)
373376
@test @inferred(ArrayInterface.size(Ap)) === (2,5)
374377
@test @inferred(ArrayInterface.size(A)) === size(A)
375378
@test @inferred(ArrayInterface.size(Ap)) === size(Ap)
379+
@test @inferred(ArrayInterface.size(R)) === (StaticInt(2),)
380+
@test @inferred(ArrayInterface.size(Rr)) === (StaticInt(4),)
381+
376382

377383
@test @inferred(ArrayInterface.size(S)) === (StaticInt(2), StaticInt(3), StaticInt(4))
378384
@test @inferred(ArrayInterface.size(Sp)) === (2, 2, StaticInt(3))
@@ -394,6 +400,9 @@ end
394400

395401
@test @inferred(ArrayInterface.known_size(A)) === (nothing, nothing, nothing)
396402
@test @inferred(ArrayInterface.known_size(Ap)) === (nothing,nothing)
403+
@test @inferred(ArrayInterface.known_size(R)) === (2,)
404+
@test @inferred(ArrayInterface.known_size(Rr)) === (4,)
405+
@test @inferred(ArrayInterface.known_size(Ar)) === (nothing,nothing, nothing,)
397406

398407
@test @inferred(ArrayInterface.known_size(S)) === (2, 3, 4)
399408
@test @inferred(ArrayInterface.known_size(Sp)) === (nothing, nothing, 3)
@@ -410,6 +419,8 @@ end
410419
@test @inferred(ArrayInterface.strides(Ap)) === (StaticInt(1), 12)
411420
@test @inferred(ArrayInterface.strides(A)) == strides(A)
412421
@test @inferred(ArrayInterface.strides(Ap)) == strides(Ap)
422+
@test @inferred(ArrayInterface.strides(Ar)) === (StaticInt{1}(), 6, 24)
423+
413424

414425
@test @inferred(ArrayInterface.strides(S)) === (StaticInt(1), StaticInt(2), StaticInt(6))
415426
@test @inferred(ArrayInterface.strides(Sp)) === (StaticInt(6), StaticInt(1), StaticInt(2))
@@ -427,7 +438,8 @@ end
427438

428439
@test @inferred(ArrayInterface.known_strides(A)) === (1, nothing, nothing)
429440
@test @inferred(ArrayInterface.known_strides(Ap)) === (1, nothing)
430-
441+
@test @inferred(ArrayInterface.known_strides(Ar)) === (1, nothing, nothing)
442+
431443
@test @inferred(ArrayInterface.known_strides(S)) === (1, 2, 6)
432444
@test @inferred(ArrayInterface.known_strides(Sp)) === (6, 1, 2)
433445
@test @inferred(ArrayInterface.known_strides(Sp2)) === (6, 2, 1)
@@ -441,6 +453,7 @@ end
441453

442454
@test @inferred(ArrayInterface.offsets(A)) === (StaticInt(1), StaticInt(1), StaticInt(1))
443455
@test @inferred(ArrayInterface.offsets(Ap)) === (StaticInt(1), StaticInt(1))
456+
@test @inferred(ArrayInterface.offsets(Ar)) === (StaticInt(1), StaticInt(1), StaticInt(1))
444457

445458
@test @inferred(ArrayInterface.offsets(S)) === (StaticInt(1), StaticInt(1), StaticInt(1))
446459
@test @inferred(ArrayInterface.offsets(Sp)) === (StaticInt(1), StaticInt(1), StaticInt(1))
@@ -452,6 +465,7 @@ end
452465

453466
@test @inferred(ArrayInterface.known_offsets(A)) === (1, 1, 1)
454467
@test @inferred(ArrayInterface.known_offsets(Ap)) === (1, 1)
468+
@test @inferred(ArrayInterface.known_offsets(Ar)) === (1, 1, 1)
455469

456470
@test @inferred(ArrayInterface.known_offsets(S)) === (1, 1, 1)
457471
@test @inferred(ArrayInterface.known_offsets(Sp)) === (1, 1, 1)
@@ -461,6 +475,10 @@ end
461475
@test @inferred(ArrayInterface.known_offsets(Mp)) === (1, 1)
462476
@test @inferred(ArrayInterface.known_offsets(Mp2)) === (1, 1)
463477

478+
@test @inferred(ArrayInterface.known_offsets(R)) === (1,)
479+
@test @inferred(ArrayInterface.known_offsets(Rr)) === (1,)
480+
@test @inferred(ArrayInterface.known_offsets(1:10)) === (1,)
481+
464482
O = OffsetArray(A, 3, 7, 10);
465483
Op = PermutedDimsArray(O,(3,1,2));
466484
@test @inferred(ArrayInterface.offsets(O)) === (4, 8, 11)

0 commit comments

Comments
 (0)