Skip to content

Commit 5a82edf

Browse files
authored
_copyto! (#8)
* _copyto! * add factorizations * increase coverage
1 parent e8e3144 commit 5a82edf

File tree

7 files changed

+81
-10
lines changed

7 files changed

+81
-10
lines changed

.travis.yml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,6 @@ os:
66
- windows
77
julia:
88
- 1.0
9-
- 1.1
10-
- 1.2
11-
- 1.3
129
- 1.4
1310
- nightly
1411
matrix:

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ArrayLayouts"
22
uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
33
authors = ["Sheehan Olver <[email protected]>"]
4-
version = "0.2.0"
4+
version = "0.2.1"
55

66
[deps]
77
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"

src/ArrayLayouts.jl

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ import Base.Broadcast: BroadcastStyle, AbstractArrayStyle, Broadcasted, broadcas
3333
materialize!, eltypes
3434

3535
import LinearAlgebra: AbstractTriangular, AbstractQ, checksquare, pinv, fill!, tilebufsize, Abuf, Bbuf, Cbuf, dot, factorize, qr, lu, cholesky,
36-
norm2, norm1, normInf, normMinusInf
36+
norm2, norm1, normInf, normMinusInf, qr, lu, qr!, lu!
3737

3838
import LinearAlgebra.BLAS: BlasFloat, BlasReal, BlasComplex
3939

@@ -85,6 +85,7 @@ macro layoutmatrix(Typ)
8585
ArrayLayouts.@layoutldiv $Typ
8686
ArrayLayouts.@layoutmul $Typ
8787
ArrayLayouts.@layoutlmul $Typ
88+
ArrayLayouts.@layoutfactorizations $Typ
8889

8990
@inline Base.getindex(A::$Typ, kr::Colon, jr::Colon) = ArrayLayouts.layout_getindex(A, kr, jr)
9091
@inline Base.getindex(A::$Typ, kr::Colon, jr::AbstractUnitRange) = ArrayLayouts.layout_getindex(A, kr, jr)
@@ -98,6 +99,30 @@ end
9899

99100
@layoutmatrix LayoutMatrix
100101

102+
_copyto!(_, _, dest::AbstractArray{T,N}, src::AbstractArray{V,N}) where {T,V,N} =
103+
Base.invoke(copyto!, Tuple{AbstractArray{T,N},AbstractArray{V,N}}, dest, src)
104+
105+
106+
copyto!(dest::LayoutArray{<:Any,N}, src::LayoutArray{<:Any,N}) where N =
107+
_copyto!(MemoryLayout(typeof(dest)), MemoryLayout(typeof(src)), dest, src)
108+
copyto!(dest::AbstractArray{<:Any,N}, src::LayoutArray{<:Any,N}) where N =
109+
_copyto!(MemoryLayout(typeof(dest)), MemoryLayout(typeof(src)), dest, src)
110+
copyto!(dest::LayoutArray{<:Any,N}, src::AbstractArray{<:Any,N}) where N =
111+
_copyto!(MemoryLayout(typeof(dest)), MemoryLayout(typeof(src)), dest, src)
112+
113+
copyto!(dest::SubArray{<:Any,N,<:LayoutArray}, src::SubArray{<:Any,N,<:LayoutArray}) where N =
114+
_copyto!(MemoryLayout(typeof(dest)), MemoryLayout(typeof(src)), dest, src)
115+
copyto!(dest::SubArray{<:Any,N,<:LayoutArray}, src::LayoutArray{<:Any,N}) where N =
116+
_copyto!(MemoryLayout(typeof(dest)), MemoryLayout(typeof(src)), dest, src)
117+
copyto!(dest::LayoutArray{<:Any,N}, src::SubArray{<:Any,N,<:LayoutArray}) where N =
118+
_copyto!(MemoryLayout(typeof(dest)), MemoryLayout(typeof(src)), dest, src)
119+
copyto!(dest::SubArray{<:Any,N,<:LayoutArray}, src::AbstractArray{<:Any,N}) where N =
120+
_copyto!(MemoryLayout(typeof(dest)), MemoryLayout(typeof(src)), dest, src)
121+
copyto!(dest::AbstractArray{<:Any,N}, src::SubArray{<:Any,N,<:LayoutArray}) where N =
122+
_copyto!(MemoryLayout(typeof(dest)), MemoryLayout(typeof(src)), dest, src)
123+
124+
125+
101126
zero!(A::AbstractArray{T}) where T = fill!(A,zero(T))
102127
function zero!(A::AbstractArray{<:AbstractArray})
103128
for a in A

src/factorizations.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,29 @@ function copyto!(dest::AbstractArray, M::Ldiv{QLayout})
2525
end
2626

2727
materialize!(M::Ldiv{QLayout}) = materialize!(Lmul(M.A',M.B))
28+
29+
_qr(layout, axes, A; kwds...) = Base.invoke(qr, Tuple{AbstractMatrix{eltype(A)}}, A; kwds...)
30+
_qr(layout, axes, A, pivot::P; kwds...) where P = Base.invoke(qr, Tuple{AbstractMatrix{eltype(A)},P}, A, pivot; kwds...)
31+
_lu(layout, axes, A; kwds...) = Base.invoke(lu, Tuple{AbstractMatrix{eltype(A)}}, A; kwds...)
32+
_lu(layout, axes, A, pivot::P; kwds...) where P = Base.invoke(lu, Tuple{AbstractMatrix{eltype(A)},P}, A, pivot; kwds...)
33+
_qr!(layout, axes, A, args...; kwds...) = error("Overload _qr!(::$(typeof(layout)), axes, A)")
34+
_lu!(layout, axes, A, args...; kwds...) = error("Overload _lu!(::$(typeof(layout)), axes, A)")
35+
_factorize(layout, axes, A) = Base.invoke(factorize, Tuple{AbstractMatrix{eltype(A)}}, A)
36+
37+
macro _layoutfactorizations(Typ)
38+
esc(quote
39+
LinearAlgebra.qr(A::$Typ, args...; kwds...) = ArrayLayouts._qr(ArrayLayouts.MemoryLayout(typeof(A)), axes(A), A, args...; kwds...)
40+
LinearAlgebra.qr!(A::$Typ, args...; kwds...) = ArrayLayouts._qr!(ArrayLayouts.MemoryLayout(typeof(A)), axes(A), A, args...; kwds...)
41+
LinearAlgebra.lu(A::$Typ, pivot::Union{Val{false}, Val{true}}; kwds...) = ArrayLayouts._lu(ArrayLayouts.MemoryLayout(typeof(A)), axes(A), A, pivot; kwds...)
42+
LinearAlgebra.lu(A::$Typ{T}; kwds...) where T = ArrayLayouts._lu(ArrayLayouts.MemoryLayout(typeof(A)), axes(A), A; kwds...)
43+
LinearAlgebra.lu!(A::$Typ, args...; kwds...) = ArrayLayouts._lu!(ArrayLayouts.MemoryLayout(typeof(A)), axes(A), A, args...; kwds...)
44+
LinearAlgebra.factorize(A::$Typ) = ArrayLayouts._factorize(ArrayLayouts.MemoryLayout(typeof(A)), axes(A), A)
45+
end)
46+
end
47+
48+
macro layoutfactorizations(Typ)
49+
esc(quote
50+
ArrayLayouts.@_layoutfactorizations $Typ
51+
ArrayLayouts.@_layoutfactorizations SubArray{<:Any,2,<:$Typ}
52+
end)
53+
end

src/ldiv.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ end
6363
Rdiv(instantiate(L.A), instantiate(L.B))
6464
end
6565

66-
__ldiv!(::Mat, ::Mat, B) where Mat = error("Overload ldiv! for $Mat")
66+
__ldiv!(::Mat, ::Mat, B) where Mat = error("Overload materialize!(::Ldiv{$(typeof(MemoryLayout(Mat))),$(typeof(MemoryLayout(typeof(B))))})")
6767
__ldiv!(_, F, B) = ldiv!(F, B)
6868
@inline _ldiv!(A, B) = __ldiv!(A, factorize(A), B)
6969
@inline _ldiv!(A::Factorization, B) = ldiv!(A, B)
@@ -131,6 +131,11 @@ macro layoutldiv(Typ)
131131
ArrayLayouts.@_layoutldiv UnitUpperTriangular{T, <:$Typ{T}} where T
132132
ArrayLayouts.@_layoutldiv LowerTriangular{T, <:$Typ{T}} where T
133133
ArrayLayouts.@_layoutldiv UnitLowerTriangular{T, <:$Typ{T}} where T
134+
135+
ArrayLayouts.@_layoutldiv UpperTriangular{T, <:SubArray{T,2,<:$Typ{T}}} where T
136+
ArrayLayouts.@_layoutldiv UnitUpperTriangular{T, <:SubArray{T,2,<:$Typ{T}}} where T
137+
ArrayLayouts.@_layoutldiv LowerTriangular{T, <:SubArray{T,2,<:$Typ{T}}} where T
138+
ArrayLayouts.@_layoutldiv UnitLowerTriangular{T, <:SubArray{T,2,<:$Typ{T}}} where T
134139
end)
135140
end
136141

src/memorylayout.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,7 @@ MemoryLayout(A::Type{UpperTriangular{T,P}}) where {T,P} = triangularlayout(Upper
393393
MemoryLayout(A::Type{UnitUpperTriangular{T,P}}) where {T,P} = triangularlayout(UnitUpperTriangularLayout, MemoryLayout(P))
394394
MemoryLayout(A::Type{LowerTriangular{T,P}}) where {T,P} = triangularlayout(LowerTriangularLayout, MemoryLayout(P))
395395
MemoryLayout(A::Type{UnitLowerTriangular{T,P}}) where {T,P} = triangularlayout(UnitLowerTriangularLayout, MemoryLayout(P))
396-
triangularlayout(_, ::MemoryLayout) = UnknownLayout()
396+
triangularlayout(_::Type{Tri}, ::MemoryLayout) where Tri = Tri{UnknownLayout}()
397397
triangularlayout(::Type{Tri}, ::ML) where {Tri, ML<:AbstractColumnMajor} = Tri{ML}()
398398
triangularlayout(::Type{Tri}, ::ML) where {Tri, ML<:AbstractRowMajor} = Tri{ML}()
399399
triangularlayout(::Type{Tri}, ::ML) where {Tri, ML<:ConjLayout{<:AbstractRowMajor}} = Tri{ML}()

test/runtests.jl

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ struct MyMatrix <: LayoutMatrix{Float64}
1010
end
1111

1212
Base.getindex(A::MyMatrix, k::Int, j::Int) = A.A[k,j]
13+
Base.setindex!(A::MyMatrix, v, k::Int, j::Int) = setindex!(A.A, v, k, j)
1314
Base.size(A::MyMatrix) = size(A.A)
1415
Base.strides(A::MyMatrix) = strides(A.A)
1516
Base.unsafe_convert(::Type{Ptr{T}}, A::MyMatrix) where T = Base.unsafe_convert(Ptr{T}, A.A)
@@ -25,6 +26,22 @@ MemoryLayout(::Type{MyMatrix}) = DenseColumnMajor()
2526
@test ldiv!(Tri(A), copy(b)) ldiv!(Tri(A.A), copy(b))
2627
@test lmul!(Tri(A), copy(b)) lmul!(Tri(A.A), copy(b))
2728
end
29+
30+
@test copyto!(MyMatrix(Array{Float64}(undef,5,5)), A) == A
31+
@test copyto!(Array{Float64}(undef,5,5), A) == A
32+
@test copyto!(MyMatrix(Array{Float64}(undef,5,5)), A.A) == A
33+
@test copyto!(view(MyMatrix(Array{Float64}(undef,5,5)),1:3,1:3), view(A,1:3,1:3)) == A[1:3,1:3]
34+
@test copyto!(view(MyMatrix(Array{Float64}(undef,5,5)),:,:), A) == A
35+
@test copyto!(MyMatrix(Array{Float64}(undef,3,3)), view(A,1:3,1:3)) == A[1:3,1:3]
36+
@test copyto!(view(MyMatrix(Array{Float64}(undef,5,5)),:,:), A.A) == A
37+
@test copyto!(Array{Float64}(undef,3,3), view(A,1:3,1:3)) == A[1:3,1:3]
38+
39+
@test qr(A).factors qr(A.A).factors
40+
@test qr(A,Val(true)).factors qr(A.A,Val(true)).factors
41+
@test lu(A).factors lu(A.A).factors
42+
@test lu(A,Val(true)).factors lu(A.A,Val(true)).factors
43+
@test_throws ErrorException qr!(A)
44+
@test_throws ErrorException lu!(A)
2845
end
2946

3047
struct MyUpperTriangular{T} <: AbstractMatrix{T}
@@ -60,9 +77,10 @@ triangulardata(A::MyUpperTriangular) = triangulardata(A.A)
6077
A = randn(5,5)
6178
B = randn(5,5)
6279
x = randn(5)
80+
U = MyUpperTriangular(A)
6381

64-
@test lmul!(MyUpperTriangular(A), copy(x)) MyUpperTriangular(A) * x
65-
@test lmul!(MyUpperTriangular(A), copy(B)) MyUpperTriangular(A) * B
82+
@test lmul!(U, copy(x)) U * x
83+
@test lmul!(U, copy(B)) U * B
6684

67-
@test_skip lmul!(MyUpperTriangular(A),view(copy(B),collect(1:5),1:5)) MyUpperTriangular(A) * B
85+
@test_skip lmul!(U,view(copy(B),collect(1:5),1:5)) U * B
6886
end

0 commit comments

Comments
 (0)