Skip to content

Commit e561971

Browse files
authored
Add Rdiv (#1)
* Add Rdiv * Fix diagonal Ldiv, more lazy ldiv * Update ldiv.jl * Add tests * Update ldiv.jl
1 parent 92d06cc commit e561971

File tree

4 files changed

+102
-28
lines changed

4 files changed

+102
-28
lines changed

src/ArrayLayouts.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ else
4646
import Base: require_one_based_indexing
4747
end
4848

49-
export materialize, materialize!, MulAdd, muladd!, Ldiv, Lmul, Rmul, lmul, rmul, mul, MemoryLayout, AbstractStridedLayout,
49+
export materialize, materialize!, MulAdd, muladd!, Ldiv, Rdiv, Lmul, Rmul, lmul, rmul, mul, MemoryLayout, AbstractStridedLayout,
5050
DenseColumnMajor, ColumnMajor, ZerosLayout, FillLayout, AbstractColumnMajor, RowMajor, AbstractRowMajor,
5151
DiagonalLayout, ScalarLayout, SymTridiagonalLayout, HermitianLayout, SymmetricLayout, TriangularLayout,
5252
UnknownLayout, AbstractBandedLayout, ApplyBroadcastStyle, ConjLayout, AbstractFillLayout,

src/diagonal.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,13 @@ end
2929

3030
copy(M::Rmul{<:Any,<:DiagonalLayout}) = M.A .* permutedims(M.B.diag)
3131
copy(M::Rmul{<:Any,<:DiagonalLayout{<:AbstractFillLayout}}) = M.A .* getindex_value(M.B.diag)
32+
33+
copy(M::Ldiv{<:DiagonalLayout,<:DiagonalLayout}) = Diagonal(inv.(M.A.diag) .* M.B.diag)
34+
copy(M::Ldiv{<:DiagonalLayout}) = inv.(M.A.diag) .* M.B
35+
copy(M::Ldiv{<:DiagonalLayout{<:AbstractFillLayout}}) = inv(getindex_value(M.A.diag)) .* M.B
36+
copy(M::Ldiv{<:DiagonalLayout{<:AbstractFillLayout},<:DiagonalLayout}) = Diagonal(inv(getindex_value(M.A.diag)) .* M.B.diag)
37+
38+
copy(M::Rdiv{<:DiagonalLayout,<:DiagonalLayout}) = Diagonal(M.A.diag .* inv.(M.B.diag))
39+
copy(M::Rdiv{<:Any,<:DiagonalLayout}) = M.A .* inv.(permutedims(M.B.diag))
40+
copy(M::Rdiv{<:Any,<:DiagonalLayout{<:AbstractFillLayout}}) = M.A .* inv(getindex_value(M.B.diag))
41+
copy(M::Rdiv{<:DiagonalLayout,<:DiagonalLayout{<:AbstractFillLayout}}) = Diagonal(M.A.diag .* inv(getindex_value(M.B.diag)))

src/ldiv.jl

Lines changed: 60 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,37 @@
1+
for Typ in (:Ldiv, :Rdiv)
2+
@eval begin
3+
struct $Typ{StyleA, StyleB, AType, BType}
4+
A::AType
5+
B::BType
6+
end
17

8+
$Typ{StyleA, StyleB}(A::AType, B::BType) where {StyleA,StyleB,AType,BType} =
9+
$Typ{StyleA,StyleB,AType,BType}(A,B)
210

11+
$Typ(A::AType, B::BType) where {AType,BType} =
12+
$Typ{typeof(MemoryLayout(AType)),typeof(MemoryLayout(BType)),AType,BType}(A, B)
313

14+
BroadcastStyle(::Type{<:$Typ}) = ApplyBroadcastStyle()
15+
broadcastable(M::$Typ) = M
416

5-
struct Ldiv{StyleA, StyleB, AType, BType}
6-
A::AType
7-
B::BType
8-
end
9-
10-
Ldiv{StyleA, StyleB}(A::AType, B::BType) where {StyleA,StyleB,AType,BType} =
11-
Ldiv{StyleA,StyleB,AType,BType}(A,B)
17+
similar(A::$Typ, ::Type{T}, axes) where T = similar(Array{T}, axes)
18+
similar(A::$Typ, ::Type{T}) where T = similar(A, T, axes(A))
19+
similar(A::$Typ) = similar(A, eltype(A))
1220

13-
Ldiv(A::AType, B::BType) where {AType,BType} =
14-
Ldiv{typeof(MemoryLayout(AType)),typeof(MemoryLayout(BType)),AType,BType}(A, B)
15-
16-
struct LdivBroadcastStyle <: BroadcastStyle end
21+
copy(M::$Typ) = copyto!(similar(M), M)
22+
materialize(M::$Typ) = copy(instantiate(M))
23+
end
24+
end
1725

1826
size(L::Ldiv{<:Any,<:Any,<:Any,<:AbstractMatrix}) = (size(L.A, 2),size(L.B,2))
1927
size(L::Ldiv{<:Any,<:Any,<:Any,<:AbstractVector}) = (size(L.A, 2),)
2028
axes(L::Ldiv{<:Any,<:Any,<:Any,<:AbstractMatrix}) = (axes(L.A, 2),axes(L.B,2))
21-
axes(L::Ldiv{<:Any,<:Any,<:Any,<:AbstractVector}) = (axes(L.A, 2),)
29+
axes(L::Ldiv{<:Any,<:Any,<:Any,<:AbstractVector}) = (axes(L.A, 2),)
2230
length(L::Ldiv{<:Any,<:Any,<:Any,<:AbstractVector}) =size(L.A, 2)
2331

32+
size(L::Rdiv) = (size(L.A, 1),size(L.B,1))
33+
axes(L::Rdiv) = (axes(L.A, 1),axes(L.B,1))
34+
2435
_ldivaxes(::Tuple{}, ::Tuple{}) = ()
2536
_ldivaxes(::Tuple{}, Bax::Tuple) = Bax
2637
_ldivaxes(::Tuple{<:Any}, ::Tuple{<:Any}) = ()
@@ -32,27 +43,26 @@ _ldivaxes(Aax::Tuple{<:Any,<:Any}, Bax::Tuple{<:Any,<:Any}) = (last(Aax),last(Ba
3243

3344
ndims(L::Ldiv) = ndims(last(L.args))
3445
eltype(M::Ldiv) = promote_type(Base.promote_op(inv, eltype(M.A)), eltype(M.B))
46+
eltype(M::Rdiv) = promote_type(eltype(M.A), Base.promote_op(inv, eltype(M.B)))
47+
48+
49+
check_ldiv_axes(A, B) =
50+
axes(A,1) == axes(B,1) || throw(DimensionMismatch("First axis of A, $(axes(A,1)), and first axis of B, $(axes(B,1)) must match"))
51+
52+
check_rdiv_axes(A, B) =
53+
axes(A,2) == axes(B,2) || throw(DimensionMismatch("Second axis of A, $(axes(A,2)), and second axis of B, $(axes(B,2)) must match"))
3554

36-
BroadcastStyle(::Type{<:Ldiv}) = ApplyBroadcastStyle()
37-
broadcastable(M::Ldiv) = M
3855

39-
similar(A::Ldiv, ::Type{T}, axes) where T = similar(Array{T}, axes)
40-
similar(A::Ldiv, ::Type{T}) where T = similar(A, T, axes(A))
41-
similar(A::Ldiv) = similar(A, eltype(A))
4256

4357
function instantiate(L::Ldiv)
4458
check_ldiv_axes(L.A, L.B)
4559
Ldiv(instantiate(L.A), instantiate(L.B))
4660
end
4761

48-
49-
check_ldiv_axes(A, B) =
50-
axes(A,1) == axes(B,1) || throw(DimensionMismatch("First axis of A, $(axes(A,1)), and first axis of B, $(axes(B,1)) must match"))
51-
52-
53-
54-
copy(M::Ldiv) = copyto!(similar(M), M)
55-
materialize(M::Ldiv) = copy(instantiate(M))
62+
function instantiate(L::Rdiv)
63+
check_rdiv_axes(L.A, L.B)
64+
Rdiv(instantiate(L.A), instantiate(L.B))
65+
end
5666

5767
_ldiv!(A, B) = ldiv!(factorize(A), B)
5868
_ldiv!(A::Factorization, B) = ldiv!(A, B)
@@ -61,7 +71,11 @@ _ldiv!(dest, A, B) = ldiv!(dest, factorize(A), B)
6171
_ldiv!(dest, A::Factorization, B) = ldiv!(dest, A, B)
6272

6373

74+
6475
materialize!(M::Ldiv) = _ldiv!(M.A, M.B)
76+
materialize!(M::Rdiv) = materialize!(Lmul(M.B', M.A'))'
77+
copyto!(dest::AbstractArray, M::Rdiv) = copyto!(dest', Ldiv(M.B', M.A'))'
78+
6579
if VERSION v"1.1-pre"
6680
copyto!(dest::AbstractArray, M::Ldiv) = _ldiv!(dest, M.A, M.B)
6781
else
@@ -73,6 +87,13 @@ const MatLdivMat{styleA, styleB, T, V} = Ldiv{styleA, styleB, <:AbstractMatrix{T
7387
const BlasMatLdivVec{styleA, styleB, T<:BlasFloat} = MatLdivVec{styleA, styleB, T, T}
7488
const BlasMatLdivMat{styleA, styleB, T<:BlasFloat} = MatLdivMat{styleA, styleB, T, T}
7589

90+
const MatRdivMat{styleA, styleB, T, V} = Rdiv{styleA, styleB, <:AbstractMatrix{T}, <:AbstractMatrix{V}}
91+
const BlasMatRdivMat{styleA, styleB, T<:BlasFloat} = MatRdivMat{styleA, styleB, T, T}
92+
93+
# function materialize!(L::BlasMatLdivVec{<:AbstractColumnMajor,<:AbstractColumnMajor})
94+
95+
# end
96+
7697

7798
macro lazyldiv(Typ)
7899
esc(quote
@@ -83,5 +104,18 @@ macro lazyldiv(Typ)
83104

84105
Base.:\(A::$Typ, x::AbstractVector) = ArrayLayouts.materialize(ArrayLayouts.Ldiv(A,x))
85106
Base.:\(A::$Typ, x::AbstractMatrix) = ArrayLayouts.materialize(ArrayLayouts.Ldiv(A,x))
107+
108+
Base.:\(x::AbstractMatrix, A::$Typ) = ArrayLayouts.materialize(ArrayLayouts.Ldiv(x,A))
109+
Base.:\(x::Diagonal, A::$Typ) = ArrayLayouts.materialize(ArrayLayouts.Ldiv(x,A))
110+
111+
Base.:\(x::$Typ, A::$Typ) = ArrayLayouts.materialize(ArrayLayouts.Ldiv(x,A))
112+
113+
Base.:/(A::$Typ, x::AbstractVector) = ArrayLayouts.materialize(ArrayLayouts.Rdiv(A,x))
114+
Base.:/(A::$Typ, x::AbstractMatrix) = ArrayLayouts.materialize(ArrayLayouts.Rdiv(A,x))
115+
116+
Base.:/(x::AbstractMatrix, A::$Typ) = ArrayLayouts.materialize(ArrayLayouts.Rdiv(x,A))
117+
Base.:/(x::Diagonal, A::$Typ) = ArrayLayouts.materialize(ArrayLayouts.Rdiv(x,A))
118+
119+
Base.:/(x::$Typ, A::$Typ) = ArrayLayouts.materialize(ArrayLayouts.Rdiv(x,A))
86120
end)
87-
end
121+
end

test/test_ldiv.jl

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using ArrayLayouts, LinearAlgebra, Test
1+
using ArrayLayouts, LinearAlgebra, FillArrays, Test
22
import ArrayLayouts: ApplyBroadcastStyle
33

44
@testset "Ldiv" begin
@@ -92,4 +92,34 @@ import ArrayLayouts: ApplyBroadcastStyle
9292
A = [1 2 ; 3 4]; b = [5,6];
9393
@test eltype(Ldiv(A, b)) == Float64
9494
end
95+
96+
@testset "Rdiv" begin
97+
@testset "Float64 \\ *" begin
98+
A = randn(3,5)
99+
B = randn(5,5)
100+
M = Rdiv(A,B)
101+
102+
@test size(M) == (3,5)
103+
@test axes(M) == (Base.OneTo(3),Base.OneTo(5))
104+
@test similar(M) isa Matrix{Float64}
105+
end
106+
end
107+
108+
@testset "Diagonal" begin
109+
D = Diagonal(randn(5))
110+
F = Eye(5)
111+
A = randn(5,5)
112+
@test copy(Ldiv(D,F)) == D \ F
113+
@test copy(Ldiv(F,D)) == F \ D
114+
@test copy(Ldiv(D,A)) D \ A
115+
@test copy(Ldiv(A,D)) == A \ D
116+
@test copy(Ldiv(F,A)) F \ A
117+
@test copy(Ldiv(A,F)) == A \ F
118+
119+
@test copy(Rdiv(D,F)) == D / F
120+
@test copy(Rdiv(F,D)) == F / D
121+
@test copy(Rdiv(A,D)) A / D
122+
@test copy(Rdiv(A,F)) == A / F
123+
end
95124
end
125+

0 commit comments

Comments
 (0)