Skip to content

Commit 9a366ad

Browse files
authored
Add in-place left-multiplication for absmat (#131)
1 parent 66cc03e commit 9a366ad

File tree

6 files changed

+104
-30
lines changed

6 files changed

+104
-30
lines changed

docs/src/types.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ Base.:*(::LinearMap,::AbstractMatrix)
102102
Base.:*(::AbstractMatrix,::LinearMap)
103103
LinearAlgebra.mul!(::AbstractVecOrMat,::LinearMap,::AbstractVector)
104104
LinearAlgebra.mul!(::AbstractVecOrMat,::LinearMap,::AbstractVector,::Number,::Number)
105+
LinearAlgebra.mul!(::AbstractMatrix,::AbstractMatrix,::LinearMap)
105106
*(::LinearAlgebra.AdjointAbsVec,::LinearMap)
106107
*(::LinearAlgebra.TransposeAbsVec,::LinearMap)
107108
```

src/kronecker.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ function _unsafe_mul!(y::AbstractVecOrMat, L::KroneckerSumMap, x::AbstractVector
288288
mb, nb = size(B)
289289
X = reshape(x, (nb, na))
290290
Y = reshape(y, (nb, na))
291-
_unsafe_mul!(Y, X, convert(AbstractMatrix, transpose(A)))
291+
_unsafe_mul!(Y, X, transpose(A))
292292
_unsafe_mul!(Y, B, X, true, true)
293293
return y
294294
end

src/left.jl

Lines changed: 49 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -38,29 +38,61 @@ julia> A=LinearMap([1.0 2.0; 3.0 4.0]); x=[1.0, 1.0]; transpose(x)*A
3838
Base.:(*)(y::LinearAlgebra.TransposeAbsVec, A::LinearMap) = transpose(transpose(A) * transpose(y))
3939

4040
# multiplication with vector/matrix
41-
const AdjointAbsVecOrMat{T} = Adjoint{T,<:AbstractVecOrMat}
4241
const TransposeAbsVecOrMat{T} = Transpose{T,<:AbstractVecOrMat}
4342

44-
function mul!(x::AbstractMatrix, y::AdjointAbsVecOrMat, A::LinearMap)
45-
check_dim_mul(x, y, A)
46-
_unsafe_mul!(x', A', y')
47-
return x
43+
# handles both y::AbstractMatrix and y::AdjointAbsVecOrMat
44+
"""
45+
mul!(C::AbstractMatrix, A::AbstractMatrix, B::LinearMap) -> C
46+
47+
Calculates the matrix representation of `A*B` and stores the result in `C`,
48+
overwriting the existing value of `C`. Note that `C` must not be aliased with
49+
either `A` or `B`. The computation `C = A*B` is performed via `C' = B'A'`.
50+
51+
## Examples
52+
```jldoctest; setup=(using LinearAlgebra, LinearMaps)
53+
julia> A=[1.0 1.0; 1.0 1.0]; B=LinearMap([1.0 2.0; 3.0 4.0]); C = similar(A); mul!(C, A, B);
54+
55+
julia> C
56+
2×2 Array{Float64,2}:
57+
4.0 6.0
58+
4.0 6.0
59+
```
60+
"""
61+
function mul!(X::AbstractMatrix, Y::AbstractMatrix, A::LinearMap)
62+
check_dim_mul(X, Y, A)
63+
_unsafe_mul!(X', A', Y')
64+
return X
65+
end
66+
67+
function mul!(X::AbstractMatrix, Y::TransposeAbsVecOrMat, A::LinearMap)
68+
check_dim_mul(X, Y, A)
69+
_unsafe_mul!(transpose(X), transpose(A), transpose(Y))
70+
return X
4871
end
4972

50-
function mul!(x::AbstractMatrix, y::AdjointAbsVecOrMat, A::LinearMap, α::Number, β::Number)
51-
check_dim_mul(x, y, A)
52-
_unsafe_mul!(x', conj(α)*A', y', true, conj(β))
53-
return x
73+
# commutative case, handles both the abstract and adjoint case
74+
function mul!(X::AbstractMatrix{<:RealOrComplex}, Y::AbstractMatrix{<:RealOrComplex}, A::LinearMap{<:RealOrComplex},
75+
α::RealOrComplex, β::RealOrComplex)
76+
check_dim_mul(X, Y, A)
77+
_unsafe_mul!(X', A', Y', conj(α), conj(β))
78+
return X
5479
end
5580

56-
function mul!(x::AbstractMatrix, y::TransposeAbsVecOrMat, A::LinearMap)
57-
check_dim_mul(x, y, A)
58-
_unsafe_mul!(transpose(x), transpose(A), transpose(y))
59-
return x
81+
function mul!(X::AbstractMatrix{<:RealOrComplex}, Y::TransposeAbsVecOrMat{<:RealOrComplex}, A::LinearMap{<:RealOrComplex},
82+
α::RealOrComplex, β::RealOrComplex)
83+
check_dim_mul(X, Y, A)
84+
_unsafe_mul!(transpose(X), transpose(A), transpose(Y), α, β)
85+
return X
6086
end
6187

62-
function mul!(x::AbstractMatrix, y::TransposeAbsVecOrMat, A::LinearMap, α::Number, β::Number)
63-
check_dim_mul(x, y, A)
64-
_unsafe_mul!(transpose(x), α*transpose(A), transpose(y), true, β)
65-
return x
88+
# non-commutative case
89+
function mul!(X::AbstractMatrix, Y::AbstractMatrix, A::LinearMap, α::Number, β::Number)
90+
check_dim_mul(X, Y, A)
91+
if iszero(β)
92+
_unsafe_mul!(X', conj(α)*A', Y')
93+
else
94+
!isone(β) && rmul!(X, β)
95+
_unsafe_mul!(X', conj(α)*A', Y', true, true)
96+
end
97+
return X
6698
end

src/wrappedmap.jl

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,10 @@ for (In, Out) in ((AbstractVector, AbstractVecOrMat), (AbstractMatrix, AbstractM
6161
end
6262
end
6363

64+
mul!(Y::AbstractMatrix, X::AbstractMatrix, A::MatrixMap) = mul!(Y, X, A.lmap)
65+
# the following method is needed for disambiguation with left-multiplication
66+
mul!(Y::AbstractMatrix, X::TransposeAbsVecOrMat, A::MatrixMap) = mul!(Y, X, A.lmap)
67+
6468
if VERSION v"1.3.0-alpha.115"
6569
for (In, Out) in ((AbstractVector, AbstractVecOrMat), (AbstractMatrix, AbstractMatrix))
6670
@eval begin
@@ -82,6 +86,18 @@ if VERSION ≥ v"1.3.0-alpha.115"
8286
end
8387
end
8488
end
89+
90+
mul!(X::AbstractMatrix, Y::AbstractMatrix, A::MatrixMap, α::Number, β::Number) =
91+
mul!(X, Y, A.lmap, α, β)
92+
# the following method is needed for disambiguation with left-multiplication
93+
function mul!(Y::AbstractMatrix{<:RealOrComplex}, X::AbstractMatrix{<:RealOrComplex}, A::MatrixMap{<:RealOrComplex},
94+
α::RealOrComplex, β::RealOrComplex)
95+
return mul!(Y, X, A.lmap, α, β)
96+
end
97+
function mul!(Y::AbstractMatrix{<:RealOrComplex}, X::TransposeAbsVecOrMat{<:RealOrComplex}, A::MatrixMap{<:RealOrComplex},
98+
α::RealOrComplex, β::RealOrComplex)
99+
return mul!(Y, X, A.lmap, α, β)
100+
end
85101
end # VERSION
86102

87103
# combine LinearMap and Matrix objects: linear combinations and map composition
@@ -112,7 +128,8 @@ Base.:(*)(A₁::LinearMap, A₂::AbstractMatrix) = *(A₁, WrappedMap(A₂))
112128
*(X::AbstractMatrix, A::LinearMap)::CompositeMap
113129
114130
Return the `CompositeMap` `LinearMap(X)*A`, interpreting the matrix `X` as a linear
115-
operator. To compute the right-action of `A` on each row of `X`, call `Matrix(X*A)`.
131+
operator. To compute the right-action of `A` on each row of `X`, call `Matrix(X*A)`
132+
or `mul!(Y, X, A)` for the in-place version.
116133
117134
## Examples
118135
```jldoctest; setup=(using LinearMaps)

test/numbertypes.jl

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,41 @@
11
using Test, LinearMaps, LinearAlgebra, Quaternions
22

33
# type piracy because Quaternions.jl doesn't have it right
4-
Base.:(*)(z::Complex{T}, q::Quaternion{T}) where {T<:Real} = quat(z) * q
5-
Base.:(*)(q::Quaternion{T}, z::Complex{T}) where {T<:Real} = q * quat(z)
4+
Base.:(*)(z::Complex, q::Quaternion) = quat(z) * q
5+
Base.:(*)(q::Quaternion, z::Complex) = q * quat(z)
6+
Base.:(+)(q::Quaternion, z::Complex) = q + quat(z)
67

78
@testset "noncommutative number type" begin
89
x = Quaternion.(rand(10), rand(10), rand(10), rand(10))
910
v = rand(10)
1011
A = Quaternion.(rand(10,10), rand(10,10), rand(10,10), rand(10,10))
1112
B = rand(ComplexF64, 10, 10)
13+
C = similar(A)
1214
γ = Quaternion.(rand(4)...) # "Number"
1315
α = UniformScaling(γ)
1416
β = UniformScaling(Quaternion.(rand(4)...))
1517
λ = rand(ComplexF64)
1618
L = LinearMap(A)
17-
@test Array(L) == A
18-
@test Array(L') == A'
19-
@test Array(transpose(L)) == transpose(A)
20-
@test Array* L) == α * A
21-
@test Array(L * α) == A * α
22-
@test Array* L) == α * A
23-
@test Array(L * α ) == A * α
24-
@test Array* L') == α * A'
25-
@test Array((α * L')') * A')' A * conj(α)
19+
F = LinearMap{eltype(A)}(x -> A*x, y -> A'y, 10)
20+
@test Array(F) == A
21+
@test Array(F') == A'
22+
@test Array(transpose(F)) == transpose(A)
23+
@test Array* F) == α * A
24+
@test Array(F * α) == A * α
25+
@test Array* F) == α * A
26+
@test Array(F * α ) == A * α
27+
@test Array* F') == α * A'
28+
for M in (L, F)
29+
@test mul!(C, transpose(A), M) transpose(A)*A
30+
@test mul!(C, A', M) A'A
31+
@test mul!(C, A, M) A*A
32+
@test mul!(copy(C), M, A, γ, λ) A*A*γ + C*λ
33+
@test mul!(copy(C), A, M, γ, λ) A*A*γ + C*λ
34+
@test mul!(copy(C), A, M, γ, 0) A*A*γ
35+
@test mul!(copy(C), transpose(A), M, γ, λ) transpose(A)*A*γ + C*λ
36+
@test mul!(copy(C), adjoint(A), M, γ, λ) A'*A*γ + C*λ
37+
end
38+
@test Array((α * F')') * A')' A * conj(α)
2639
@test L * x A * x
2740
@test L' * x A' * x
2841
@test α * (L * x) α * (A * x)

test/wrappedmap.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,4 +72,15 @@ using Test, LinearMaps, LinearAlgebra
7272
B = @inferred LinearMap(Hermitian(rand(ComplexF64, 10, 10)))
7373
@test adjoint(B) == B
7474
@test B == B'
75+
76+
N = 10
77+
Id = LinearMap(identity, N; issymmetric=true)
78+
A = rand(N, N)
79+
B = similar(A)
80+
@test mul!(copy(B), Id, A) == A
81+
@test mul!(B, A, Id) == B == A
82+
@test mul!(copy(B), A, LinearMap(Matrix(Id))) == A
83+
@test mul!(B, Id, A, true, true) 2A
84+
@test mul!(B, A, Id, true, true) == B == 3A
85+
@test mul!(copy(B), A, LinearMap(Matrix(Id)), true, true) == 4A
7586
end

0 commit comments

Comments
 (0)