@@ -3,34 +3,39 @@ module TestLayoutArray
3
3
using ArrayLayouts, LinearAlgebra, FillArrays, Test, SparseArrays
4
4
using ArrayLayouts: sub_materialize, MemoryLayout, ColumnNorm, RowMaximum, CRowMaximum, @_layoutlmul
5
5
import ArrayLayouts: triangulardata
6
+ import LinearAlgebra: Diagonal, Bidiagonal, Tridiagonal, SymTridiagonal
6
7
7
- struct MyMatrix <: LayoutMatrix{Float64 }
8
- A:: Matrix{Float64}
8
+ struct MyMatrix{T,M <: AbstractMatrix{T} } <: LayoutMatrix{T }
9
+ A:: M
9
10
end
10
11
11
12
Base. getindex (A:: MyMatrix , k:: Int , j:: Int ) = A. A[k,j]
12
13
Base. setindex! (A:: MyMatrix , v, k:: Int , j:: Int ) = setindex! (A. A, v, k, j)
13
14
Base. size (A:: MyMatrix ) = size (A. A)
14
15
Base. strides (A:: MyMatrix ) = strides (A. A)
15
- Base. elsize (:: Type{MyMatrix} ) = sizeof (Float64 )
16
- Base. cconvert (:: Type{Ptr{Float64 }} , A:: MyMatrix ) = A. A
17
- Base. unsafe_convert (:: Type{Ptr{Float64 }} , A:: MyMatrix ) = Base. unsafe_convert (Ptr{Float64 }, A. A)
18
- MemoryLayout (:: Type{MyMatrix} ) = DenseColumnMajor ( )
16
+ Base. elsize (:: Type{<: MyMatrix{T}} ) where {T} = sizeof (T )
17
+ Base. cconvert (:: Type{Ptr{T }} , A:: MyMatrix{T} ) where {T} = Base . cconvert (Ptr{T}, A. A)
18
+ Base. unsafe_convert (:: Type{Ptr{T }} , A:: MyMatrix{T} ) where {T} = Base. unsafe_convert (Ptr{T }, A. A)
19
+ MemoryLayout (:: Type{MyMatrix{T,M}} ) where {T,M} = MemoryLayout (M )
19
20
Base. copy (A:: MyMatrix ) = MyMatrix (copy (A. A))
21
+ ArrayLayouts. bidiagonaluplo (M:: MyMatrix ) = ArrayLayouts. bidiagonaluplo (M. A)
22
+ for MT in (:Diagonal , :Bidiagonal , :Tridiagonal , :SymTridiagonal )
23
+ @eval $ MT (M:: MyMatrix ) = $ MT (M. A)
24
+ end
20
25
21
- struct MyVector{T} <: LayoutVector{T}
22
- A:: Vector{T}
26
+ struct MyVector{T,V <: AbstractVector{T} } <: LayoutVector{T}
27
+ A:: V
23
28
end
24
29
25
30
MyVector (M:: MyVector ) = MyVector (M. A)
26
31
Base. getindex (A:: MyVector , k:: Int ) = A. A[k]
27
32
Base. setindex! (A:: MyVector , v, k:: Int ) = setindex! (A. A, v, k)
28
33
Base. size (A:: MyVector ) = size (A. A)
29
34
Base. strides (A:: MyVector ) = strides (A. A)
30
- Base. elsize (:: Type{MyVector} ) = sizeof (Float64 )
31
- Base. cconvert (:: Type{Ptr{T}} , A:: MyVector{T} ) where {T} = A. A
35
+ Base. elsize (:: Type{<: MyVector{T}} ) where {T} = sizeof (T )
36
+ Base. cconvert (:: Type{Ptr{T}} , A:: MyVector{T} ) where {T} = Base . cconvert (Ptr{T}, A. A)
32
37
Base. unsafe_convert (:: Type{Ptr{T}} , A:: MyVector{T} ) where T = Base. unsafe_convert (Ptr{T}, A. A)
33
- MemoryLayout (:: Type{MyVector} ) = DenseColumnMajor ( )
38
+ MemoryLayout (:: Type{MyVector{T,V}} ) where {T,V} = MemoryLayout (V )
34
39
Base. copy (A:: MyVector ) = MyVector (copy (A. A))
35
40
36
41
# These need to test dispatch reduces to ArrayLayouts.mul, etc.
@@ -44,7 +49,7 @@ Base.copy(A::MyVector) = MyVector(copy(A.A))
44
49
@test a[1 : 3 ] == a. A[1 : 3 ]
45
50
@test a[:] == a
46
51
@test (a' )[1 ,:] == (a' )[1 ,1 : 3 ] == a
47
- @test sprint (show, " text/plain" , a) == " 3-element $MyVector {Float64} :\n 1.0\n 2.0\n 3.0"
52
+ @test sprint (show, " text/plain" , a) == " $( summary (a)) :\n 1.0\n 2.0\n 3.0"
48
53
@test B* a ≈ B* a. A
49
54
@test B' * a ≈ B' * a. A
50
55
@test transpose (B)* a ≈ transpose (B)* a. A
@@ -142,16 +147,16 @@ Base.copy(A::MyVector) = MyVector(copy(A.A))
142
147
143
148
@testset " ldiv!" begin
144
149
c = MyVector (randn (5 ))
145
- if VERSION < v " 1.9"
146
- @test_broken ldiv! (lu (A), MyVector (copy (c))) ≈ A \ c
147
- else
148
- @test ldiv! (lu (A), MyVector (copy (c))) ≈ A \ c
149
- end
150
- if VERSION < v " 1.9" || VERSION >= v " 1.10-"
151
- @test_throws ErrorException ldiv! (qr (A), MyVector (copy (c)))
152
- else
153
- @test_throws MethodError ldiv! (qr (A), MyVector (copy (c)))
154
- end
150
+ # if VERSION < v"1.9"
151
+ # @test_broken ldiv!(lu(A), MyVector(copy(c))) ≈ A \ c
152
+ # else
153
+ @test ldiv! (lu (A), MyVector (copy (c))) ≈ A \ c
154
+ # end
155
+ # if VERSION < v"1.9" || VERSION >= v"1.10-"
156
+ # @test_throws ErrorException ldiv!(qr(A), MyVector(copy(c)))
157
+ # else
158
+ # @test_throws MethodError ldiv!(qr(A), MyVector(copy(c)))
159
+ # end
155
160
@test_throws ErrorException ldiv! (eigen (randn (5 ,5 )), c)
156
161
@test ArrayLayouts. ldiv! (svd (A. A), Vector (c)) ≈ ArrayLayouts. ldiv! (similar (c), svd (A. A), c) ≈ A \ c
157
162
if VERSION ≥ v " 1.8"
@@ -215,8 +220,8 @@ Base.copy(A::MyVector) = MyVector(copy(A.A))
215
220
@test B == Ones (5 ,5 )* A + 2.0 Bin
216
221
end
217
222
218
- C = MyMatrix ([1 2 ; 3 4 ])
219
- @test sprint (show, " text/plain" , C) == " 2×2 $MyMatrix :\n 1.0 2.0\n 3.0 4.0"
223
+ C = MyMatrix (Float64 [1 2 ; 3 4 ])
224
+ @test sprint (show, " text/plain" , C) == " $( summary (C)) :\n 1.0 2.0\n 3.0 4.0"
220
225
221
226
@testset " layoutldiv" begin
222
227
A = MyMatrix (randn (5 ,5 ))
@@ -343,18 +348,27 @@ Base.copy(A::MyVector) = MyVector(copy(A.A))
343
348
@testset " Diagonal * Bidiagonal/Tridiagonal with structured diags" begin
344
349
n = size (D,1 )
345
350
B = Bidiagonal (map (MyVector, (rand (n), rand (n- 1 )))... , :U )
351
+ MB = MyMatrix (B)
346
352
S = SymTridiagonal (map (MyVector, (rand (n), rand (n- 1 )))... )
353
+ MS = MyMatrix (S)
347
354
T = Tridiagonal (map (MyVector, (rand (n- 1 ), rand (n), rand (n- 1 )))... )
355
+ MT = MyMatrix (T)
348
356
DA, BA, SA, TA = map (Array, (D, B, S, T))
349
357
if VERSION >= v " 1.11"
350
358
@test D * B ≈ DA * BA
351
359
@test B * D ≈ BA * DA
360
+ @test D * MB ≈ DA * BA
361
+ @test MB * D ≈ BA * DA
352
362
end
353
363
if VERSION >= v " 1.12.0-DEV.824"
354
364
@test D * S ≈ DA * SA
365
+ @test D * MS ≈ DA * SA
355
366
@test D * T ≈ DA * TA
367
+ @test D * MT ≈ DA * TA
356
368
@test S * D ≈ SA * DA
369
+ @test MS * D ≈ SA * DA
357
370
@test T * D ≈ TA * DA
371
+ @test MT * D ≈ TA * DA
358
372
end
359
373
end
360
374
end
0 commit comments