Skip to content

Commit 9f2d89d

Browse files
authored
add mulstyle feature (#76)
* refactor linear combination multiplication * make LinearCombination multiplication recursive
1 parent dd960a5 commit 9f2d89d

12 files changed

+116
-82
lines changed

src/LinearMaps.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,23 @@ const MapOrMatrix{T} = Union{LinearMap{T},AbstractMatrix{T}}
1919

2020
Base.eltype(::LinearMap{T}) where {T} = T
2121

22+
abstract type MulStyle end
23+
24+
struct FiveArg <: MulStyle end
25+
struct ThreeArg <: MulStyle end
26+
27+
MulStyle(::FiveArg, ::FiveArg) = FiveArg()
28+
MulStyle(::ThreeArg, ::FiveArg) = ThreeArg()
29+
MulStyle(::FiveArg, ::ThreeArg) = ThreeArg()
30+
MulStyle(::ThreeArg, ::ThreeArg) = ThreeArg()
31+
MulStyle(::LinearMap) = ThreeArg() # default
32+
@static if VERSION v"1.3.0-alpha.115"
33+
MulStyle(::AbstractMatrix) = FiveArg()
34+
else
35+
MulStyle(::AbstractMatrix) = ThreeArg()
36+
end
37+
MulStyle(A::LinearMap, As::LinearMap...) = MulStyle(MulStyle(A), MulStyle(As...))
38+
2239
Base.isreal(A::LinearMap) = eltype(A) <: Real
2340
LinearAlgebra.issymmetric(::LinearMap) = false # default assumptions
2441
LinearAlgebra.ishermitian(A::LinearMap{<:Real}) = issymmetric(A)

src/blockmap.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ end
1414

1515
BlockMap{T}(maps::As, rows::S) where {T,As<:Tuple{Vararg{LinearMap}},S} = BlockMap{T,As,S}(maps, rows)
1616

17+
MulStyle(A::BlockMap) = MulStyle(A.maps...)
18+
1719
function check_dim(A::LinearMap, dim, n)
1820
n == size(A, dim) || throw(DimensionMismatch("Expected $n, got $(size(A, dim))"))
1921
return nothing

src/linearcombination.jl

Lines changed: 36 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ end
1313

1414
LinearCombination{T}(maps::As) where {T, As} = LinearCombination{T, As}(maps)
1515

16+
MulStyle(A::LinearCombination) = MulStyle(A.maps...)
17+
1618
# basic methods
1719
Base.size(A::LinearCombination) = size(A.maps[1])
1820
LinearAlgebra.issymmetric(A::LinearCombination) = all(issymmetric, A.maps) # sufficient but not necessary
@@ -61,84 +63,50 @@ Base.:(==)(A::LinearCombination, B::LinearCombination) = (eltype(A) == eltype(B)
6163
LinearAlgebra.transpose(A::LinearCombination) = LinearCombination{eltype(A)}(map(transpose, A.maps))
6264
LinearAlgebra.adjoint(A::LinearCombination) = LinearCombination{eltype(A)}(map(adjoint, A.maps))
6365

64-
# multiplication with vectors
65-
if VERSION < v"1.3.0-alpha.115"
66-
67-
function A_mul_B!(y::AbstractVector, A::LinearCombination, x::AbstractVector)
68-
# no size checking, will be done by individual maps
69-
A_mul_B!(y, A.maps[1], x)
70-
l = length(A.maps)
71-
if l>1
72-
z = similar(y)
73-
for n in 2:l
74-
A_mul_B!(z, A.maps[n], x)
75-
y .+= z
66+
# multiplication with vectors & matrices
67+
for Atype in (AbstractVector, AbstractMatrix)
68+
@eval Base.@propagate_inbounds function LinearAlgebra.mul!(y::$Atype, A::LinearCombination, x::$Atype,
69+
α::Number=true, β::Number=false)
70+
@boundscheck check_dim_mul(y, A, x)
71+
if iszero(α) # trivial cases
72+
iszero(β) && (fill!(y, zero(eltype(y))); return y)
73+
isone(β) && return y
74+
# β != 0, 1
75+
rmul!(y, β)
76+
return y
77+
else
78+
mul!(y, first(A.maps), x, α, β)
79+
return _mul!(MulStyle(A), y, A, x, α, β)
7680
end
7781
end
78-
return y
7982
end
8083

81-
else # 5-arg mul! is available for matrices
82-
83-
# map types that have an allocation-free 5-arg mul! implementation
84-
const FreeMap = Union{MatrixMap,UniformScalingMap}
85-
86-
function A_mul_B!(y::AbstractVector, A::LinearCombination{T,As}, x::AbstractVector) where {T, As<:Tuple{Vararg{FreeMap}}}
87-
# no size checking, will be done by individual maps
88-
A_mul_B!(y, A.maps[1], x)
89-
for n in 2:length(A.maps)
90-
mul!(y, A.maps[n], x, true, true)
91-
end
92-
return y
93-
end
94-
function A_mul_B!(y::AbstractVector, A::LinearCombination, x::AbstractVector)
95-
# no size checking, will be done by individual maps
96-
A_mul_B!(y, A.maps[1], x)
97-
l = length(A.maps)
98-
if l>1
99-
z = similar(y)
100-
for n in 2:l
101-
An = A.maps[n]
102-
if An isa FreeMap
103-
mul!(y, An, x, true, true)
104-
else
105-
A_mul_B!(z, A.maps[n], x)
106-
y .+= z
107-
end
108-
end
109-
end
110-
return y
84+
@inline _mul!(::FiveArg, y, A::LinearCombination, x, α::Number, β::Number) =
85+
__mul!(y, Base.tail(A.maps), x, α, nothing)
86+
@inline function _mul!(::ThreeArg, y, A::LinearCombination, x, α::Number, β::Number)
87+
z = similar(y)
88+
__mul!(y, Base.tail(A.maps), x, α, z)
11189
end
11290

113-
function LinearAlgebra.mul!(y::AbstractVector, A::LinearCombination{T,As}, x::AbstractVector, α::Number=true, β::Number=false) where {T, As<:Tuple{Vararg{FreeMap}}}
114-
length(y) == size(A, 1) || throw(DimensionMismatch("mul!"))
115-
if isone(α)
116-
iszero(β) && (A_mul_B!(y, A, x); return y)
117-
!isone(β) && rmul!(y, β)
118-
elseif iszero(α)
119-
iszero(β) && (fill!(y, zero(eltype(y))); return y)
120-
isone(β) && return y
121-
# β != 0, 1
122-
rmul!(y, β)
123-
return y
124-
else # α != 0, 1
125-
if iszero(β)
126-
A_mul_B!(y, A, x)
127-
rmul!(y, α)
128-
return y
129-
elseif !isone(β)
130-
rmul!(y, β)
131-
end # β-cases
132-
end # α-cases
91+
@inline __mul!(y, As::Tuple{Vararg{LinearMap}}, x, α, z) =
92+
__mul!(__mul!(y, first(As), x, α, z), Base.tail(As), x, α, z)
93+
@inline __mul!(y, A::Tuple{LinearMap}, x, α, z) = __mul!(y, first(A), x, α, z)
94+
@inline __mul!(y, A::LinearMap, x, α, z) = muladd!(MulStyle(A), y, A, x, α, z)
13395

134-
for An in A.maps
135-
mul!(y, An, x, α, true)
96+
@inline muladd!(::FiveArg, y, A, x, α, _) = mul!(y, A, x, α, true)
97+
@inline function muladd!(::ThreeArg, y, A, x, α, z)
98+
# TODO: replace by mul!(z, A, x)
99+
A_mul_B!(z, A, x)
100+
if isone(α)
101+
y .+= z
102+
else
103+
y .+= z .* α
136104
end
137105
return y
138106
end
139107

140-
end # VERSION
108+
A_mul_B!(y::AbstractVector, A::LinearCombination, x::AbstractVector) = mul!(y, A, x)
141109

142-
At_mul_B!(y::AbstractVector, A::LinearCombination, x::AbstractVector) = A_mul_B!(y, transpose(A), x)
110+
At_mul_B!(y::AbstractVector, A::LinearCombination, x::AbstractVector) = mul!(y, transpose(A), x)
143111

144-
Ac_mul_B!(y::AbstractVector, A::LinearCombination, x::AbstractVector) = A_mul_B!(y, adjoint(A), x)
112+
Ac_mul_B!(y::AbstractVector, A::LinearCombination, x::AbstractVector) = mul!(y, adjoint(A), x)

src/transpose.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ struct AdjointMap{T, A<:LinearMap{T}} <: LinearMap{T}
55
lmap::A
66
end
77

8+
MulStyle(A::Union{TransposeMap,AdjointMap}) = MulStyle(A.lmap)
9+
810
# transposition behavior of LinearMap objects
911
LinearAlgebra.transpose(A::TransposeMap) = A.lmap
1012
LinearAlgebra.adjoint(A::AdjointMap) = A.lmap

src/uniformscalingmap.jl

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ UniformScalingMap(λ::Number, M::Int, N::Int) =
1111
UniformScalingMap::T, sz::Dims{2}) where {T} =
1212
(sz[1] == sz[2] ? UniformScalingMap(λ, sz[1]) : error("UniformScalingMap needs to be square"))
1313

14+
MulStyle(::UniformScalingMap) = FiveArg()
15+
1416
# properties
1517
Base.size(A::UniformScalingMap) = (A.M, A.M)
1618
Base.isreal(A::UniformScalingMap) = isreal(A.λ)
@@ -91,11 +93,22 @@ function _scaling!(y, J::UniformScalingMap, x, α::Number=true, β::Number=false
9193
rmul!(y, β)
9294
return y
9395
else # α != 0, 1
94-
iszero(β) && (y .= λ .* x .* α; return y)
95-
isone(β) && (y .+= λ .* x .* α; return y)
96-
# β != 0, 1
97-
y .= y .* β .+ λ .* x .* α
98-
return y
96+
if iszero(β)
97+
iszero(λ) && return fill!(y, zero(eltype(y)))
98+
isone(λ) && return y .= x .* α
99+
y .= λ .* x .* α
100+
return y
101+
elseif isone(β)
102+
iszero(λ) && return y
103+
isone(λ) && return y .+= x .* α
104+
y .+= λ .* x .* α
105+
return y
106+
else # β != 0, 1
107+
iszero(λ) && (rmul!(y, β); return y)
108+
isone(λ) && (y .= y .* β .+ x .* α; return y)
109+
y .= y .* β .+ λ .* x .* α
110+
return y
111+
end
99112
end
100113
end
101114

src/wrappedmap.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ end
1919

2020
const MatrixMap{T} = WrappedMap{T,<:AbstractMatrix}
2121

22+
MulStyle(A::WrappedMap) = MulStyle(A.lmap)
23+
2224
LinearAlgebra.transpose(A::MatrixMap{T}) where {T} =
2325
WrappedMap{T}(transpose(A.lmap); issymmetric=A._issymmetric, ishermitian=A._ishermitian, isposdef=A._isposdef)
2426
LinearAlgebra.adjoint(A::MatrixMap{T}) where {T} =

test/blockmap.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ using Test, LinearMaps, LinearAlgebra
66
A11 = rand(elty, 10, 10)
77
A12 = rand(elty, 10, n2)
88
L = @inferred hcat(LinearMap(A11), LinearMap(A12))
9+
@test @inferred(LinearMaps.MulStyle(L)) === matrixstyle
910
@test L isa LinearMaps.BlockMap{elty}
1011
A = [A11 A12]
1112
x = rand(10+n2)
@@ -36,6 +37,7 @@ using Test, LinearMaps, LinearAlgebra
3637
A21 = rand(elty, 20, 10)
3738
L = @inferred vcat(LinearMap(A11), LinearMap(A21))
3839
@test L isa LinearMaps.BlockMap{elty}
40+
@test @inferred(LinearMaps.MulStyle(L)) === matrixstyle
3941
A = [A11; A21]
4042
x = rand(10)
4143
@test size(L) == size(A)
@@ -62,6 +64,7 @@ using Test, LinearMaps, LinearAlgebra
6264
A = [A11 A12; A21 A22]
6365
@inferred hvcat((2,2), LinearMap(A11), LinearMap(A12), LinearMap(A21), LinearMap(A22))
6466
L = [LinearMap(A11) LinearMap(A12); LinearMap(A21) LinearMap(A22)]
67+
@test @inferred(LinearMaps.MulStyle(L)) === matrixstyle
6568
@test @inferred !issymmetric(L)
6669
@test @inferred !ishermitian(L)
6770
x = rand(30)
@@ -102,12 +105,13 @@ using Test, LinearMaps, LinearAlgebra
102105
@test Matrix(adjoint(B)) == C'
103106
end
104107
end
105-
108+
106109
@testset "adjoint/transpose" begin
107110
for elty in (Float32, Float64, ComplexF64), transform in (transpose, adjoint)
108111
A12 = rand(elty, 10, 10)
109112
A = [I A12; transform(A12) I]
110113
L = [I LinearMap(A12); transform(LinearMap(A12)) I]
114+
@test @inferred(LinearMaps.MulStyle(L)) === matrixstyle
111115
if elty <: Complex
112116
if transform == transpose
113117
@test @inferred issymmetric(L)

test/kronecker.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using Test, LinearMaps, LinearAlgebra
99
LB = LinearMap(B)
1010
LK = @inferred kron(LA, LB)
1111
@test @inferred size(LK) == size(K)
12+
@test LinearMaps.MulStyle(LK) === LinearMaps.ThreeArg()
1213
for i in (1, 2)
1314
@test @inferred size(LK, i) == size(K, i)
1415
end
@@ -31,6 +32,11 @@ using Test, LinearMaps, LinearAlgebra
3132
@test @inferred kron(LA, LB)' == @inferred kron(LA', LB')
3233
@test (@inferred kron(LA, B)) == (@inferred kron(LA, LB)) == (@inferred kron(A, LB))
3334
@test @inferred ishermitian(kron(LA'LA, LB'LB))
35+
A = rand(2, 5); B = rand(4, 2)
36+
K = @inferred kron(A, LinearMap(B))
37+
@test Matrix(K) kron(A, B)
38+
K = @inferred kron(LinearMap(B), A)
39+
@test Matrix(K) kron(B, A)
3440
A = rand(3, 3); B = rand(2, 2); LA = LinearMap(A); LB = LinearMap(B)
3541
@test @inferred issymmetric(kron(LA'LA, LB'LB))
3642
@test @inferred ishermitian(kron(LA'LA, LB'LB))
@@ -59,7 +65,7 @@ using Test, LinearMaps, LinearAlgebra
5965
@test Matrix(kronsum(transform(LA), transform(LB))) transform(KSmat)
6066
@test Matrix(transform(LinearMap(kronsum(LA, LB)))) Matrix(transform(KS)) transform(KSmat)
6167
end
62-
@inferred kronsum(A, A, LB)
68+
@test @inferred(kronsum(A, A, LB)) == @inferred((A, A, B))
6369
@test Matrix(@inferred LA^⊕(3)) == Matrix(@inferred A^⊕(3)) Matrix(kronsum(LA, A, A))
6470
@test @inferred(kronsum(LA, LA, LB)) == @inferred(kronsum(LA, kronsum(LA, LB))) == @inferred(kronsum(A, A, B))
6571
@test Matrix(@inferred kronsum(A, B, A, B, A, B)) Matrix(@inferred kronsum(LA, LB, LA, LB, LA, LB))

test/linearcombination.jl

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using Test, LinearMaps, LinearAlgebra, BenchmarkTools
44
CS! = LinearMap{ComplexF64}(cumsum!,
55
(y, x) -> (copyto!(y, x); reverse!(y); cumsum!(y, y)), 10;
66
ismutating=true)
7-
v = rand(10)
7+
v = rand(ComplexF64, 10)
88
u = similar(v)
99
b = @benchmarkable mul!($u, $CS!, $v)
1010
@test run(b, samples=3).allocs == 0
@@ -13,12 +13,20 @@ using Test, LinearMaps, LinearAlgebra, BenchmarkTools
1313
@test mul!(u, L, v) n * cumsum(v)
1414
b = @benchmarkable mul!($u, $L, $v)
1515
@test run(b, samples=5).allocs <= 1
16+
for α in (false, true, rand(ComplexF64)), β in (false, true, rand(ComplexF64))
17+
@test mul!(copy(u), L, v, α, β) Matrix(L)*v*α + u*β
18+
end
1619

1720
A = 2 * rand(ComplexF64, (10, 10)) .- 1
1821
B = rand(ComplexF64, size(A)...)
1922
M = @inferred LinearMap(A)
2023
N = @inferred LinearMap(B)
24+
@test @inferred(LinearMaps.MulStyle(M)) === matrixstyle
25+
@test @inferred(LinearMaps.MulStyle(N)) === matrixstyle
2126
LC = @inferred M + N
27+
@test @inferred(LinearMaps.MulStyle(LC)) === matrixstyle
28+
@test @inferred(LinearMaps.MulStyle(LC + I)) === matrixstyle
29+
@test @inferred(LinearMaps.MulStyle(LC + 2.0*I)) === matrixstyle
2230
v = rand(ComplexF64, 10)
2331
w = similar(v)
2432
b = @benchmarkable mul!($w, $M, $v)
@@ -27,10 +35,14 @@ using Test, LinearMaps, LinearAlgebra, BenchmarkTools
2735
b = @benchmarkable mul!($w, $LC, $v)
2836
@test run(b, samples=3).allocs == 0
2937
for α in (false, true, rand(ComplexF64)), β in (false, true, rand(ComplexF64))
30-
b = @benchmarkable mul!($w, $LC, $v, $α, $β)
31-
@test run(b, samples=3).allocs == 0
32-
b = @benchmarkable mul!($w, $(LC + I), $v, $α, $β)
33-
@test run(b, samples=3).allocs == 0
38+
if testallocs
39+
b = @benchmarkable mul!($w, $LC, $v, $α, $β)
40+
@test run(b, samples=3).allocs == 0
41+
b = @benchmarkable mul!($w, $(I + LC), $v, $α, $β)
42+
@test run(b, samples=3).allocs == 0
43+
b = @benchmarkable mul!($w, $(LC + I), $v, $α, $β)
44+
@test run(b, samples=3).allocs == 0
45+
end
3446
y = rand(ComplexF64, size(v))
3547
@test mul!(copy(y), LC, v, α, β) Matrix(LC)*v*α + y*β
3648
@test mul!(copy(y), LC+I, v, α, β) Matrix(LC + I)*v*α + y*β

test/runtests.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
11
using Test, LinearMaps
2+
import LinearMaps: FiveArg, ThreeArg
3+
4+
const matrixstyle = VERSION v"1.3.0-alpha.115" ? FiveArg() : ThreeArg()
5+
6+
const testallocs = true
27

38
include("linearmaps.jl")
49

0 commit comments

Comments
 (0)