Skip to content

Commit 752a157

Browse files
authored
Improve specific kron case (#139)
1 parent fc5d2aa commit 752a157

File tree

7 files changed

+60
-19
lines changed

7 files changed

+60
-19
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "LinearMaps"
22
uuid = "7a12625a-238d-50fd-b39a-03d52299707e"
3-
version = "3.2.2"
3+
version = "3.2.3"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/kronecker.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,12 @@ end
141141
mb, nb = size(B)
142142
X = reshape(x, (nb, na))
143143
Y = reshape(y, (mb, ma))
144-
if nb*ma < mb*na
144+
if B isa UniformScalingMap
145+
# the following is (maybe due to the reshape?) faster than
146+
# _unsafe_mul!(Y, B * X, At.lmap)
147+
_unsafe_mul!(Y, X, At.lmap)
148+
lmul!(B.λ, y)
149+
elseif nb*ma <= mb*na
145150
_unsafe_mul!(Y, B, X * At.lmap)
146151
else
147152
_unsafe_mul!(Y, Matrix(B*X), At.lmap)

src/uniformscalingmap.jl

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,16 @@ Base.:(*)(α::RealOrComplex, J::UniformScalingMap) = UniformScalingMap(α * J.λ
3939
Base.:(*)(J::UniformScalingMap, α::RealOrComplex) = UniformScalingMap(J.λ * α, size(J))
4040

4141
# multiplication with vector
42-
Base.:(*)(A::UniformScalingMap, x::AbstractVector) =
43-
length(x) == A.M ? A.λ * x : throw(DimensionMismatch("*"))
42+
Base.:(*)(J::UniformScalingMap, x::AbstractVector) =
43+
length(x) == J.M ? J.λ * x : throw(DimensionMismatch("*"))
44+
# multiplication with matrix
45+
Base.:(*)(J::UniformScalingMap, B::AbstractMatrix) =
46+
size(B, 1) == J.M ? J.λ * LinearMap(B) : throw(DimensionMismatch("*"))
47+
Base.:(*)(A::AbstractMatrix, J::UniformScalingMap) =
48+
size(A, 2) == J.M ? LinearMap(A) * J.λ : throw(DimensionMismatch("*"))
49+
# disambiguation
50+
Base.:(*)(xc::LinearAlgebra.AdjointAbsVec, J::UniformScalingMap) = xc * J.λ
51+
Base.:(*)(xt::LinearAlgebra.TransposeAbsVec, J::UniformScalingMap) = xt * J.λ
4452

4553
# multiplication with vector/matrix
4654
for (In, Out) in ((AbstractVector, AbstractVecOrMat), (AbstractMatrix, AbstractMatrix))
@@ -49,6 +57,11 @@ for (In, Out) in ((AbstractVector, AbstractVecOrMat), (AbstractMatrix, AbstractM
4957
_scaling!(y, J.λ, x, true, false)
5058
return y
5159
end
60+
function _unsafe_mul!(y::$Out, J::UniformScalingMap{<:RealOrComplex}, x::$In{<:RealOrComplex},
61+
α::RealOrComplex, β::Number)
62+
_scaling!(y, J.λ * α, x, true, β)
63+
return y
64+
end
5265
function _unsafe_mul!(y::$Out, J::UniformScalingMap, x::$In,
5366
α::Number, β::Number)
5467
_scaling!(y, J.λ, x, α, β)

test/kronecker.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,18 @@ using Test, LinearMaps, LinearAlgebra, SparseArrays
6666
@test Matrix(@inferred K*K) kron(A, B)*kron(A, B)
6767
A = rand(3, 2); B = rand(4, 3)
6868
@test Matrix(kron(LinearMap(A), B, [A A])*kron(LinearMap(A), B, A')) kron(A, B, [A A])*kron(A, B, A')
69+
70+
m = 3
71+
A = rand(m, m)
72+
S = sparse(I, m, m)
73+
J = LinearMap(I, m)
74+
v = rand(m^3)
75+
for (K, M) in (((A, J, J), kron(A, S, S)),
76+
((J, A, J), kron(S, A, S)),
77+
((J, J, A), kron(S, S, A)))
78+
@test K * v M * v
79+
@test Matrix(K) M
80+
end
6981
end
7082

7183
@testset "Kronecker sum" begin

test/left.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,4 +70,7 @@ end
7070
@test left_tester(W)
7171
@test left_tester(W')
7272
@test left_tester(transpose(W))
73+
74+
J = LinearMap(1.0I, 5) # UniformScalingMap
75+
@test left_tester(J)
7376
end

test/numbertypes.jl

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,30 +35,30 @@ Base.:(+)(q::Quaternion, z::Complex) = q + quat(z)
3535
@test mul!(copy(C), transpose(A), M, γ, λ) transpose(A)*A*γ + C*λ
3636
@test mul!(copy(C), adjoint(A), M, γ, λ) A'*A*γ + C*λ
3737
end
38-
@test Array((α * F')') (α * A')' A * conj(α)
38+
@test Array((α * F')') (γ * A')' A * conj(γ)
3939
@test L * x A * x
4040
@test L' * x A' * x
41-
@test α * (L * x) α * (A * x)
42-
@test α * L * x α * A * x
43-
@test L * α * x A * α * x
41+
@test α * (L * x) γ * (A * x)
42+
@test α * L * x γ * A * x
43+
@test L * α * x A * γ * x
4444
@test 3L * x 3A * x
4545
@test 3L' * x 3A' * x
4646
@test λ*L isa LinearMaps.CompositeMap
4747
@test γ ** LinearMap(B)) isa LinearMaps.CompositeMap
4848
@test* LinearMap(B)) * γ isa LinearMaps.CompositeMap
4949
@test λ*L * x λ*A * x
5050
@test λ*L' * x λ*A' * x
51-
@test α * (3L * x) α * (3A * x)
52-
@test (@inferred α * 3L) * x α * 3A * x
53-
@test (@inferred 3L * α) * x 3A * α * x
54-
@test* L') * x (α * A') * x
55-
@test* L')' * x (α * A')' * x
56-
@test* L')' * v (α * A')' * v
57-
@test Array(@inferred adjoint* L * β)) conj(β) * A' * conj(α)
58-
@test Array(@inferred transpose* L * β)) β * transpose(A) * α
51+
@test α * (3L * x) γ * (3A * x)
52+
@test (@inferred α * 3L) * x γ * 3A * x
53+
@test (@inferred 3L * α) * x 3A * γ * x
54+
@test* L') * x (γ * A') * x
55+
@test* L')' * x (γ * A')' * x
56+
@test* L')' * v (γ * A')' * v
57+
@test Array(@inferred adjoint* L * β)) conj(β) * A' * conj(γ)
58+
@test Array(@inferred transpose* L * β)) β * transpose(A) * γ
5959
J = LinearMap(α, 10)
60-
@test* J) * x LinearMap*α, 10) * x β*α*x
61-
@test (J * β) * x LinearMap*β, 10) * x α*β*x
60+
@test* J) * x LinearMap*α, 10) * x β*γ*x
61+
@test (J * β) * x LinearMap*β, 10) * x γ*β*x
6262
M = β.λ ** L * L)
6363
@test M == β ** L * L)
6464
@test length(M.maps) == 3
@@ -71,7 +71,11 @@ Base.:(+)(q::Quaternion, z::Complex) = q + quat(z)
7171
@test length(M.maps) == 3
7272
@test M.maps[1].λ == γ*β.λ
7373
@test γ*FillMap(γ, (3, 4)) == FillMap^2, (3, 4)) == FillMap(γ, (3, 4))*γ
74-
74+
U = LinearMap(quat(1.0)*I, 10)
75+
for β in (0, 1, rand())
76+
@test mul!(copy(x), J, x, γ, β) == γ * x * γ + x * β
77+
@test mul!(copy(x), U, x, γ, β) == x * γ + x * β
78+
end
7579
# exercise non-RealOrComplex scalar operations
7680
@test Array* (L'*L)) γ * (A'*A) # CompositeMap
7781
@test Array((L'*L) * γ) (A'*A) * γ

test/uniformscalingmap.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,4 +48,8 @@ using Test, LinearMaps, LinearAlgebra, BenchmarkTools
4848
end
4949
X = rand(10, 10); Y = similar(X)
5050
@test mul!(Y, Id, X) == X
51+
@test Id*X isa LinearMap
52+
@test X*Id isa LinearMap
53+
@test Matrix(Id*X) == X
54+
@test Matrix(X*Id) == X
5155
end

0 commit comments

Comments
 (0)