Skip to content

Commit 1f72917

Browse files
jishnubdlfivefifty
andauthored
Fix matrix multiplication with non-commutative elements (#321)
* Fix matrix multiplication with array elements * fix fillmatrix * stridedvec * Fix for adjoint * Tests with non-zero beta * Bump version to v1.9.3 * Fix 5-term mul order and test against Quaternions * Move imports to one line --------- Co-authored-by: Sheehan Olver <[email protected]>
1 parent 373183f commit 1f72917

File tree

3 files changed

+74
-28
lines changed

3 files changed

+74
-28
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ Documenter = "1"
2626
Infinities = "0.1"
2727
LinearAlgebra = "1.6"
2828
PDMats = "0.11.17"
29+
Quaternions = "0.7"
2930
Random = "1.6"
3031
ReverseDiff = "1"
3132
SparseArrays = "1.6"
@@ -40,11 +41,12 @@ Base64 = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
4041
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
4142
Infinities = "e1ba4f0e-776d-440f-acd9-e1d2e9742647"
4243
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
44+
Quaternions = "94ee1d12-ae83-5a48-8b1c-48b8ff168ae0"
4345
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
4446
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
4547
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
4648
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
4749
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4850

4951
[targets]
50-
test = ["Aqua", "Test", "Base64", "Infinities", "PDMats", "ReverseDiff", "SparseArrays", "StaticArrays", "Statistics", "Documenter"]
52+
test = ["Aqua", "Test", "Base64", "Infinities", "PDMats", "ReverseDiff", "SparseArrays", "StaticArrays", "Statistics", "Quaternions", "Documenter"]

src/fillalgebra.jl

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -150,60 +150,56 @@ end
150150
function mul!(y::AbstractVector, A::AbstractFillMatrix, b::AbstractFillVector, alpha::Number, beta::Number)
151151
check_matmul_sizes(y, A, b)
152152

153-
αAb = alpha * getindex_value(A) * getindex_value(b) * length(b)
153+
Abα = Ref(getindex_value(A) * getindex_value(b) * alpha * length(b))
154154

155155
if iszero(beta)
156-
y .= αAb
156+
y .= Abα
157157
else
158-
y .= αAb .+ beta .* y
158+
y .= Abα .+ y .* beta
159159
end
160160
y
161161
end
162162

163163
function mul!(y::StridedVector, A::StridedMatrix, b::AbstractFillVector, alpha::Number, beta::Number)
164164
check_matmul_sizes(y, A, b)
165165

166-
αb = alpha * getindex_value(b)
166+
= Ref(getindex_value(b) * alpha)
167167

168168
if iszero(beta)
169-
y .= zero(eltype(y))
170-
for col in eachcol(A)
171-
y .+= αb .* col
172-
end
169+
y .= Ref(zero(eltype(y)))
173170
else
174-
lmul!(beta, y)
175-
for col in eachcol(A)
176-
y .+= αb .* col
177-
end
171+
rmul!(y, beta)
172+
end
173+
for Acol in eachcol(A)
174+
@. y += Acol *
178175
end
179176
y
180177
end
181178

182179
function mul!(y::StridedVector, A::AbstractFillMatrix, b::StridedVector, alpha::Number, beta::Number)
183180
check_matmul_sizes(y, A, b)
184181

185-
αA = alpha * getindex_value(A)
182+
Abα = Ref(getindex_value(A) * sum(b) * alpha)
186183

187184
if iszero(beta)
188-
y .= αA .* sum(b)
185+
y .= Abα
189186
else
190-
y .= αA .* sum(b) .+ beta .* y
187+
y .= Abα .+ y .* beta
191188
end
192189
y
193190
end
194191

195-
function _mul_adjtrans!(y::AbstractVector, A::AbstractMatrix, b::AbstractVector, alpha, beta, f)
196-
α = alpha * getindex_value(b)
197-
192+
function _mul_adjtrans!(y::AbstractVector, A::AbstractMatrix, b::AbstractFillVector, alpha, beta, f)
193+
= getindex_value(b) * alpha
198194
At = f(A)
199195

200196
if iszero(beta)
201-
for (ind, col) in zip(eachindex(y), eachcol(At))
202-
y[ind] = α .* f(sum(col))
197+
for (ind, Atcol) in zip(eachindex(y), eachcol(At))
198+
y[ind] = f(sum(Atcol)) *
203199
end
204200
else
205-
for (ind, col) in zip(eachindex(y), eachcol(At))
206-
y[ind] = α .* f(sum(col)) .+ beta .* y[ind]
201+
for (ind, Atcol) in zip(eachindex(y), eachcol(At))
202+
y[ind] = f(sum(Atcol)) *.+ y[ind] .* beta
207203
end
208204
end
209205
y
@@ -218,11 +214,11 @@ end
218214

219215
function mul!(C::AbstractMatrix, A::AbstractFillMatrix, B::AbstractFillMatrix, alpha::Number, beta::Number)
220216
check_matmul_sizes(C, A, B)
221-
αAB = alpha * getindex_value(A) * getindex_value(B) * size(B,1)
217+
ABα = getindex_value(A) * getindex_value(B) * alpha * size(B,1)
222218
if iszero(beta)
223-
C .= αAB
219+
C .= ABα
224220
else
225-
C .= αAB .+ beta .* C
221+
C .= ABα .+ C .* beta
226222
end
227223
C
228224
end
@@ -248,7 +244,7 @@ _firstcol(C::Union{Adjoint, Transpose}) = view(parent(C), 1, :)
248244
function _mulfill!(C, A, B::AbstractFillMatrix, alpha, beta)
249245
check_matmul_sizes(C, A, B)
250246
if iszero(size(B,2))
251-
return lmul!(beta, C)
247+
return rmul!(C, beta)
252248
end
253249
mul!(_firstcol(C), A, view(B, :, 1), alpha, beta)
254250
copyfirstcol!(C)

test/runtests.jl

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using FillArrays, LinearAlgebra, PDMats, SparseArrays, StaticArrays, ReverseDiff, Random, Base64, Test, Statistics
1+
using FillArrays, LinearAlgebra, PDMats, SparseArrays, StaticArrays, ReverseDiff, Random, Base64, Test, Statistics, Quaternions
22
import FillArrays: AbstractFill, RectDiagonal, SquareEye
33

44
using Documenter
@@ -1558,11 +1558,25 @@ end
15581558
Z = Zeros(SMatrix{2,3,Float64,6}, 2)
15591559
@test Z' * D == Array(Z)' * D
15601560

1561+
S = SMatrix{2,3}(1:6)
1562+
A = reshape([S,2S,3S,4S],2,2)
1563+
F = Fill(S',2,2)
1564+
@test A * F == A * fill(S',size(F))
1565+
@test mul!(A * F, A, F, 2, 1) == 3 * A * fill(S',size(F))
1566+
@test F * A == fill(S',size(F)) * A
1567+
@test mul!(F * A, F, A, 2, 1) == 3 * fill(S',size(F)) * A
1568+
15611569
# doubly nested
15621570
A = [[[1,2]]]'
15631571
Z = Zeros(SMatrix{1,1,SMatrix{2,2,Int,4},1},1)
15641572
Z2 = zeros(SMatrix{1,1,SMatrix{2,2,Int,4},1},1)
15651573
@test A * Z == A * Z2
1574+
1575+
x = [1 2 3; 4 5 6]
1576+
A = reshape([x,2x,3x,4x],2,2)
1577+
F = Fill(x,2,2)
1578+
@test A' * F == A' * fill(x,size(F))
1579+
@test mul!(A' * F, A', F, 2, 1) == 3 * A' * fill(x,size(F))
15661580
end
15671581

15681582
for W in (zeros(3,4), @MMatrix zeros(3,4))
@@ -1697,6 +1711,40 @@ end
16971711
@test adjoint(A)*fillmat adjoint(A)*Array(fillmat)
16981712
end
16991713

1714+
@testset "non-commutative" begin
1715+
A = Fill(quat(rand(4)...), 2, 2)
1716+
M = Array(A)
1717+
α, β = quat(0,1,1,0), quat(1,0,0,1)
1718+
@testset "matvec" begin
1719+
f = Fill(quat(rand(4)...), size(A,2))
1720+
v = Array(f)
1721+
D = copy(v)
1722+
exp_res = M * v * α + D * β
1723+
@test mul!(copy(D), A, f, α, β) mul!(copy(D), M, v, α, β) exp_res
1724+
@test mul!(copy(D), M, f, α, β) mul!(copy(D), M, v, α, β) exp_res
1725+
@test mul!(copy(D), A, v, α, β) mul!(copy(D), M, v, α, β) exp_res
1726+
1727+
@test mul!(copy(D), M', f, α, β) mul!(copy(D), M', v, α, β) M' * v * α + D * β
1728+
@test mul!(copy(D), transpose(M), f, α, β) mul!(copy(D), transpose(M), v, α, β) transpose(M) * v * α + D * β
1729+
end
1730+
1731+
@testset "matmat" begin
1732+
B = Fill(quat(rand(4)...), 2, 2)
1733+
N = Array(B)
1734+
D = copy(N)
1735+
exp_res = M * N * α + D * β
1736+
@test mul!(copy(D), A, B, α, β) mul!(copy(D), M, N, α, β) exp_res
1737+
@test mul!(copy(D), M, B, α, β) mul!(copy(D), M, N, α, β) exp_res
1738+
@test mul!(copy(D), A, N, α, β) mul!(copy(D), M, N, α, β) exp_res
1739+
1740+
@test mul!(copy(D), M', B, α, β) mul!(copy(D), M', N, α, β) M' * N * α + D * β
1741+
@test mul!(copy(D), transpose(M), B, α, β) mul!(copy(D), transpose(M), N, α, β) transpose(M) * N * α + D * β
1742+
1743+
@test mul!(copy(D), A, N', α, β) mul!(copy(D), M, N', α, β) M * N' * α + D * β
1744+
@test mul!(copy(D), A, transpose(N), α, β) mul!(copy(D), M, transpose(N), α, β) M * transpose(N) * α + D * β
1745+
end
1746+
end
1747+
17001748
@testset "ambiguities" begin
17011749
UT33 = UpperTriangular(ones(3,3))
17021750
UT11 = UpperTriangular(ones(1,1))

0 commit comments

Comments
 (0)