Skip to content

Commit 67ad5c4

Browse files
dkarraschJutho
authored andcommitted
UniformScalingMap improvements (#49)
* speed up mul with UniformScalingMap, inherit type from LinearMap * update as per review * make code more symmetric
1 parent 9ecd792 commit 67ad5c4

File tree

1 file changed

+23
-13
lines changed

1 file changed

+23
-13
lines changed

src/uniformscalingmap.jl

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,30 @@ LinearAlgebra.transpose(A::UniformScalingMap) = A
1919
LinearAlgebra.adjoint(A::UniformScalingMap) = UniformScalingMap(conj(A.λ), size(A))
2020

2121
# multiplication with vector
22+
Base.:(*)(A::UniformScalingMap, x::AbstractVector) =
23+
length(x) == A.M ? A.λ * x : throw(DimensionMismatch("A_mul_B!"))
24+
2225
# call of LinearAlgebra.generic_mul! since order of arguments in mul! in stdlib/LinearAlgebra/src/generic.jl
2326
# TODO: either leave it as is or use mul! (and lower bound on version) once fixed in LinearAlgebra
24-
A_mul_B!(y::AbstractVector, A::UniformScalingMap, x::AbstractVector) =
25-
(length(x) == length(y) == A.M ? LinearAlgebra.generic_mul!(y, A.λ, x) : throw(DimensionMismatch("A_mul_B!")))
26-
Base.:(*)(A::UniformScalingMap, x::AbstractVector) = A.λ * x
27-
28-
At_mul_B!(y::AbstractVector, A::UniformScalingMap, x::AbstractVector) =
29-
(length(x) == length(y) == A.M ? LinearAlgebra.generic_mul!(y, A.λ, x) : throw(DimensionMismatch("At_mul_B!")))
30-
31-
Ac_mul_B!(y::AbstractVector, A::UniformScalingMap, x::AbstractVector) =
32-
(length(x) == length(y) == A.M ? LinearAlgebra.generic_mul!(y, conj(A.λ), x) : throw(DimensionMismatch("Ac_mul_B!")))
27+
function A_mul_B!(y::AbstractVector, A::UniformScalingMap, x::AbstractVector)
28+
(length(x) == length(y) == A.M || throw(DimensionMismatch("A_mul_B!")))
29+
if iszero(A.λ)
30+
return fill!(y, 0)
31+
elseif isone(A.λ)
32+
return copyto!(y, x)
33+
else
34+
return LinearAlgebra.generic_mul!(y, A.λ, x)
35+
end
36+
end
37+
At_mul_B!(y::AbstractVector, A::UniformScalingMap, x::AbstractVector) = A_mul_B!(y, transpose(A), x)
38+
Ac_mul_B!(y::AbstractVector, A::UniformScalingMap, x::AbstractVector) = A_mul_B!(y, adjoint(A), x)
3339

3440
# combine LinearMap and UniformScaling objects in linear combinations
35-
Base.:(+)(A1::LinearMap, A2::UniformScaling{T}) where {T} = A1 + UniformScalingMap{T}(A2[1,1], size(A1, 1))
36-
Base.:(+)(A1::UniformScaling{T}, A2::LinearMap) where {T} = UniformScalingMap{T}(A1[1,1], size(A2, 1)) + A2
37-
Base.:(-)(A1::LinearMap, A2::UniformScaling{T}) where {T} = A1 - UniformScalingMap{T}(A2[1,1], size(A1, 1))
38-
Base.:(-)(A1::UniformScaling{T}, A2::LinearMap) where {T} = UniformScalingMap{T}(A1[1,1], size(A2, 1)) - A2
41+
Base.:(+)(A1::LinearMap, A2::UniformScaling) =
42+
A1 + UniformScalingMap(convert(promote_type(eltype(A1), eltype(A2)), A2.λ), size(A1, 1))
43+
Base.:(+)(A1::UniformScaling, A2::LinearMap) =
44+
UniformScalingMap(convert(promote_type(eltype(A1), eltype(A2)), A1.λ), size(A2, 1)) + A2
45+
Base.:(-)(A1::LinearMap, A2::UniformScaling) =
46+
A1 - UniformScalingMap(convert(promote_type(eltype(A1), eltype(A2)), A2.λ), size(A1, 1))
47+
Base.:(-)(A1::UniformScaling, A2::LinearMap) =
48+
UniformScalingMap(convert(promote_type(eltype(A1), eltype(A2)), A1.λ), size(A2, 1)) - A2

0 commit comments

Comments
 (0)