Skip to content

Commit ebf0312

Browse files
authored
Better Support for getindex for Ldiv (#35)
* Better Support for getindex for Ldiv * lmul! for LayoutArray * lmul! for Symmetric/Hermitian * Test scalar lmul!/rmul!/ldiv!/rdiv! * Update runtests.jl * Update runtests.jl * 3-arg dot, getindex for LayoutArrays * x'A should behave like a dual-vector * Update runtests.jl
1 parent 3cc11e7 commit ebf0312

File tree

7 files changed

+98
-13
lines changed

7 files changed

+98
-13
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ArrayLayouts"
22
uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
33
authors = ["Sheehan Olver <[email protected]>"]
4-
version = "0.4.2"
4+
version = "0.4.3"
55

66
[deps]
77
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"

src/ArrayLayouts.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,10 @@ macro layoutgetindex(Typ)
128128
esc(quote
129129
ArrayLayouts.@_layoutgetindex $Typ
130130
ArrayLayouts.@_layoutgetindex LinearAlgebra.AbstractTriangular{<:Any,<:$Typ}
131+
ArrayLayouts.@_layoutgetindex LinearAlgebra.Symmetric{<:Any,<:$Typ}
132+
ArrayLayouts.@_layoutgetindex LinearAlgebra.Hermitian{<:Any,<:$Typ}
133+
ArrayLayouts.@_layoutgetindex LinearAlgebra.Adjoint{<:Any,<:$Typ}
134+
ArrayLayouts.@_layoutgetindex LinearAlgebra.Transpose{<:Any,<:$Typ}
131135
end)
132136
end
133137

@@ -144,6 +148,15 @@ end
144148

145149
@layoutmatrix LayoutMatrix
146150

151+
for Typ in (:LayoutArray, :(Transpose{<:Any,<:LayoutMatrix}), :(Adjoint{<:Any,<:LayoutMatrix}), :(Symmetric{<:Any,<:LayoutMatrix}), :(Hermitian{<:Any,<:LayoutMatrix}))
152+
@eval begin
153+
LinearAlgebra.lmul!::Number, A::$Typ) = lmul!(α, A)
154+
LinearAlgebra.rmul!(A::$Typ, α::Number) = rmul!(A, α)
155+
LinearAlgebra.ldiv!::Number, A::$Typ) = ldiv!(α, A)
156+
LinearAlgebra.rdiv!(A::$Typ, α::Number) = rdiv!(A, α)
157+
end
158+
end
159+
147160
getindex(A::LayoutVector, kr::AbstractVector) = layout_getindex(A, kr)
148161

149162
_copyto!(_, _, dest::AbstractArray{T,N}, src::AbstractArray{V,N}) where {T,V,N} =

src/ldiv.jl

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,12 @@ end
4545
@inline eltype(M::Rdiv) = promote_type(eltype(M.A), Base.promote_op(inv, eltype(M.B)))
4646

4747
# Lazy getindex
48-
getindex(L::Ldiv{<:Any,<:Any,<:AbstractMatrix,<:AbstractVector}, k::Integer) = copyto!(similar(L), L)[k]
49-
getindex(L::Ldiv{<:Any,<:Any,<:AbstractMatrix,<:AbstractMatrix}, k::Integer,j::Integer) = Ldiv(L.A, L.B[:,j])[k]
48+
49+
getindex(L::Ldiv, k...) = _getindex(indextype(L), L, k)
50+
concretize(L::AbstractArray) = convert(Array,L)
51+
concretize(L::Ldiv) = ldiv(concretize(L.A), concretize(L.B))
52+
_getindex(::Type{Tuple{I}}, L::Ldiv, (k,)::Tuple{I}) where I = concretize(L)[k]
53+
_getindex(::Type{Tuple{I,J}}, L::Ldiv, (k,j)::Tuple{I,J}) where {I,J} = Ldiv(L.A, L.B[:,j])[k]
5054

