Skip to content

Commit 6a3c516

Browse files
authored
Overload copymutable_oftype (#156)
* Overload copymutable_oftype * Update test_layoutarray.jl * Overload rdiv! * Update test_layoutarray.jl * Update test_layoutarray.jl * Transpose/AdjointFactorization
1 parent 610da35 commit 6a3c516

File tree

4 files changed

+56
-12
lines changed

4 files changed

+56
-12
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ArrayLayouts"
22
uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
33
authors = ["Sheehan Olver <[email protected]>"]
4-
version = "1.0.13"
4+
version = "1.1.0"
55

66
[deps]
77
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"

src/ArrayLayouts.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,12 @@ copyto!(dest::AbstractMatrix, src::SubArray{<:Any,2,<:AdjOrTrans{<:Any,<:LayoutA
270270
# ambiguity from sparsematrix.jl
271271
copyto!(dest::LayoutMatrix, src::SparseArrays.AbstractSparseMatrixCSC) = _copyto!(dest, src)
272272
copyto!(dest::SubArray{<:Any,2,<:LayoutMatrix}, src::SparseArrays.AbstractSparseMatrixCSC) = _copyto!(dest, src)
273+
if isdefined(LinearAlgebra, :copymutable_oftype)
274+
LinearAlgebra.copymutable_oftype(A::Union{LayoutArray,Symmetric{<:Any,<:LayoutMatrix},Hermitian{<:Any,<:LayoutMatrix},
275+
UpperOrLowerTriangular{<:Any,<:LayoutMatrix},
276+
AdjOrTrans{<:Any,<:LayoutMatrix}}, ::Type{T}) where T = copymutable_oftype_layout(MemoryLayout(A), A, T)
277+
end
278+
copymutable_oftype_layout(_, A, ::Type{S}) where S = copyto!(similar(A, S), A)
273279

274280
# avoid bad copy in Base
275281
Base.map(::typeof(copy), D::Diagonal{<:LayoutArray}) = Diagonal(map(copy, D.diag))

src/ldiv.jl

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ for Typ in (:Ldiv, :Rdiv)
2323
end
2424
end
2525

26+
similar(A::Rdiv{<:DualLayout}, ::Type{T}, (ax1,ax2)) where T = dualadjoint(A.A)(similar(Array{T}, (ax2,)))
27+
2628
@inline _ldivaxes(::Tuple{}, ::Tuple{}) = ()
2729
@inline _ldivaxes(::Tuple{}, Bax::Tuple) = Bax
2830
@inline _ldivaxes(::Tuple{<:Any}, ::Tuple{<:Any}) = ()
@@ -80,8 +82,18 @@ __ldiv!(_, F, B) = LinearAlgebra.ldiv!(F, B)
8082

8183
@inline _ldiv!(dest, A, B; kwds...) = ldiv!(dest, factorize(A), B; kwds...)
8284
@inline _ldiv!(dest, A::Factorization, B; kwds...) = LinearAlgebra.ldiv!(dest, A, B; kwds...)
83-
@inline _ldiv!(dest, A::Transpose{<:Any,<:Factorization}, B; kwds...) = LinearAlgebra.ldiv!(dest, A, B; kwds...)
84-
@inline _ldiv!(dest, A::Adjoint{<:Any,<:Factorization}, B; kwds...) = LinearAlgebra.ldiv!(dest, A, B; kwds...)
85+
86+
if VERSION v"1.10-"
87+
using LinearAlgebra: TransposeFactorization, AdjointFactorization
88+
else
89+
const TransposeFactorization = Transpose
90+
const AdjointFactorization = Adjoint
91+
92+
end
93+
@inline _ldiv!(dest, A::TransposeFactorization{<:Any,<:Factorization}, B; kwds...) = LinearAlgebra.ldiv!(dest, A, B; kwds...)
94+
@inline _ldiv!(dest, A::AdjointFactorization{<:Any,<:Factorization}, B; kwds...) = LinearAlgebra.ldiv!(dest, A, B; kwds...)
95+
96+
8597

8698
@inline ldiv(A, B; kwds...) = materialize(Ldiv(A,B); kwds...)
8799
@inline rdiv(A, B; kwds...) = materialize(Rdiv(A,B); kwds...)
@@ -94,7 +106,10 @@ __ldiv!(_, F, B) = LinearAlgebra.ldiv!(F, B)
94106

95107
@inline materialize!(M::Ldiv) = _ldiv!(M.A, M.B)
96108
@inline materialize!(M::Rdiv) = ldiv!(M.B', M.A')'
97-
@inline copyto!(dest::AbstractArray, M::Rdiv; kwds...) = copyto!(dest', Ldiv(M.B', M.A'); kwds...)'
109+
@inline function copyto!(dest::AbstractArray, M::Rdiv; kwds...)
110+
adj = dualadjoint(dest)
111+
adj(copyto!(adj(dest), Ldiv(adj(M.B), adj(M.A)); kwds...))
112+
end
98113
@inline copyto!(dest::AbstractArray, M::Ldiv; kwds...) = _ldiv!(dest, M.A, copy(M.B); kwds...)
99114

100115
const MatLdivVec{styleA, styleB, T, V} = Ldiv{styleA, styleB, <:AbstractMatrix{T}, <:AbstractVector{V}}
@@ -139,6 +154,8 @@ macro _layoutldiv(Typ)
139154

140155
LinearAlgebra.ldiv!(A::Bidiagonal, B::$Typ; kwds...) = ArrayLayouts.ldiv!(A,B; kwds...)
141156

157+
LinearAlgebra.rdiv!(A::AbstractMatrix, B::$Typ; kwds...) = ArrayLayouts.rdiv!(A,B; kwds...)
158+
142159

143160
(\)(A::$Typ, x::AbstractVector; kwds...) = ArrayLayouts.ldiv(A,x; kwds...)
144161
(\)(A::$Typ, x::AbstractMatrix; kwds...) = ArrayLayouts.ldiv(A,x; kwds...)
@@ -171,6 +188,8 @@ macro _layoutldiv(Typ)
171188
(/)(A::$Typ, D::Diagonal; kwds...) = ArrayLayouts.rdiv(A,D; kwds...)
172189

173190
(/)(x::$Typ, A::$Typ; kwds...) = ArrayLayouts.rdiv(x,A; kwds...)
191+
(/)(D::Adjoint{<:Any,<:AbstractVector}, A::$Typ; kwds...) = ArrayLayouts.rdiv(D,A; kwds...)
192+
(/)(D::Transpose{<:Any,<:AbstractVector}, A::$Typ; kwds...) = ArrayLayouts.rdiv(D,A; kwds...)
174193
end
175194
if Typ :LayoutVector
176195
ret = quote

test/test_layoutarray.jl

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,14 @@ MemoryLayout(::Type{MyVector}) = DenseColumnMajor()
6464
@test A[kr,jr] == A.A[kr,jr]
6565
end
6666
b = randn(5)
67+
B = randn(5,5)
6768
for Tri in (UpperTriangular, UnitUpperTriangular, LowerTriangular, UnitLowerTriangular)
6869
@test ldiv!(Tri(A), copy(b)) ldiv!(Tri(A.A), copy(b)) Tri(A.A) \ MyVector(b)
70+
@test ldiv!(Tri(A), copy(B)) ldiv!(Tri(A.A), copy(B)) Tri(A.A) \ MyMatrix(B)
71+
if VERSION v"1.9"
72+
@test rdiv!(copy(b)', Tri(A)) rdiv!(copy(b)', Tri(A.A)) MyVector(b)' / Tri(A.A)
73+
@test rdiv!(copy(B), Tri(A)) rdiv!(copy(B), Tri(A.A)) B / Tri(A.A)
74+
end
6975
@test lmul!(Tri(A), copy(b)) lmul!(Tri(A.A), copy(b)) Tri(A.A) * MyVector(b)
7076
end
7177

@@ -112,7 +118,7 @@ MemoryLayout(::Type{MyVector}) = DenseColumnMajor()
112118
@test cholesky(S, CRowMaximum()) \ b ldiv!(cholesky(Matrix(S), CRowMaximum()), copy(MyVector(b)))
113119
@test cholesky(S) \ b Matrix(S) \ b Symmetric(Matrix(S)) \ b
114120
@test cholesky(S) \ b Symmetric(Matrix(S)) \ MyVector(b)
115-
if VERSION >= v"1.9-"
121+
if VERSION >= v"1.9"
116122
@test S \ b Matrix(S) \ b Symmetric(Matrix(S)) \ b
117123
@test S \ b Symmetric(Matrix(S)) \ MyVector(b)
118124
end
@@ -122,19 +128,19 @@ MemoryLayout(::Type{MyVector}) = DenseColumnMajor()
122128
@test cholesky(S,CRowMaximum()).U cholesky(Matrix(S),CRowMaximum()).U
123129
@test cholesky(S) \ b Matrix(S) \ b Symmetric(Matrix(S), :L) \ b
124130
@test cholesky(S) \ b Symmetric(Matrix(S), :L) \ MyVector(b)
125-
if VERSION >= v"1.9-"
131+
if VERSION >= v"1.9"
126132
@test S \ b Matrix(S) \ b Symmetric(Matrix(S), :L) \ b
127133
@test S \ b Symmetric(Matrix(S), :L) \ MyVector(b)
128134
end
129135

130136
@testset "ldiv!" begin
131137
c = MyVector(randn(5))
132-
if VERSION < v"1.9-"
138+
if VERSION < v"1.9"
133139
@test_broken ldiv!(lu(A), MyVector(copy(c))) A \ c
134140
else
135141
@test ldiv!(lu(A), MyVector(copy(c))) A \ c
136142
end
137-
if VERSION < v"1.9-" || VERSION >= v"1.10-"
143+
if VERSION < v"1.9" || VERSION >= v"1.10-"
138144
@test_throws ErrorException ldiv!(qr(A), MyVector(copy(c)))
139145
else
140146
@test_throws MethodError ldiv!(qr(A), MyVector(copy(c)))
@@ -221,18 +227,24 @@ MemoryLayout(::Type{MyVector}) = DenseColumnMajor()
221227
@test_broken ldiv!(A, t) A\t
222228
@test ldiv!(A, copy(X)) A\X
223229
@test A\T A\
224-
VERSION >= v"1.9-" && @test A/T A/
230+
VERSION >= v"1.9" && @test A/T A/
225231
@test_broken ldiv!(A, T) A\T
226232
@test B\A B\Matrix(A)
227233
@test D \ A D \ Matrix(A)
228234
@test transpose(B)\A transpose(B)\Matrix(A) Transpose(B)\A Adjoint(B)\A
229235
@test B'\A B'\Matrix(A)
230236
@test A\A I
231-
VERSION >= v"1.9-" && @test A/A I
237+
VERSION >= v"1.9" && @test A/A I
232238
@test A\MyVector(x) A\x
233239
@test A\MyMatrix(X) A\X
234240

235-
VERSION >= v"1.9-" && @test A/A A.A / A.A
241+
if VERSION >= v"1.9"
242+
@test A/A A.A / A.A
243+
@test x' / A x' / A.A
244+
@test transpose(x) / A transpose(x) / A.A
245+
@test transpose(x) / A isa Transpose
246+
@test x' / A isa Adjoint
247+
end
236248
end
237249

238250
@testset "dot" begin
@@ -455,6 +467,13 @@ MemoryLayout(::Type{MyVector}) = DenseColumnMajor()
455467
@test UpperTriangular(A) * UnitUpperTriangular(A') UpperTriangular(A.A) * UnitUpperTriangular(A.A')
456468
@test UpperTriangular(A') * UnitUpperTriangular(A') UpperTriangular(A.A') * UnitUpperTriangular(A.A')
457469
end
470+
471+
if isdefined(LinearAlgebra, :copymutable_oftype)
472+
@testset "copymutable_oftype" begin
473+
A = MyMatrix(randn(3,3))
474+
@test LinearAlgebra.copymutable_oftype(A, BigFloat) == A
475+
end
476+
end
458477
end
459478

460479
struct MyUpperTriangular{T} <: AbstractMatrix{T}
@@ -498,5 +517,5 @@ triangulardata(A::MyUpperTriangular) = triangulardata(A.A)
498517
@test_skip lmul!(U,view(copy(B),collect(1:5),1:5)) U * B
499518

500519
@test MyMatrix(A) / U A / U
501-
VERSION >= v"1.9-" && @test U / MyMatrix(A) U / A
520+
VERSION >= v"1.9" && @test U / MyMatrix(A) U / A
502521
end

0 commit comments

Comments
 (0)