Skip to content

Commit bc21d92

Browse files
authored
Fix _unsafe_mul! for custom maps (#110)
1 parent 2dbe5f2 commit bc21d92

File tree

6 files changed

+56
-20
lines changed

6 files changed

+56
-20
lines changed

src/LinearMaps.jl

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,10 @@ julia> A*x
9393
```
9494
"""
9595
function Base.:(*)(A::LinearMap, x::AbstractVector)
96-
size(A, 2) == length(x) || throw(DimensionMismatch("linear map has dimensions ($mA,$nA), " *
97-
"vector has length $mB"))
98-
return _unsafe_mul!(similar(x, promote_type(eltype(A), eltype(x)), size(A, 1)), A, x)
96+
m, n = size(A)
97+
n == length(x) || throw(DimensionMismatch("linear map has dimensions ($m,$n), " *
98+
"vector has length $(length(x))"))
99+
return mul!(similar(x, promote_type(eltype(A), eltype(x)), m), A, x)
99100
end
100101

101102
"""
@@ -163,21 +164,23 @@ function mul!(y::AbstractVecOrMat, A::LinearMap, x::AbstractVector, α::Number,
163164
end
164165

165166
function _generic_mapvec_mul!(y, A, x, α, β)
167+
# this function needs to call mul! for, e.g., AdjointMap{...,<:CustomMap}
166168
if isone(α)
167-
iszero(β) && (_unsafe_mul!(y, A, x); return y)
168-
isone(β) && (y .+= A * x; return y)
169-
# β != 0, 1
169+
iszero(β) && return mul!(y, A, x)
170170
z = A * x
171-
y .= y.*β .+ z
171+
if isone(β)
172+
y .+= z
173+
else
174+
y .= y.*β .+ z
175+
end
172176
return y
173177
elseif iszero(α)
174-
iszero(β) && (fill!(y, zero(eltype(y))); return y)
178+
iszero(β) && return fill!(y, zero(eltype(y)))
175179
isone(β) && return y
176180
# β != 0, 1
177-
rmul!(y, β)
178-
return y
181+
return rmul!(y, β)
179182
else # α != 0, 1
180-
iszero(β) && (_unsafe_mul!(y, A, x); rmul!(y, α); return y)
183+
iszero(β) && return rmul!(mul!(y, A, x), α)
181184
z = A * x
182185
if isone(β)
183186
y .+= z .* α

src/functionmap.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ function Base.:(*)(A::AdjointMap{<:Any,<:FunctionMap}, x::AbstractVector)
7171
conj!(y)
7272
return y
7373
else
74-
error("adjoint not implemented for $A")
74+
error("adjoint not implemented for $(A.lmap)")
7575
end
7676
end
7777
function Base.:(*)(A::TransposeMap{<:Any,<:FunctionMap}, x::AbstractVector)
@@ -103,7 +103,7 @@ function Base.:(*)(A::TransposeMap{<:Any,<:FunctionMap}, x::AbstractVector)
103103
conj!(y)
104104
return y
105105
else
106-
error("transpose not implemented for $A")
106+
error("transpose not implemented for $(A.lmap)")
107107
end
108108
end
109109

@@ -129,7 +129,7 @@ function _unsafe_mul!(y::AbstractVecOrMat, At::TransposeMap{<:Any,<:FunctionMap}
129129
conj!(y)
130130
return y
131131
else
132-
error("transpose not implemented for $A")
132+
error("transpose not implemented for $(A.lmap)")
133133
end
134134
end
135135

@@ -144,6 +144,6 @@ function _unsafe_mul!(y::AbstractVecOrMat, Ac::AdjointMap{<:Any,<:FunctionMap},
144144
conj!(y)
145145
return y
146146
else
147-
error("adjoint not implemented for $A")
147+
error("adjoint not implemented for $(A.lmap)")
148148
end
149149
end

src/transpose.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ Base.:(==)(A::LinearMap, B::AdjointMap) = ishermitian(A) && B.lmap == A
5151
# multiplication with vector/matrices
5252
# # TransposeMap
5353
function _unsafe_mul!(y::AbstractVecOrMat, A::TransposeMap, x::AbstractVector)
54-
issymmetric(A.lmap) ? _unsafe_mul!(y, A.lmap, x) : error("transpose not implemented for $A")
54+
issymmetric(A.lmap) ? _unsafe_mul!(y, A.lmap, x) : error("transpose not implemented for $(A.lmap)")
5555
end
5656
function _unsafe_mul!(y::AbstractMatrix, A::TransposeMap, x::AbstractMatrix)
5757
issymmetric(A.lmap) ? _unsafe_mul!(y, A.lmap, x) : _generic_mapmat_mul!(y, A, x)
@@ -64,7 +64,7 @@ function _unsafe_mul!(y::AbstractMatrix, A::TransposeMap, x::AbstractMatrix, α:
6464
end
6565
# # AdjointMap
6666
function _unsafe_mul!(y::AbstractVecOrMat, A::AdjointMap, x::AbstractVector)
67-
ishermitian(A.lmap) ? _unsafe_mul!(y, A.lmap, x) : error("adjoint not implemented for $A")
67+
ishermitian(A.lmap) ? _unsafe_mul!(y, A.lmap, x) : error("adjoint not implemented for $(A.lmap)")
6868
end
6969
function _unsafe_mul!(y::AbstractMatrix, A::AdjointMap, x::AbstractMatrix)
7070
ishermitian(A.lmap) ? _unsafe_mul!(y, A.lmap, x) : _generic_mapmat_mul!(y, A, x)

src/wrappedmap.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,6 @@ LinearAlgebra.ishermitian(A::WrappedMap) = A._ishermitian
3737
LinearAlgebra.isposdef(A::WrappedMap) = A._isposdef
3838

3939
# multiplication with vectors & matrices
40-
Base.:(*)(A::WrappedMap, x::AbstractVector) = *(A.lmap, x)
41-
4240
for (intype, outtype) in ((AbstractVector, AbstractVecOrMat), (AbstractMatrix, AbstractMatrix))
4341
@eval begin
4442
_unsafe_mul!(y::$outtype, A::WrappedMap, x::$intype) = _unsafe_mul!(y, A.lmap, x)

test/kronecker.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using Test, LinearMaps, LinearAlgebra, SparseArrays
22

3-
@testset "kronecker products" begin
3+
@testset "Kronecker products and sums" begin
44
@testset "Kronecker product" begin
55
A = rand(ComplexF64, 3, 3)
66
B = rand(ComplexF64, 2, 2)
@@ -70,6 +70,7 @@ using Test, LinearMaps, LinearAlgebra, SparseArrays
7070
@test_throws ArgumentError kronsum(LA, [B B]) # non-square map
7171
KSmat = kron(A, Matrix(I, 2, 2)) + kron(Matrix(I, 3, 3), B)
7272
@test Matrix(KS) Matrix(kron(A, LinearMap(I, 2)) + kron(LinearMap(I, 3), B))
73+
@test KS * ones(size(KS, 2)) KSmat * ones(size(KS, 2))
7374
@test size(KS) == size(kron(A, Matrix(I, 2, 2)))
7475
for transform in (identity, transpose, adjoint)
7576
@test Matrix(transform(KS)) transform(Matrix(KS)) transform(KSmat)

test/linearmaps.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,3 +99,37 @@ LinearAlgebra.mul!(y::AbstractVector, A::Union{SimpleFunctionMap,SimpleComplexFu
9999
@test SparseMatrixCSC(F) == Fs == L
100100
@test Fs isa SparseMatrixCSC
101101
end
102+
103+
struct MyFillMap{T} <: LinearMaps.LinearMap{T}
104+
λ::T
105+
size::Dims{2}
106+
function MyFillMap::T, dims::Dims{2}) where {T}
107+
all(d -> d >= 0, dims) || throw(ArgumentError("dims of MyFillMap must be non-negative"))
108+
promote_type(T, typeof(λ)) == T || throw(InexactError())
109+
return new{T}(λ, dims)
110+
end
111+
end
112+
Base.size(A::MyFillMap) = A.size
113+
function LinearAlgebra.mul!(y::AbstractVecOrMat, A::MyFillMap, x::AbstractVector)
114+
LinearMaps.check_dim_mul(y, A, x)
115+
return fill!(y, iszero(A.λ) ? zero(eltype(y)) : A.λ*sum(x))
116+
end
117+
function LinearAlgebra.mul!(y::AbstractVecOrMat, transA::LinearMaps.TransposeMap{<:Any,<:MyFillMap}, x::AbstractVector)
118+
LinearMaps.check_dim_mul(y, transA, x)
119+
λ = transA.lmap.λ
120+
return fill!(y, iszero(λ) ? zero(eltype(y)) : transpose(λ)*sum(x))
121+
end
122+
123+
@testset "transpose of new LinearMap type" begin
124+
A = MyFillMap(5.0, (3, 3))
125+
x = ones(3)
126+
@test A * x == fill(15.0, 3)
127+
@test mul!(zeros(3), A, x) == mul!(zeros(3), A, x, 1, 0) == fill(15.0, 3)
128+
@test mul!(ones(3), A, x, 2, 2) == fill(32.0, 3)
129+
@test mul!(ones(3,3), A, reshape(collect(1:9), 3, 3), 2, 2)[:,1] == fill(62.0, 3)
130+
@test A'x == mul!(zeros(3), A', x)
131+
for α in (true, false, 0, 1, randn()), β in (true, false, 0, 1, randn())
132+
@test mul!(ones(3), A', x, α, β) == fill(β, 3) + fill(15α, 3)
133+
@test mul!(ones(3, 2), A', [x x], α, β) == fill(β, 3, 2) + fill(15α, 3, 2)
134+
end
135+
end

0 commit comments

Comments
 (0)