Skip to content

Commit 2dbe5f2

Browse files
authored
Rework conversions, speed-up matrix Kronecker products (#108)
1 parent 68bda52 commit 2dbe5f2

File tree

7 files changed

+83
-35
lines changed

7 files changed

+83
-35
lines changed

src/LinearMaps.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,15 @@ convert_to_lmaps(A) = (convert_to_lmaps_(A),)
6868
@inline convert_to_lmaps(A, B, Cs...) =
6969
(convert_to_lmaps_(A), convert_to_lmaps_(B), convert_to_lmaps(Cs...)...)
7070

71+
# The (internal) multiplication logic is as follows:
72+
# - `*(A, x)` calls `mul!(y, A, x)` for appropriately-sized y
73+
# - `mul!` checks consistency of the sizes, and calls `_unsafe_mul!`,
74+
# which does not check sizes, but potentially one-based indexing if necessary
75+
# - by default, `_unsafe_mul!` is redirected back to `mul!`
76+
# - custom map types only need to implement 3-arg (vector) `mul!`, and
77+
# everything else (5-arg multiplication, application to matrices,
78+
# conversion to matrices) will just work
79+
7180
"""
7281
*(A::LinearMap, x::AbstractVector)::AbstractVector
7382
@@ -200,7 +209,7 @@ function _generic_mapmat_mul!(Y, A, X, α=true, β=false)
200209
return Y
201210
end
202211

203-
_unsafe_mul!(y, A::MapOrMatrix, x) = mul!(y, A, x)
212+
_unsafe_mul!(y, A::MapOrMatrix, x) = mul!(y, A, x)
204213
_unsafe_mul!(y, A::AbstractMatrix, x, α, β) = mul!(y, A, x, α, β)
205214
function _unsafe_mul!(y::AbstractVecOrMat, A::LinearMap, x::AbstractVector, α, β)
206215
return _generic_mapvec_mul!(y, A, x, α, β)

src/conversion.jl

Lines changed: 50 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# Matrix: create matrix representation of LinearMap
2-
function Base.Matrix(A::LinearMap)
2+
function Base.Matrix{T}(A::LinearMap) where {T}
33
M, N = size(A)
4-
T = eltype(A)
54
mat = Matrix{T}(undef, (M, N))
65
v = fill(zero(T), N)
76
@inbounds for i in 1:N
@@ -11,8 +10,9 @@ function Base.Matrix(A::LinearMap)
1110
end
1211
return mat
1312
end
13+
Base.Matrix(A::LinearMap{T}) where {T} = Matrix{T}(A)
1414
Base.Array(A::LinearMap) = Matrix(A)
15-
Base.convert(::Type{Matrix}, A::LinearMap) = Matrix(A)
15+
Base.convert(::Type{T}, A::LinearMap) where {T<:Matrix} = T(A)
1616
Base.convert(::Type{Array}, A::LinearMap) = convert(Matrix, A)
1717
Base.convert(::Type{AbstractMatrix}, A::LinearMap) = convert(Matrix, A)
1818
Base.convert(::Type{AbstractArray}, A::LinearMap) = convert(AbstractMatrix, A)
@@ -42,21 +42,22 @@ function SparseArrays.sparse(A::LinearMap{T}) where {T}
4242
return SparseMatrixCSC(M, N, colptr, rowind, nzval)
4343
end
4444
Base.convert(::Type{SparseMatrixCSC}, A::LinearMap) = sparse(A)
45+
SparseArrays.SparseMatrixCSC(A::LinearMap) = sparse(A)
4546

4647
# special cases
4748

4849
# ScaledMap
49-
Base.Matrix(A::ScaledMap{<:Any,<:Any,<:MatrixMap}) = convert(Matrix, A.λ*A.lmap.lmap)
50+
Base.Matrix{T}(A::ScaledMap{<:Any,<:Any,<:MatrixMap}) where {T} = convert(Matrix{T}, A.λ*A.lmap.lmap)
5051
SparseArrays.sparse(A::ScaledMap{<:Any,<:Any,<:MatrixMap}) = convert(SparseMatrixCSC, A.λ*A.lmap.lmap)
5152

5253
# UniformScalingMap
53-
Base.Matrix(J::UniformScalingMap) = Matrix(J.λ*I, size(J))
54-
Base.convert(::Type{AbstractMatrix}, J::UniformScalingMap) = Diagonal(fill(J.λ, size(J, 1)))
54+
Base.Matrix{T}(J::UniformScalingMap) where {T} = Matrix{T}(J.λ*I, size(J))
55+
Base.convert(::Type{AbstractMatrix}, J::UniformScalingMap) = Diagonal(fill(J.λ, J.M))
5556

5657
# WrappedMap
57-
Base.Matrix(A::WrappedMap) = Matrix(A.lmap)
58+
Base.Matrix{T}(A::WrappedMap) where {T} = Matrix{T}(A.lmap)
59+
Base.convert(::Type{T}, A::WrappedMap) where {T<:Matrix} = convert(T, A.lmap)
5860
Base.convert(::Type{AbstractMatrix}, A::WrappedMap) = convert(AbstractMatrix, A.lmap)
59-
Base.convert(::Type{Matrix}, A::WrappedMap) = convert(Matrix, A.lmap)
6061
SparseArrays.sparse(A::WrappedMap) = sparse(A.lmap)
6162
Base.convert(::Type{SparseMatrixCSC}, A::WrappedMap) = convert(SparseMatrixCSC, A.lmap)
6263

@@ -67,68 +68,90 @@ for (TT, T) in ((AdjointMap, adjoint), (TransposeMap, transpose))
6768
end
6869

6970
# LinearCombination
70-
for (TT, T) in ((Type{Matrix}, Matrix), (Type{SparseMatrixCSC}, SparseMatrixCSC))
71-
@eval function Base.convert(::$TT, ΣA::LinearCombination{<:Any,<:Tuple{Vararg{MatrixMap}}})
72-
maps = ΣA.maps
73-
mats = map(A->getfield(A, :lmap), maps)
74-
return convert($T, sum(mats))
75-
end
71+
function Base.Matrix{T}(ΣA::LinearCombination{<:Any,<:Tuple{Vararg{MatrixMap}}}) where {T}
72+
maps = ΣA.maps
73+
mats = map(A->getfield(A, :lmap), maps)
74+
return Matrix{T}(sum(mats))
75+
end
76+
function SparseArrays.sparse(ΣA::LinearCombination{<:Any,<:Tuple{Vararg{MatrixMap}}})
77+
maps = ΣA.maps
78+
mats = map(A->getfield(A, :lmap), maps)
79+
return convert(SparseMatrixCSC, sum(mats))
7680
end
7781

7882
# CompositeMap
79-
for (TT, T) in ((Type{Matrix}, Matrix), (Type{SparseMatrixCSC}, SparseMatrixCSC))
80-
@eval function Base.convert(::$TT, AB::CompositeMap{<:Any,<:Tuple{MatrixMap,MatrixMap}})
83+
function Base.Matrix{T}(AB::CompositeMap{<:Any,<:Tuple{MatrixMap,LinearMap}}) where {T}
84+
B, A = AB.maps
85+
require_one_based_indexing(B)
86+
Y = Matrix{eltype(AB)}(undef, size(AB))
87+
@views for i in 1:size(Y, 2)
88+
_unsafe_mul!(Y[:, i], A, B.lmap[:, i])
89+
end
90+
return Y
91+
end
92+
for ((TA, fieldA), (TB, fieldB)) in (((MatrixMap, :lmap), (MatrixMap, :lmap)),
93+
((MatrixMap, :lmap), (UniformScalingMap, )),
94+
((UniformScalingMap, ), (MatrixMap, :lmap)))
95+
@eval function Base.convert(::Type{AbstractMatrix}, AB::CompositeMap{<:Any,<:Tuple{$TB,$TA}})
8196
B, A = AB.maps
82-
return convert($T, A.lmap*B.lmap)
97+
return A.$fieldA*B.$fieldB
8398
end
8499
end
85-
function Base.Matrix(λA::CompositeMap{<:Any,<:Tuple{MatrixMap,UniformScalingMap}})
100+
function Base.Matrix{T}(AB::CompositeMap{<:Any,<:Tuple{MatrixMap,MatrixMap}}) where {T}
101+
B, A = AB.maps
102+
return convert(Matrix{T}, A.lmap*B.lmap)
103+
end
104+
function SparseArrays.sparse(AB::CompositeMap{<:Any,<:Tuple{MatrixMap,MatrixMap}})
105+
B, A = AB.maps
106+
return convert(SparseMatrixCSC, A.lmap*B.lmap)
107+
end
108+
function Base.Matrix{T}(λA::CompositeMap{<:Any,<:Tuple{MatrixMap,UniformScalingMap}}) where {T}
86109
A, J = λA.maps
87-
return convert(Matrix, J.λ*A.lmap)
110+
return convert(Matrix{T}, J.λ*A.lmap)
88111
end
89112
function SparseArrays.sparse(λA::CompositeMap{<:Any,<:Tuple{MatrixMap,UniformScalingMap}})
90113
A, J = λA.maps
91114
return convert(SparseMatrixCSC, J.λ*A.lmap)
92115
end
93-
function Base.Matrix(Aλ::CompositeMap{<:Any,<:Tuple{UniformScalingMap,MatrixMap}})
116+
function Base.Matrix{T}(Aλ::CompositeMap{<:Any,<:Tuple{UniformScalingMap,MatrixMap}}) where {T}
94117
J, A =.maps
95-
return convert(Matrix, A.lmap*J.λ)
118+
return convert(Matrix{T}, A.lmap*J.λ)
96119
end
97120
function SparseArrays.sparse(Aλ::CompositeMap{<:Any,<:Tuple{UniformScalingMap,MatrixMap}})
98121
J, A =.maps
99122
return convert(SparseMatrixCSC, A.lmap*J.λ)
100123
end
101124

102125
# BlockMap & BlockDiagonalMap
103-
Base.Matrix(A::BlockMap) = hvcat(A.rows, convert.(Matrix, A.maps)...)
126+
Base.Matrix{T}(A::BlockMap) where {T} = hvcat(A.rows, convert.(Matrix{T}, A.maps)...)
104127
Base.convert(::Type{AbstractMatrix}, A::BlockMap) = hvcat(A.rows, convert.(AbstractMatrix, A.maps)...)
105-
function Base.convert(::Type{SparseMatrixCSC}, A::BlockMap)
128+
function SparseArrays.sparse(A::BlockMap)
106129
return hvcat(
107130
A.rows,
108131
convert(SparseMatrixCSC, first(A.maps)),
109132
convert.(AbstractMatrix, Base.tail(A.maps))...
110133
)
111134
end
112-
Base.Matrix(A::BlockDiagonalMap) = cat(convert.(Matrix, A.maps)...; dims=(1,2))
135+
Base.Matrix{T}(A::BlockDiagonalMap) where {T} = cat(convert.(Matrix{T}, A.maps)...; dims=(1,2))
113136
Base.convert(::Type{AbstractMatrix}, A::BlockDiagonalMap) = sparse(A)
114137
function SparseArrays.sparse(A::BlockDiagonalMap)
115138
return blockdiag(convert.(SparseMatrixCSC, A.maps)...)
116139
end
117140

118141
# KroneckerMap & KroneckerSumMap
119-
Base.Matrix(A::KroneckerMap) = kron(convert.(Matrix, A.maps)...)
142+
Base.Matrix{T}(A::KroneckerMap) where {T} = kron(convert.(Matrix{T}, A.maps)...)
120143
Base.convert(::Type{AbstractMatrix}, A::KroneckerMap) = kron(convert.(AbstractMatrix, A.maps)...)
121144
function SparseArrays.sparse(A::KroneckerMap)
122145
return kron(
123146
convert(SparseMatrixCSC, first(A.maps)),
124147
convert.(AbstractMatrix, Base.tail(A.maps))...
125148
)
126149
end
127-
function Base.Matrix(L::KroneckerSumMap)
150+
function Base.Matrix{T}(L::KroneckerSumMap) where {T}
128151
A, B = L.maps
129152
IA = Diagonal(ones(Bool, size(A, 1)))
130153
IB = Diagonal(ones(Bool, size(B, 1)))
131-
return kron(Matrix(A), IB) + kron(IA, Matrix(B))
154+
return kron(Matrix{T}(A), IB) + kron(IA, Matrix{T}(B))
132155
end
133156
function Base.convert(::Type{AbstractMatrix}, L::KroneckerSumMap)
134157
A, B = L.maps

src/kronecker.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,13 +115,14 @@ Base.:(==)(A::KroneckerMap, B::KroneckerMap) = (eltype(A) == eltype(B) && A.maps
115115
end
116116
return y
117117
end
118-
@inline function _kronmul!(y, B::Union{MatrixMap,UniformScalingMap}, X, At::Union{MatrixMap,UniformScalingMap}, T)
118+
@inline function _kronmul!(y, B, X, At::Union{MatrixMap,UniformScalingMap}, T)
119119
na, ma = size(At)
120120
mb, nb = size(B)
121-
if nb*ma < mb*na
122-
_unsafe_mul!(reshape(y, (mb, ma)), B, convert(Matrix, X*At))
121+
Y = reshape(y, (mb, ma))
122+
if nb*ma < mb*na
123+
_unsafe_mul!(Y, B, Matrix(X*At))
123124
else
124-
_unsafe_mul!(reshape(y, (mb, ma)), convert(Matrix, B*X), At isa MatrixMap ? At.lmap : At.λ)
125+
_unsafe_mul!(Y, Matrix(B*X), At isa MatrixMap ? At.lmap : At.λ)
125126
end
126127
return y
127128
end

test/composition.jl

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using Test, LinearMaps, LinearAlgebra
1+
using Test, LinearMaps, LinearAlgebra, SparseArrays
22

33
@testset "composition" begin
44
F = @inferred LinearMap(cumsum, reverse cumsum reverse, 10; ismutating=false)
@@ -50,8 +50,9 @@ using Test, LinearMaps, LinearAlgebra
5050
R1 = rand(ComplexF64, 10, 10); L1 = LinearMap(R1)
5151
R2 = rand(ComplexF64, 10, 10); L2 = LinearMap(R2)
5252
R3 = rand(ComplexF64, 10, 10); L3 = LinearMap(R3)
53-
CompositeR = prod(R -> LinearMap(R), [R1, R2, R3])
53+
CompositeR = prod(LinearMap, [R1, R2, R3])
5454
@test @inferred L1 * L2 * L3 == CompositeR
55+
@test Matrix(L1 * L2) sparse(L1 * L2) R1 * R2
5556
@test @inferred transpose(CompositeR) == transpose(L3) * transpose(L2) * transpose(L1)
5657
@test @inferred adjoint(CompositeR) == L3' * L2' * L1'
5758
@test @inferred adjoint(adjoint((CompositeR))) == CompositeR
@@ -61,6 +62,18 @@ using Test, LinearMaps, LinearAlgebra
6162
Lc = @inferred adjoint(LinearMap(CompositeR))
6263
@test Lc * v R3' * R2' * R1' * v
6364

65+
# convert to AbstractMatrix
66+
for A in (LinearMap(sprandn(10, 10, 0.3)), LinearMap(rand()*I, 10))
67+
for B in (LinearMap(sprandn(10, 10, 0.3)), LinearMap(rand()*I, 10))
68+
AA = convert(AbstractMatrix, A*B)
69+
if A isa LinearMaps.UniformScalingMap && B isa LinearMaps.UniformScalingMap
70+
@test isdiag(AA)
71+
else
72+
@test issparse(AA)
73+
end
74+
end
75+
end
76+
6477
# test inplace operations
6578
w = similar(v)
6679
mul!(w, L, v)

test/linearcombination.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ using Test, LinearMaps, LinearAlgebra, BenchmarkTools
3737
@test @inferred(LinearMaps.MulStyle(LC)) === matrixstyle
3838
@test @inferred(LinearMaps.MulStyle(LC + I)) === matrixstyle
3939
@test @inferred(LinearMaps.MulStyle(LC + 2.0*I)) === matrixstyle
40+
@test sparse(LC) == Matrix(LC) == A+B
4041
v = rand(ComplexF64, 10)
4142
w = similar(v)
4243
b = @benchmarkable mul!($w, $M, $v)

test/linearmaps.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,6 @@ LinearAlgebra.mul!(y::AbstractVector, A::Union{SimpleFunctionMap,SimpleComplexFu
9696
@test FM == L
9797
@test F * v L * v
9898
Fs = sparse(F)
99-
@test Fs == L
99+
@test SparseMatrixCSC(F) == Fs == L
100100
@test Fs isa SparseMatrixCSC
101101
end

test/wrappedmap.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ using Test, LinearMaps, LinearAlgebra
88
L = @inferred LinearMap{Float64}(A)
99
MA = @inferred LinearMap(SA)
1010
MB = @inferred LinearMap(SB)
11+
@test eltype(Matrix{Complex{Float32}}(LinearMap(A))) <: Complex
1112
@test size(L) == size(A)
1213
@test @inferred !issymmetric(L)
1314
@test @inferred issymmetric(MA)

0 commit comments

Comments
 (0)