5155
check_ldiv_axes(A, B) =
5256
axes(A,1) == axes(B,1) || throw(DimensionMismatch("First axis of A, $(axes(A,1)), and first axis of B, $(axes(B,1)) must match"))
@@ -105,10 +109,25 @@ const BlasMatLdivMat{styleA, styleB, T<:BlasFloat} = MatLdivMat{styleA, styleB,
105109
const MatRdivMat{styleA, styleB, T, V} = Rdiv{styleA, styleB, <:AbstractMatrix{T}, <:AbstractMatrix{V}}
106110
const BlasMatRdivMat{styleA, styleB, T<:BlasFloat} = MatRdivMat{styleA, styleB, T, T}
107111

108-
# function materialize!(L::BlasMatLdivVec{<:AbstractColumnMajor,<:AbstractColumnMajor})
109-
110-
# end
112+
materialize!(M::Ldiv{ScalarLayout}) = Base.invoke(LinearAlgebra.ldiv!, Tuple{Number,AbstractArray}, M.A, M.B)
113+
materialize!(M::Rdiv{<:Any,ScalarLayout}) = Base.invoke(LinearAlgebra.rdiv!, Tuple{AbstractArray,Number}, M.A, M.B)
111114

115+
function materialize!(M::Ldiv{ScalarLayout,<:SymmetricLayout})
116+
ldiv!(M.A, symmetricdata(M.B))
117+
M.B
118+
end
119+
function materialize!(M::Ldiv{ScalarLayout,<:HermitianLayout})
120+
ldiv!(M.A, hermitiandata(M.B))
121+
M.B
122+
end
123+
function materialize!(M::Rdiv{<:SymmetricLayout,ScalarLayout})
124+
rdiv!(symmetricdata(M.A), M.B)
125+
M.A
126+
end
127+
function materialize!(M::Rdiv{<:HermitianLayout,ScalarLayout})
128+
rdiv!(hermitiandata(M.A), M.B)
129+
M.A
130+
end
112131

113132
macro _layoutldiv(Typ)
114133
ret = quote

src/lmul.jl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,22 @@ materialize!(M::Rmul) = LinearAlgebra.rmul!(M.A,M.B)
7373
materialize!(M::Lmul{ScalarLayout}) = Base.invoke(LinearAlgebra.lmul!, Tuple{Number,AbstractArray}, M.A, M.B)
7474
materialize!(M::Rmul{<:Any,ScalarLayout}) = Base.invoke(LinearAlgebra.rmul!, Tuple{AbstractArray,Number}, M.A, M.B)
7575

76-
76+
function materialize!(M::Lmul{ScalarLayout,<:SymmetricLayout})
77+
lmul!(M.A, symmetricdata(M.B))
78+
M.B
79+
end
80+
function materialize!(M::Lmul{ScalarLayout,<:HermitianLayout})
81+
lmul!(M.A, hermitiandata(M.B))
82+
M.B
83+
end
84+
function materialize!(M::Rmul{<:SymmetricLayout,ScalarLayout})
85+
rmul!(symmetricdata(M.A), M.B)
86+
M.A
87+
end
88+
function materialize!(M::Rmul{<:HermitianLayout,ScalarLayout})
89+
rmul!(hermitiandata(M.A), M.B)
90+
M.A
91+
end
7792

7893
macro _layoutlmul(Typ)
7994
esc(quote

src/mul.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,10 @@ function _getindex(::Type{Tuple{AA,BB}}, M::Mul, (k, j)::Tuple{AA,BB}) where {AA
4646
end
4747

4848
# linear indexing
49-
_getindex(::Type{NTuple{2,Int}}, M::Mul, k::Tuple{Int}) = M[Base._ind2sub(axes(M), k...)...]
49+
_getindex(::Type{NTuple{2,Int}}, M, k::Tuple{Int}) = M[Base._ind2sub(axes(M), k...)...]
5050

51-
_getindex(::Type{Tuple{Int}}, M::Mul, (k,)::Tuple{CartesianIndex{1}}) = M[convert(Int, k)]
52-
_getindex(::Type{NTuple{2,Int}}, M::Mul, (kj,)::Tuple{CartesianIndex{2}}) = M[kj[1], kj[2]]
51+
_getindex(::Type{Tuple{Int}}, M, (k,)::Tuple{CartesianIndex{1}}) = M[convert(Int, k)]
52+
_getindex(::Type{NTuple{2,Int}}, M, (kj,)::Tuple{CartesianIndex{2}}) = M[kj[1], kj[2]]
5353

5454
"""
5555
indextype(A)
@@ -245,3 +245,7 @@ dot(a, b) = materialize(Dot(a, b))
245245
@inline LinearAlgebra.dot(a::LayoutArray, b::SubArray{<:Any,N,<:LayoutArray}) where N = dot(a,b)
246246
@inline LinearAlgebra.dot(a::SubArray{<:Any,N,<:LayoutArray}, b::SubArray{<:Any,N,<:LayoutArray}) where N = dot(a,b)
247247

248+
# Temporary until layout 3-arg dot is added.
249+
# We go to generic fallback as layout-arrays are structured
250+
dot(x, A, y) = dot(x, mul(A, y))
251+
LinearAlgebra.dot(x::AbstractVector, A::LayoutMatrix, y::AbstractVector) = dot(x, A, y)

src/muladd.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,4 +417,13 @@ fillzeros(::Type{T}, ax) where T = Zeros{T}(ax)
417417
# Fill
418418
###
419419

420-
copy(M::MulAdd{<:AbstractFillLayout,<:AbstractFillLayout,<:AbstractFillLayout}) = M.α*M.A*M.B + M.β*M.C
420+
copy(M::MulAdd{<:AbstractFillLayout,<:AbstractFillLayout,<:AbstractFillLayout}) = M.α*M.A*M.B + M.β*M.C
421+
422+
###
423+
# DualLayout
424+
###
425+
426+
function similar(M::MulAdd{<:DualLayout,<:Any,ZerosLayout}, ::Type{T}, (x,y)) where T
427+
@assert x Base.OneTo(1)
428+
similar(M.A', T, y)'
429+
end

test/runtests.jl

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,11 +126,36 @@ MemoryLayout(::Type{MyVector}) = DenseColumnMajor()
126126
@test A\MyVector(x) A\x
127127
@test A\MyMatrix(X) A\X
128128
end
129+
130+
@testset "dot" begin
131+
A = MyMatrix(randn(5,5))
132+
b = randn(5)
133+
@test dot(b, A, b) b'*(A*b) b'A*b
134+
end
129135
end
130136

131137
@testset "l/rmul!" begin
132-
b = randn(5)
133-
@test ArrayLayouts.lmul!(2, MyVector(copy(b))) == ArrayLayouts.rmul!(MyVector(copy(b)), 2) == 2b
138+
b = MyVector(randn(5))
139+
A = MyMatrix(randn(5,5))
140+
@test lmul!(2, deepcopy(b)) == rmul!(deepcopy(b), 2) == 2b
141+
@test lmul!(2, deepcopy(A)) == rmul!(deepcopy(A), 2) == 2A
142+
@test lmul!(2, deepcopy(A)') == rmul!(deepcopy(A)', 2) == 2A'
143+
@test lmul!(2, transpose(deepcopy(A))) == rmul!(transpose(deepcopy(A)), 2) == 2transpose(A)
144+
@test lmul!(2, Symmetric(deepcopy(A))) == rmul!(Symmetric(deepcopy(A)), 2) == 2Symmetric(A)
145+
@test lmul!(2, Hermitian(deepcopy(A))) == rmul!(Hermitian(deepcopy(A)), 2) == 2Hermitian(A)
146+
147+
C = randn(ComplexF64,5,5)
148+
@test ArrayLayouts.lmul!(2, Hermitian(copy(C))) == ArrayLayouts.rmul!(Hermitian(copy(C)), 2) == 2Hermitian(C)
149+
150+
if VERSION v"1.5"
151+
@test ldiv!(2, deepcopy(b)) == rdiv!(deepcopy(b), 2) == 2\b
152+
@test ldiv!(2, deepcopy(A)) == rdiv!(deepcopy(A), 2) == 2\A
153+
@test ldiv!(2, deepcopy(A)') == rdiv!(deepcopy(A)', 2) == 2\A'
154+
@test ldiv!(2, transpose(deepcopy(A))) == rdiv!(transpose(deepcopy(A)), 2) == 2\transpose(A)
155+
@test ldiv!(2, Symmetric(deepcopy(A))) == rdiv!(Symmetric(deepcopy(A)), 2) == 2\Symmetric(A)
156+
@test ldiv!(2, Hermitian(deepcopy(A))) == rdiv!(Hermitian(deepcopy(A)), 2) == 2\Hermitian(A)
157+
@test ArrayLayouts.ldiv!(2, Hermitian(copy(C))) == ArrayLayouts.rdiv!(Hermitian(copy(C)), 2) == 2\Hermitian(C)
158+
end
134159
end
135160
end
136161

0 commit comments

Comments
 (0)