Skip to content

Commit 7d22a1e

Browse files
committed
Add test for general method
1 parent 0f77124 commit 7d22a1e

File tree

1 file changed

+38
-24
lines changed

1 file changed

+38
-24
lines changed

test/test_layoutarray.jl

Lines changed: 38 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,34 +3,39 @@ module TestLayoutArray
33
using ArrayLayouts, LinearAlgebra, FillArrays, Test, SparseArrays
44
using ArrayLayouts: sub_materialize, MemoryLayout, ColumnNorm, RowMaximum, CRowMaximum, @_layoutlmul
55
import ArrayLayouts: triangulardata
6+
import LinearAlgebra: Diagonal, Bidiagonal, Tridiagonal, SymTridiagonal
67

7-
struct MyMatrix <: LayoutMatrix{Float64}
8-
A::Matrix{Float64}
8+
struct MyMatrix{T,M<:AbstractMatrix{T}} <: LayoutMatrix{T}
9+
A::M
910
end
1011

1112
Base.getindex(A::MyMatrix, k::Int, j::Int) = A.A[k,j]
1213
Base.setindex!(A::MyMatrix, v, k::Int, j::Int) = setindex!(A.A, v, k, j)
1314
Base.size(A::MyMatrix) = size(A.A)
1415
Base.strides(A::MyMatrix) = strides(A.A)
15-
Base.elsize(::Type{MyMatrix}) = sizeof(Float64)
16-
Base.cconvert(::Type{Ptr{Float64}}, A::MyMatrix) = A.A
17-
Base.unsafe_convert(::Type{Ptr{Float64}}, A::MyMatrix) = Base.unsafe_convert(Ptr{Float64}, A.A)
18-
MemoryLayout(::Type{MyMatrix}) = DenseColumnMajor()
16+
Base.elsize(::Type{<:MyMatrix{T}}) where {T} = sizeof(T)
17+
Base.cconvert(::Type{Ptr{T}}, A::MyMatrix{T}) where {T} = Base.cconvert(Ptr{T}, A.A)
18+
Base.unsafe_convert(::Type{Ptr{T}}, A::MyMatrix{T}) where {T} = Base.unsafe_convert(Ptr{T}, A.A)
19+
MemoryLayout(::Type{MyMatrix{T,M}}) where {T,M} = MemoryLayout(M)
1920
Base.copy(A::MyMatrix) = MyMatrix(copy(A.A))
21+
ArrayLayouts.bidiagonaluplo(M::MyMatrix) = ArrayLayouts.bidiagonaluplo(M.A)
22+
for MT in (:Diagonal, :Bidiagonal, :Tridiagonal, :SymTridiagonal)
23+
@eval $MT(M::MyMatrix) = $MT(M.A)
24+
end
2025

21-
struct MyVector{T} <: LayoutVector{T}
22-
A::Vector{T}
26+
struct MyVector{T,V<:AbstractVector{T}} <: LayoutVector{T}
27+
A::V
2328
end
2429

2530
MyVector(M::MyVector) = MyVector(M.A)
2631
Base.getindex(A::MyVector, k::Int) = A.A[k]
2732
Base.setindex!(A::MyVector, v, k::Int) = setindex!(A.A, v, k)
2833
Base.size(A::MyVector) = size(A.A)
2934
Base.strides(A::MyVector) = strides(A.A)
30-
Base.elsize(::Type{MyVector}) = sizeof(Float64)
31-
Base.cconvert(::Type{Ptr{T}}, A::MyVector{T}) where {T} = A.A
35+
Base.elsize(::Type{<:MyVector{T}}) where {T} = sizeof(T)
36+
Base.cconvert(::Type{Ptr{T}}, A::MyVector{T}) where {T} = Base.cconvert(Ptr{T}, A.A)
3237
Base.unsafe_convert(::Type{Ptr{T}}, A::MyVector{T}) where T = Base.unsafe_convert(Ptr{T}, A.A)
33-
MemoryLayout(::Type{MyVector}) = DenseColumnMajor()
38+
MemoryLayout(::Type{MyVector{T,V}}) where {T,V} = MemoryLayout(V)
3439
Base.copy(A::MyVector) = MyVector(copy(A.A))
3540

