Skip to content

Commit 12c13fb

Browse files
krcoolsdkarrasch
andauthored
Specialisation of mul! for left multiplication with scalars (#173)
Co-authored-by: Daniel Karrasch <[email protected]>
1 parent 98843dd commit 12c13fb

20 files changed

+374
-97
lines changed

docs/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@ Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
88
[compat]
99
BenchmarkTools = "1"
1010
Documenter = "0.25, 0.26, 0.27"
11-
Literate = "2"
11+
Literate = "2"

docs/src/custom.jl

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,13 @@ end
2929
Base.size(A::MyFillMap) = A.size
3030

3131
# By a couple of defaults provided for all subtypes of `LinearMap`, we only need to define
32-
# a `LinearAlgebra.mul!` method to have minimal, operational type.
32+
# a `LinearMaps._unsafe_mul!` method to have a minimal, operational type. The (internal)
33+
# function `_unsafe_mul!` is called by `LinearAlgebra.mul!`, constructors, and conversions
34+
# and only needs to be concerned with the bare computing kernel. Dimension checking is done
35+
# on the level of `mul!` etc. Factoring out dimension checking is done to minimise overhead
36+
# caused by repetitive checking.
3337

34-
function LinearAlgebra.mul!(y::AbstractVecOrMat, A::MyFillMap, x::AbstractVector)
35-
LinearMaps.check_dim_mul(y, A, x)
38+
function LinearMaps._unsafe_mul!(y::AbstractVecOrMat, A::MyFillMap, x::AbstractVector)
3639
return fill!(y, iszero(A.λ) ? zero(eltype(y)) : A.λ*sum(x))
3740
end
3841

@@ -84,7 +87,7 @@ using BenchmarkTools
8487

8588
LinearMaps.MulStyle(A::MyFillMap) = FiveArg()
8689

87-
function LinearAlgebra.mul!(
90+
function LinearMaps._unsafe_mul!(
8891
y::AbstractVecOrMat,
8992
A::MyFillMap,
9093
x::AbstractVector,
@@ -126,7 +129,7 @@ typeof(A')
126129
try A'x catch e println(e) end
127130

128131
# If the operator is symmetric or Hermitian, the transpose and the adjoint, respectively,
129-
# of the linear map `A` is given by `A` itself. So let's define corresponding checks.
132+
# of the linear map `A` is given by `A` itself. So let us define corresponding checks.
130133

131134
LinearAlgebra.issymmetric(A::MyFillMap) = A.size[1] == A.size[2]
132135
LinearAlgebra.ishermitian(A::MyFillMap) = isreal(A.λ) && A.size[1] == A.size[2]
@@ -154,22 +157,35 @@ try MyFillMap(5.0, (3, 4))' * ones(3) catch e println(e) end
154157
# The first option is to write `LinearAlgebra.mul!` methods for the corresponding wrapped
155158
# map types; for instance,
156159

157-
function LinearAlgebra.mul!(
160+
function LinearMaps._unsafe_mul!(
158161
y::AbstractVecOrMat,
159162
transA::LinearMaps.TransposeMap{<:Any,<:MyFillMap},
160163
x::AbstractVector
161164
)
162-
LinearMaps.check_dim_mul(y, transA, x)
163165
λ = transA.lmap.λ
164166
return fill!(y, iszero(λ) ? zero(eltype(y)) : transpose(λ)*sum(x))
165167
end
166168

169+
# Now, the adjoint multiplication works.
170+
171+
MyFillMap(5.0, (3, 4))' * ones(3)
172+
167173
# If you have set the `MulStyle` trait to `FiveArg()`, you should provide a corresponding
168174
# 5-arg `mul!` method for `LinearMaps.TransposeMap{<:Any,<:MyFillMap}` and
169175
# `LinearMaps.AdjointMap{<:Any,<:MyFillMap}`.
170176

171177
# ### Path 2: Invariant `LinearMap` subtypes
172178

179+
# Before we start, let us delete the previously defined method to make sure we use the
180+
# following definitions.
181+
182+
Base.delete_method(
183+
first(methods(
184+
LinearMaps._unsafe_mul!,
185+
(AbstractVecOrMat, LinearMaps.TransposeMap{<:Any,<:MyFillMap}, AbstractVector))
186+
)
187+
)
188+
173189
# The seconnd option is when your class of linear maps that are modelled by your custom
174190
# `LinearMap` subtype are invariant under taking adjoints and transposes.
175191

docs/src/history.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,14 @@
11
# Version history
22

33
## What's new in v3.7
4+
* `mul!(M::AbstractMatrix, A::LinearMap, s::Number, a, b)` methods are provided, mimicking
5+
similar methods in `Base.LinearAlgebra`. This version allows for the memory efficient
6+
implementation of in-place addition and conversion of a `LinearMap` to `Matrix`.
7+
Efficient specialisations for `WrappedMap`, `ScaledMap`, and `LinearCombination` are
8+
provided. If users supply the corresponding `_unsafe_mul!` method for their custom maps,
9+
conversion, construction, and inplace addition will benefit from this supplied efficient
10+
implementation. If no specialisation is supplied, a generic fallback is used that is based
11+
on feeding the canonical basis of unit vectors to the linear map.
412

513
* A new map type called `EmbeddedMap` is introduced. It is a wrapper of a "small" `LinearMap`
614
(or a suitably converted `AbstractVecOrMat`) embedded into a "larger" zero map. Hence,

docs/src/types.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ Base.:*(::AbstractMatrix,::LinearMap)
103103
LinearAlgebra.mul!(::AbstractVecOrMat,::LinearMap,::AbstractVector)
104104
LinearAlgebra.mul!(::AbstractVecOrMat,::LinearMap,::AbstractVector,::Number,::Number)
105105
LinearAlgebra.mul!(::AbstractMatrix,::AbstractMatrix,::LinearMap)
106+
LinearAlgebra.mul!(::AbstractVecOrMat,::LinearMap,::Number)
107+
LinearAlgebra.mul!(::AbstractVecOrMat,::LinearMap,::Number,::Number,::Number)
106108
*(::LinearAlgebra.AdjointAbsVec,::LinearMap)
107109
*(::LinearAlgebra.TransposeAbsVec,::LinearMap)
108110
```

src/LinearMaps.jl

Lines changed: 97 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -143,14 +143,14 @@ with either `A` or `B`.
143143
144144
## Examples
145145
```jldoctest; setup=(using LinearAlgebra, LinearMaps)
146-
julia> A=LinearMap([1.0 2.0; 3.0 4.0]); B=[1.0, 1.0]; Y = similar(B); mul!(Y, A, B);
146+
julia> A=LinearMap([1.0 2.0; 3.0 4.0]); B=ones(2); Y = similar(B); mul!(Y, A, B);
147147
148148
julia> Y
149149
2-element Array{Float64,1}:
150150
3.0
151151
7.0
152152
153-
julia> A=LinearMap([1.0 2.0; 3.0 4.0]); B=[1.0 1.0; 1.0 1.0]; Y = similar(B); mul!(Y, A, B);
153+
julia> A=LinearMap([1.0 2.0; 3.0 4.0]); B=ones(4,4); Y = similar(B); mul!(Y, A, B);
154154
155155
julia> Y
156156
2×2 Array{Float64,2}:
@@ -162,6 +162,35 @@ function mul!(y::AbstractVecOrMat, A::LinearMap, x::AbstractVector)
162162
check_dim_mul(y, A, x)
163163
return _unsafe_mul!(y, A, x)
164164
end
165+
# the following is of interest in, e.g., subspace-iteration methods
166+
function mul!(Y::AbstractMatrix, A::LinearMap, X::AbstractMatrix)
167+
check_dim_mul(Y, A, X)
168+
return _unsafe_mul!(Y, A, X)
169+
end
170+
171+
"""
172+
mul!(Y::AbstractMatrix, A::LinearMap, b::Number) -> Y
173+
174+
Scales the matrix representation of the linear map `A` by `b` and stores the result in `Y`,
175+
overwriting the existing value of `Y`.
176+
177+
## Examples
178+
```jldoctest; setup=(using LinearAlgebra, LinearMaps)
179+
julia> A = LinearMap{Int}(cumsum, 3); b = 2; Y = Matrix{Int}(undef, (3,3));
180+
181+
julia> mul!(Y, A, b)
182+
3×3 Matrix{Int64}:
183+
2 0 0
184+
2 2 0
185+
2 2 2
186+
```
187+
"""
188+
function mul!(y::AbstractVecOrMat, A::LinearMap, s::Number)
189+
size(y) == size(A) ||
190+
throw(
191+
DimensionMismatch("y has size $(size(y)), A has size $(size(A))."))
192+
return _unsafe_mul!(y, A, s)
193+
end
165194

166195
"""
167196
mul!(C::AbstractVecOrMat, A::LinearMap, B::AbstractVector, α, β) -> C
@@ -197,6 +226,46 @@ function mul!(y::AbstractVecOrMat, A::LinearMap, x::AbstractVector, α::Number,
197226
check_dim_mul(y, A, x)
198227
return _unsafe_mul!(y, A, x, α, β)
199228
end
229+
function mul!(Y::AbstractMatrix, A::LinearMap, X::AbstractMatrix, α::Number, β::Number)
230+
check_dim_mul(Y, A, X)
231+
return _unsafe_mul!(Y, A, X, α, β)
232+
end
233+
234+
"""
235+
mul!(Y::AbstractMatrix, A::LinearMap, b::Number, α::Number, β::Number) -> Y
236+
237+
Scales the matrix representation of the linear map `A` by `b*α`, adds the result to `Y*β`
238+
and stores the final result in `Y`, overwriting the existing value of `Y`.
239+
240+
## Examples
241+
```jldoctest; setup=(using LinearAlgebra, LinearMaps)
242+
julia> A = LinearMap{Int}(cumsum, 3); b = 2; Y = ones(Int, (3,3));
243+
244+
julia> mul!(Y, A, b, 2, 1)
245+
3×3 Matrix{Int64}:
246+
5 1 1
247+
5 5 1
248+
5 5 5
249+
```
250+
"""
251+
function mul!(y::AbstractMatrix, A::LinearMap, s::Number, α::Number, β::Number)
252+
size(y) == size(A) ||
253+
throw(
254+
DimensionMismatch("y has size $(size(y)), A has size $(size(A))."))
255+
return _unsafe_mul!(y, A, s, α, β)
256+
end
257+
258+
_unsafe_mul!(y, A::MapOrVecOrMat, x) = mul!(y, A, x)
259+
_unsafe_mul!(y, A::AbstractVecOrMat, x, α, β) = mul!(y, A, x, α, β)
260+
_unsafe_mul!(y::AbstractVecOrMat, A::LinearMap, x::AbstractVector, α, β) =
261+
_generic_mapvec_mul!(y, A, x, α, β)
262+
_unsafe_mul!(y::AbstractMatrix, A::LinearMap, x::AbstractMatrix) =
263+
_generic_mapmat_mul!(y, A, x)
264+
_unsafe_mul!(y::AbstractMatrix, A::LinearMap, x::AbstractMatrix, α::Number, β::Number) =
265+
_generic_mapmat_mul!(y, A, x, α, β)
266+
_unsafe_mul!(Y::AbstractMatrix, A::LinearMap, s::Number) = _generic_mapnum_mul!(Y, A, s)
267+
_unsafe_mul!(Y::AbstractMatrix, A::LinearMap, s::Number, α::Number, β::Number) =
268+
_generic_mapnum_mul!(Y, A, s, α, β)
200269

201270
function _generic_mapvec_mul!(y, A, x, α, β)
202271
# this function needs to call mul! for, e.g., AdjointMap{...,<:CustomMap}
@@ -226,33 +295,40 @@ function _generic_mapvec_mul!(y, A, x, α, β)
226295
end
227296
end
228297

229-
# the following is of interest in, e.g., subspace-iteration methods
230-
function mul!(Y::AbstractMatrix, A::LinearMap, X::AbstractMatrix)
231-
check_dim_mul(Y, A, X)
232-
return _unsafe_mul!(Y, A, X)
233-
end
234-
function mul!(Y::AbstractMatrix, A::LinearMap, X::AbstractMatrix, α::Number, β::Number)
235-
check_dim_mul(Y, A, X)
236-
return _unsafe_mul!(Y, A, X, α, β)
298+
function _generic_mapmat_mul!(Y, A, X)
299+
for (Xi, Yi) in zip(eachcol(X), eachcol(Y))
300+
mul!(Yi, A, Xi)
301+
end
302+
return Y
237303
end
238-
239-
function _generic_mapmat_mul!(Y, A, X, α=true, β=false)
304+
function _generic_mapmat_mul!(Y, A, X, α, β)
240305
for (Xi, Yi) in zip(eachcol(X), eachcol(Y))
241306
mul!(Yi, A, Xi, α, β)
242307
end
243308
return Y
244309
end
245310

246-
_unsafe_mul!(y, A::MapOrVecOrMat, x) = mul!(y, A, x)
247-
_unsafe_mul!(y, A::AbstractVecOrMat, x, α, β) = mul!(y, A, x, α, β)
248-
function _unsafe_mul!(y::AbstractVecOrMat, A::LinearMap, x::AbstractVector, α, β)
249-
return _generic_mapvec_mul!(y, A, x, α, β)
250-
end
251-
function _unsafe_mul!(y::AbstractMatrix, A::LinearMap, x::AbstractMatrix)
252-
return _generic_mapmat_mul!(y, A, x)
311+
function _generic_mapnum_mul!(Y, A, s)
312+
T = promote_type(eltype(A), typeof(s))
313+
ax2 = axes(A)[2]
314+
xi = zeros(T, ax2)
315+
@inbounds for (i, Yi) in zip(ax2, eachcol(Y))
316+
xi[i] = s
317+
mul!(Yi, A, xi)
318+
xi[i] = zero(T)
319+
end
320+
return Y
253321
end
254-
function _unsafe_mul!(y::AbstractMatrix, A::LinearMap, x::AbstractMatrix, α, β)
255-
return _generic_mapmat_mul!(y, A, x, α, β)
322+
function _generic_mapnum_mul!(Y, A, s, α, β)
323+
T = promote_type(eltype(A), typeof(s))
324+
ax2 = axes(A)[2]
325+
xi = zeros(T, ax2)
326+
@inbounds for (i, Yi) in zip(ax2, eachcol(Y))
327+
xi[i] = s
328+
mul!(Yi, A, xi, α, β)
329+
xi[i] = zero(T)
330+
end
331+
return Y
256332
end
257333

258334
include("left.jl") # left multiplication by a transpose or adjoint vector

0 commit comments

Comments
 (0)