Skip to content

Commit 209d6e2

Browse files
authored
Reduce UniformScalingMaps under Kronecker products, perf improvements (#142)
1 parent b53903d commit 209d6e2

File tree

4 files changed

+42
-25
lines changed

4 files changed

+42
-25
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "LinearMaps"
22
uuid = "7a12625a-238d-50fd-b39a-03d52299707e"
3-
version = "3.2.3"
3+
version = "3.2.4"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/kronecker.jl

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ Base.kron(A::KroneckerMap, B::KroneckerMap) =
5656
Base.kron(A::ScaledMap, B::LinearMap) = A.λ * kron(A.lmap, B)
5757
Base.kron(A::LinearMap, B::ScaledMap) = kron(A, B.lmap) * B.λ
5858
Base.kron(A::ScaledMap, B::ScaledMap) = (A.λ * B.λ) * kron(A.lmap, B.lmap)
59+
# reduce UniformScalingMaps
60+
Base.kron(A::UniformScalingMap, B::UniformScalingMap) = UniformScalingMap(A.λ * B.λ, A.M * B.M)
5961
# disambiguation
6062
Base.kron(A::ScaledMap, B::KroneckerMap) = A.λ * kron(A.lmap, B)
6163
Base.kron(A::KroneckerMap, B::ScaledMap) = kron(A, B.lmap) * B.λ
@@ -112,44 +114,46 @@ Base.:(==)(A::KroneckerMap, B::KroneckerMap) = (eltype(A) == eltype(B) && A.maps
112114
# multiplication helper functions
113115
#################
114116

115-
@inline function _kronmul!(y, B, x, At, T)
116-
na, ma = size(At)
117+
@inline function _kronmul!(y, B, x, A, T)
118+
ma, na = size(A)
117119
mb, nb = size(B)
118120
X = reshape(x, (nb, na))
119-
v = zeros(T, ma)
120-
temp1 = similar(y, na)
121-
temp2 = similar(y, nb)
122-
@views @inbounds for i in 1:ma
123-
v[i] = one(T)
124-
_unsafe_mul!(temp1, At, v)
125-
_unsafe_mul!(temp2, X, temp1)
126-
_unsafe_mul!(y[((i-1)*mb+1):i*mb], B, temp2)
127-
v[i] = zero(T)
121+
Y = reshape(y, (mb, ma))
122+
if B isa UniformScalingMap
123+
_unsafe_mul!(Y, X, transpose(A))
124+
lmul!(B.λ, y)
125+
else
126+
temp = similar(Y, (ma, nb))
127+
_unsafe_mul!(temp, A, copy(transpose(X)))
128+
_unsafe_mul!(Y, B, transpose(temp))
128129
end
129130
return y
130131
end
131-
@inline function _kronmul!(y, B, x, At::UniformScalingMap, _)
132-
na, ma = size(At)
132+
@inline function _kronmul!(y, B, x, A::UniformScalingMap, _)
133+
ma, na = size(A)
133134
mb, nb = size(B)
135+
iszero(A.λ) && return fill!(y, zero(eltype(y)))
134136
X = reshape(x, (nb, na))
135137
Y = reshape(y, (mb, ma))
136-
_unsafe_mul!(Y, B, X, At.λ, false)
138+
_unsafe_mul!(Y, B, X)
139+
!isone(A.λ) && rmul!(y, A.λ)
137140
return y
138141
end
139-
@inline function _kronmul!(y, B, x, At::MatrixMap, _)
140-
na, ma = size(At)
142+
@inline function _kronmul!(y, B, x, A::MatrixMap, _)
143+
ma, na = size(A)
141144
mb, nb = size(B)
142145
X = reshape(x, (nb, na))
143146
Y = reshape(y, (mb, ma))
147+
At = transpose(A.lmap)
144148
if B isa UniformScalingMap
145-
# the following is (maybe due to the reshape?) faster than
146-
# _unsafe_mul!(Y, B * X, At.lmap)
147-
_unsafe_mul!(Y, X, At.lmap)
149+
# the following is (perhaps due to the reshape?) faster than
150+
# _unsafe_mul!(Y, B * X, At)
151+
_unsafe_mul!(Y, X, At)
148152
lmul!(B.λ, y)
149153
elseif nb*ma <= mb*na
150-
_unsafe_mul!(Y, B, X * At.lmap)
154+
_unsafe_mul!(Y, B, X * At)
151155
else
152-
_unsafe_mul!(Y, Matrix(B*X), At.lmap)
156+
_unsafe_mul!(Y, Matrix(B*X), At)
153157
end
154158
return y
155159
end
@@ -163,14 +167,14 @@ const KroneckerMap2{T} = KroneckerMap{T, <:Tuple{LinearMap, LinearMap}}
163167
function _unsafe_mul!(y::AbstractVecOrMat, L::KroneckerMap2, x::AbstractVector)
164168
require_one_based_indexing(y)
165169
A, B = L.maps
166-
_kronmul!(y, B, x, transpose(A), eltype(L))
170+
_kronmul!(y, B, x, A, eltype(L))
167171
return y
168172
end
169173
function _unsafe_mul!(y::AbstractVecOrMat, L::KroneckerMap, x::AbstractVector)
170174
require_one_based_indexing(y)
171175
A = first(L.maps)
172176
B = kron(Base.tail(L.maps)...)
173-
_kronmul!(y, B, x, transpose(A), eltype(L))
177+
_kronmul!(y, B, x, A, eltype(L))
174178
return y
175179
end
176180
# mixed-product rule, prefer the right if possible

test/composition.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,4 +124,12 @@ using Test, LinearMaps, LinearAlgebra, SparseArrays
124124
@test u1 == u2
125125
@test w1 == w2
126126
end
127+
L1 = LinearMap(rand(2,3))
128+
L2 = LinearMap(rand(4,2))
129+
L3 = LinearMap(rand(3, 4))
130+
L4 = LinearMap(rand(5, 3))
131+
Ls = L4*L3*L2*L1
132+
X = rand(size(Ls, 2), 10)
133+
Y = similar(X, (size(Ls, 1), size(X, 2)))
134+
@test mul!(Y, Ls, X) L4.lmap * L3.lmap * L2.lmap * L1.lmap * X
127135
end

test/kronecker.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,12 +69,17 @@ using Test, LinearMaps, LinearAlgebra, SparseArrays
6969

7070
m = 3
7171
A = rand(m, m)
72+
F = LinearMap(x -> A*x, m, m)
7273
S = sparse(I, m, m)
7374
J = LinearMap(I, m)
75+
@test kron(J, J) == LinearMap(I, m*m)
7476
v = rand(m^3)
7577
for (K, M) in (((A, J, J), kron(A, S, S)),
7678
((J, A, J), kron(S, A, S)),
77-
((J, J, A), kron(S, S, A)))
79+
((J, J, A), kron(S, S, A)),
80+
((F, J, J), kron(A, S, S)),
81+
((J, F, J), kron(S, A, S)),
82+
((J, J, F), kron(S, S, A)))
7883
@test K * v M * v
7984
@test Matrix(K) M
8085
end

0 commit comments

Comments
 (0)