3641
# These need to test dispatch reduces to ArrayLayouts.mul, etc.
@@ -44,7 +49,7 @@ Base.copy(A::MyVector) = MyVector(copy(A.A))
4449
@test a[1:3] == a.A[1:3]
4550
@test a[:] == a
4651
@test (a')[1,:] == (a')[1,1:3] == a
47-
@test sprint(show, "text/plain", a) == "3-element $MyVector{Float64}:\n 1.0\n 2.0\n 3.0"
52+
@test sprint(show, "text/plain", a) == "$(summary(a)):\n 1.0\n 2.0\n 3.0"
4853
@test B*a B*a.A
4954
@test B'*a B'*a.A
5055
@test transpose(B)*a transpose(B)*a.A
@@ -142,16 +147,16 @@ Base.copy(A::MyVector) = MyVector(copy(A.A))
142147

143148
@testset "ldiv!" begin
144149
c = MyVector(randn(5))
145-
if VERSION < v"1.9"
146-
@test_broken ldiv!(lu(A), MyVector(copy(c))) A \ c
147-
else
148-
@test ldiv!(lu(A), MyVector(copy(c))) A \ c
149-
end
150-
if VERSION < v"1.9" || VERSION >= v"1.10-"
151-
@test_throws ErrorException ldiv!(qr(A), MyVector(copy(c)))
152-
else
153-
@test_throws MethodError ldiv!(qr(A), MyVector(copy(c)))
154-
end
150+
# if VERSION < v"1.9"
151+
# @test_broken ldiv!(lu(A), MyVector(copy(c))) ≈ A \ c
152+
# else
153+
@test ldiv!(lu(A), MyVector(copy(c))) A \ c
154+
# end
155+
# if VERSION < v"1.9" || VERSION >= v"1.10-"
156+
# @test_throws ErrorException ldiv!(qr(A), MyVector(copy(c)))
157+
# else
158+
# @test_throws MethodError ldiv!(qr(A), MyVector(copy(c)))
159+
# end
155160
@test_throws ErrorException ldiv!(eigen(randn(5,5)), c)
156161
@test ArrayLayouts.ldiv!(svd(A.A), Vector(c)) ArrayLayouts.ldiv!(similar(c), svd(A.A), c) A \ c
157162
if VERSION v"1.8"
@@ -215,8 +220,8 @@ Base.copy(A::MyVector) = MyVector(copy(A.A))
215220
@test B == Ones(5,5)*A + 2.0Bin
216221
end
217222

218-
C = MyMatrix([1 2; 3 4])
219-
@test sprint(show, "text/plain", C) == "2×2 $MyMatrix:\n 1.0 2.0\n 3.0 4.0"
223+
C = MyMatrix(Float64[1 2; 3 4])
224+
@test sprint(show, "text/plain", C) == "$(summary(C)):\n 1.0 2.0\n 3.0 4.0"
220225

221226
@testset "layoutldiv" begin
222227
A = MyMatrix(randn(5,5))
@@ -343,18 +348,27 @@ Base.copy(A::MyVector) = MyVector(copy(A.A))
343348
@testset "Diagonal * Bidiagonal/Tridiagonal with structured diags" begin
344349
n = size(D,1)
345350
B = Bidiagonal(map(MyVector, (rand(n), rand(n-1)))..., :U)
351+
MB = MyMatrix(B)
346352
S = SymTridiagonal(map(MyVector, (rand(n), rand(n-1)))...)
353+
MS = MyMatrix(S)
347354
T = Tridiagonal(map(MyVector, (rand(n-1), rand(n), rand(n-1)))...)
355+
MT = MyMatrix(T)
348356
DA, BA, SA, TA = map(Array, (D, B, S, T))
349357
if VERSION >= v"1.11"
350358
@test D * B DA * BA
351359
@test B * D BA * DA
360+
@test D * MB DA * BA
361+
@test MB * D BA * DA
352362
end
353363
if VERSION >= v"1.12.0-DEV.824"
354364
@test D * S DA * SA
365+
@test D * MS DA * SA
355366
@test D * T DA * TA
367+
@test D * MT DA * TA
356368
@test S * D SA * DA
369+
@test MS * D SA * DA
357370
@test T * D TA * DA
371+
@test MT * D TA * DA
358372
end
359373
end
360374
end

0 commit comments

Comments
 (0)