Skip to content

Commit 8f5ed36

Browse files
authored
Match LinearAlgebra's muladd order (#189)
* Match LinearAlgebra's muladd order * Bump version to 1.4.4 * Don't add _fill_rmul * Add non-BLAS muladd tests with mixed types * Add larger array tests * Tests for _fill_lmul! * diagonal tests * Fix test on v1.6 * Fill MulAdd * Don't use inplace setindex * relax allocations test * Non-commutative tests
1 parent 74af367 commit 8f5ed36

File tree

5 files changed

+164
-29
lines changed

5 files changed

+164
-29
lines changed

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ArrayLayouts"
22
uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
33
authors = ["Sheehan Olver <[email protected]>"]
4-
version = "1.4.3"
4+
version = "1.4.4"
55

66
[deps]
77
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
@@ -18,6 +18,7 @@ ArrayLayoutsSparseArraysExt = "SparseArrays"
1818
Aqua = "0.8"
1919
FillArrays = "1.2.1"
2020
LinearAlgebra = "1.6"
21+
Quaternions = "0.7"
2122
Random = "1.6"
2223
SparseArrays = "1.6"
2324
StableRNGs = "1"
@@ -26,10 +27,11 @@ julia = "1.6"
2627

2728
[extras]
2829
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
30+
Quaternions = "94ee1d12-ae83-5a48-8b1c-48b8ff168ae0"
2931
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
3032
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
3133
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
3234
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3335

3436
[targets]
35-
test = ["Aqua", "Random", "StableRNGs", "SparseArrays", "Test"]
37+
test = ["Aqua", "Random", "StableRNGs", "SparseArrays", "Test", "Quaternions"]

src/ArrayLayouts.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ LinearAlgebra.norm(A::LayoutArray, p::Real=2) = _norm(MemoryLayout(A), A, p)
299299
LinearAlgebra.norm(A::SubArray{<:Any,N,<:LayoutArray}, p::Real=2) where N = _norm(MemoryLayout(A), A, p)
300300

301301

302-
_fill_lmul!(β, A::AbstractArray{T}) where T = iszero(β) ? zero!(A) : lmul!(β, A)
302+
_fill_lmul!(β, A::AbstractArray) = iszero(β) ? zero!(A) : lmul!(β, A)
303303

304304

305305
# Elementary reflection similar to LAPACK. The reflector is not Hermitian but

src/muladd.jl

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ function tiled_blasmul!(tile_size, α, A::AbstractMatrix{T}, B::AbstractMatrix{S
9494
size(C) == (mA, nB) || throw(DimensionMismatch("Dimensions must match"))
9595

9696
(iszero(mA) || iszero(nB)) && return C
97-
iszero(nA) && return lmul!(β, C)
97+
iszero(nA) && return rmul!(C, β)
9898

9999
@inbounds begin
100100
sz = (tile_size, tile_size)
@@ -116,7 +116,7 @@ function tiled_blasmul!(tile_size, α, A::AbstractMatrix{T}, B::AbstractMatrix{S
116116
for k = 1:nA
117117
s += Atile[aoff+k] * Btile[boff+k]
118118
end
119-
C[i,j] = α*s + β*C[i,j]
119+
C[i,j] = s * α + C[i,j] * β
120120
end
121121
end
122122
else
@@ -142,7 +142,7 @@ function tiled_blasmul!(tile_size, α, A::AbstractMatrix{T}, B::AbstractMatrix{S
142142
for k = 1:klen
143143
s += Atile[aoff+k] * Btile[bcoff+k]
144144
end
145-
Ctile[bcoff+i] += α*s
145+
Ctile[bcoff+i] += s * α
146146
end
147147
end
148148
end
@@ -161,7 +161,7 @@ end
161161
@simd for ν = rowsupport(A,k) colsupport(B,j)
162162
Ctmp = @inbounds muladd(A[k, ν],B[ν, j],Ctmp)
163163
end
164-
@inbounds C[k,j] = muladd(α,Ctmp, C[k,j])
164+
@inbounds C[k,j] = muladd(Ctmp, α, C[k,j])
165165
end
166166

167167
function default_blasmul!(α, A::AbstractMatrix, B::AbstractMatrix, β, C::AbstractMatrix)
@@ -170,7 +170,7 @@ function default_blasmul!(α, A::AbstractMatrix, B::AbstractMatrix, β, C::Abstr
170170
nA == mB || throw(DimensionMismatch("Dimensions must match"))
171171
size(C) == (mA, nB) || throw(DimensionMismatch("Dimensions must match"))
172172

173-
lmul!(β, C)
173+
rmul!(C, β)
174174

175175
(iszero(mA) || iszero(nB)) && return C
176176
iszero(nA) && return C
@@ -196,7 +196,7 @@ function default_blasmul!(α, A::AbstractVector, B::AbstractMatrix, β, C::Abstr
196196
1 == mB || throw(DimensionMismatch("Dimensions must match"))
197197
size(C) == (mA, nB) || throw(DimensionMismatch("Dimensions must match"))
198198

199-
lmul!(β, C)
199+
rmul!(C, β)
200200

201201
(iszero(mA) || iszero(nB)) && return C
202202

@@ -213,7 +213,7 @@ function _default_blasmul!(::IndexLinear, α, A::AbstractMatrix, B::AbstractVect
213213
nA == mB || throw(DimensionMismatch("Dimensions must match"))
214214
length(C) == mA || throw(DimensionMismatch("Dimensions must match"))
215215

216-
lmul!(β, C)
216+
rmul!(C, β)
217217
(nA == 0 || mB == 0) && return C
218218

219219
z = zero(A[1]*B[1] + A[1]*B[1])
@@ -223,7 +223,7 @@ function _default_blasmul!(::IndexLinear, α, A::AbstractMatrix, B::AbstractVect
223223
aoffs = (k-1)*Astride
224224
b = B[k]
225225
for i = 1:mA
226-
C[i] += α * A[aoffs + i] * b
226+
C[i] += A[aoffs + i] * b * α
227227
end
228228
end
229229

@@ -236,13 +236,13 @@ function _default_blasmul!(::IndexCartesian, α, A::AbstractMatrix, B::AbstractV
236236
nA == mB || throw(DimensionMismatch("Dimensions must match"))
237237
length(C) == mA || throw(DimensionMismatch("Dimensions must match"))
238238

239-
lmul!(β, C)
239+
rmul!(C, β)
240240
(nA == 0 || mB == 0) && return C
241241

242242
z = zero(A[1,1]*B[1] + A[1,1]*B[1])
243243

244244
@inbounds for k in colsupport(B,1)
245-
b = α * B[k]
245+
b = B[k] * α
246246
for i = colsupport(A,k)
247247
C[i] += A[i,k] * b
248248
end
@@ -386,16 +386,29 @@ similar(M::MulAdd{<:DiagonalLayout,<:DiagonalLayout}, ::Type{T}, axes) where T =
386386
similar(M::MulAdd{<:DiagonalLayout}, ::Type{T}, axes) where T = similar(M.B, T, axes)
387387
similar(M::MulAdd{<:Any,<:DiagonalLayout}, ::Type{T}, axes) where T = similar(M.A, T, axes)
388388
# equivalent to rescaling
389-
function materialize!(M::MulAdd{<:DiagonalLayout{<:AbstractFillLayout}})
390-
checkdimensions(M)
391-
M.C .= (M.α * getindex_value(M.A.diag)) .* M.B .+ M.β .* M.C
392-
M.C
389+
for MatMulT in (:MatMulMatAdd, :MatMulVecAdd, :MulAdd)
390+
@eval function materialize!(M::$MatMulT{<:DiagonalLayout{<:AbstractFillLayout}})
391+
checkdimensions(M)
392+
if iszero(M.β)
393+
M.C .= Ref(getindex_value(M.A.diag)) .* M.B .* M.α
394+
else
395+
M.C .= Ref(getindex_value(M.A.diag)) .* M.B .* M.α .+ M.C .* M.β
396+
end
397+
M.C
398+
end
393399
end
394400

395-
function materialize!(M::MulAdd{<:Any,<:DiagonalLayout{<:AbstractFillLayout}})
396-
checkdimensions(M)
397-
M.C .= M.α .* M.A .* getindex_value(M.B.diag) .+ M.β .* M.C
398-
M.C
401+
for MatMulT in (:MulAdd, :VecMulMatAdd)
402+
@eval function materialize!(M::$MatMulT{<:Any,<:DiagonalLayout{<:AbstractFillLayout}})
403+
checkdimensions(M)
404+
= Ref(getindex_value(M.B.diag) * M.α)
405+
if iszero(M.β)
406+
M.C .= M.A .*
407+
else
408+
M.C .= M.A .*.+ M.C .* M.β
409+
end
410+
M.C
411+
end
399412
end
400413

401414

@@ -432,7 +445,13 @@ mulzeros(::Type{T}, M) where T<:AbstractArray = _mulzeros!(similar(Array{T}, axe
432445
# Fill
433446
###
434447

435-
copy(M::MulAdd{<:AbstractFillLayout,<:AbstractFillLayout,<:AbstractFillLayout}) = M.α*M.A*M.B + M.β*M.C
448+
function copy(M::MulAdd{<:AbstractFillLayout,<:AbstractFillLayout,<:AbstractFillLayout})
449+
if iszero(M.β)
450+
M.A * M.B * M.α
451+
else
452+
M.A * M.B * M.α + M.C * M.β
453+
end
454+
end
436455

437456
###
438457
# DualLayout

test/test_muladd.jl

Lines changed: 114 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using ArrayLayouts, FillArrays, Random, StableRNGs, LinearAlgebra, Test
1+
using ArrayLayouts, FillArrays, Random, StableRNGs, LinearAlgebra, Test, Quaternions
22
using ArrayLayouts: DenseColumnMajor, AbstractStridedLayout, AbstractColumnMajor, DiagonalLayout, mul, Mul, zero!
33

44
Random.seed!(0)
@@ -89,6 +89,23 @@ Random.seed!(0)
8989
@test mul(A,X) == A*X
9090
@test mul(X,A) == X*A
9191
end
92+
93+
@testset "Diagonal Fill" begin
94+
for (A, B) in (([1:4;], [3:6;]), (reshape([1:16;],4,4), reshape(2 .* [1:16;],4,4)))
95+
D = Diagonal(Fill(3, 4))
96+
M = MulAdd(2, D, A, 3, B)
97+
@test copy(M) == mul!(B, D, A, 2, 3)
98+
M = MulAdd(1, D, A, 0, B)
99+
@test copy(M) == mul!(B, D, A)
100+
end
101+
102+
A, B = [1:4;], reshape([3:6;], 4, 1)
103+
D = Diagonal(Fill(3, 1))
104+
M = MulAdd(2, A, D, 3, B)
105+
@test copy(M) == (VERSION >= v"1.9" ? mul!(B, A, D, 2, 3) : 2 * A * D + 3 * B)
106+
M = MulAdd(1, A, D, 0, B)
107+
@test copy(M) == (VERSION >= v"1.9" ? mul!(B, A, D) : A * D)
108+
end
92109
end
93110

94111
@testset "Matrix * Matrix" begin
@@ -98,17 +115,28 @@ Random.seed!(0)
98115
B in (randn(5,5), view(randn(5,5),:,:), view(randn(5,5),1:5,:),
99116
view(randn(5,5),1:5,1:5), view(randn(5,5),:,1:5))
100117
C = similar(B);
118+
D = similar(C);
101119

102120
C .= MulAdd(1.0,A,B,0.0,C)
103-
@test C == BLAS.gemm!('N', 'N', 1.0, A, B, 0.0, similar(C))
121+
@test C == BLAS.gemm!('N', 'N', 1.0, A, B, 0.0, D)
104122

105123
C .= MulAdd(2.0,A,B,0.0,C)
106-
@test C == BLAS.gemm!('N', 'N', 2.0, A, B, 0.0, similar(C))
124+
@test C == BLAS.gemm!('N', 'N', 2.0, A, B, 0.0, D)
107125

108126
C = copy(B)
109127
C .= MulAdd(2.0,A,B,1.0,C)
110128
@test C == BLAS.gemm!('N', 'N', 2.0, A, B, 1.0, copy(B))
111129
end
130+
131+
A, B = ones(100, 100), ones(100, 100)
132+
C = ones(100, 100)
133+
C .= MulAdd(2,A,B,1,C)
134+
@test C BLAS.gemm!('N', 'N', 2.0, A, B, 1.0, copy(B))
135+
136+
A, B = Float64[i+j for i in 1:100, j in 1:100], Float64[i+j for i in 1:100, j in 1:100]
137+
C = ones(100, 100)
138+
C .= MulAdd(2,A,B,1,C)
139+
@test_broken C BLAS.gemm!('N', 'N', 2.0, A, B, 1.0, copy(B))
112140
end
113141

114142
@testset "gemm Complex" begin
@@ -276,7 +304,8 @@ Random.seed!(0)
276304
vx = view(x,1:2)
277305
vy = view(y,:)
278306
muladd!(2.0, VA, vx, 3.0, vy)
279-
@test @allocated(muladd!(2.0, VA, vx, 3.0, vy)) == 0
307+
# spurious allocations in tests
308+
@test @allocated(muladd!(2.0, VA, vx, 3.0, vy)) < 100
280309
end
281310

282311
@testset "BigFloat" begin
@@ -680,7 +709,7 @@ Random.seed!(0)
680709
b = randn(5)
681710
c = randn(5) + im*randn(5)
682711
d = randn(5) + im*randn(5)
683-
712+
684713
@test ArrayLayouts.dot(a,b) ArrayLayouts.dotu(a,b) mul(a',b)
685714
@test ArrayLayouts.dot(a,b) dot(a,b)
686715
@test eltype(Dot(a,1:5)) == Float64
@@ -693,7 +722,7 @@ Random.seed!(0)
693722
@test ArrayLayouts.dot(c,b) == mul(c',b)
694723
@test ArrayLayouts.dotu(c,b) == mul(transpose(c),b)
695724
@test ArrayLayouts.dot(c,b) dot(c,b)
696-
725+
697726
@test ArrayLayouts.dot(a,d) == mul(a',d)
698727
@test ArrayLayouts.dotu(a,d) == mul(transpose(a),d)
699728
@test ArrayLayouts.dot(a,d) dot(a,d)
@@ -730,9 +759,88 @@ Random.seed!(0)
730759
X = randn(rng, ComplexF64, 8, 4)
731760
Y = randn(rng, 8, 2)
732761
@test mul(Y',X) Y'X
762+
763+
for A in (randn(5,5), view(randn(5,5),:,:), view(randn(5,5),1:5,:),
764+
view(randn(5,5),1:5,1:5), view(randn(5,5),:,1:5)),
765+
B in (randn(5,5), view(randn(5,5),:,:), view(randn(5,5),1:5,:),
766+
view(randn(5,5),1:5,1:5), view(randn(5,5),:,1:5))
767+
C = similar(B);
768+
D = similar(C);
769+
770+
C .= MulAdd(1,A,B,0,C)
771+
@test C BLAS.gemm!('N', 'N', 1.0, A, B, 0.0, D)
772+
773+
C = copy(B)
774+
C .= MulAdd(2,A,B,1,C)
775+
@test C BLAS.gemm!('N', 'N', 2.0, A, B, 1.0, copy(B))
776+
end
733777
end
734778

735779
@testset "Vec * Adj" begin
736780
@test ArrayLayouts.mul(1:5, (1:4)') == (1:5) * (1:4)'
737781
end
782+
783+
@testset "Fill" begin
784+
mutable struct MFillMat{T} <: FillArrays.AbstractFill{T,2,NTuple{2,Base.OneTo{Int}}}
785+
x :: T
786+
sz :: NTuple{2,Int}
787+
end
788+
MFillMat(x::T, sz::NTuple{2,Int}) where {T} = MFillMat{T}(x, sz)
789+
MFillMat(x::T, sz::Vararg{Int,2}) where {T} = MFillMat{T}(x, sz)
790+
Base.size(M::MFillMat) = M.sz
791+
FillArrays.getindex_value(M::MFillMat) = M.x
792+
Base.copyto!(M::MFillMat, A::Broadcast.Broadcasted) = (M.x = only(unique(A)); M)
793+
Base.copyto!(M::MFillMat, A::Broadcast.Broadcasted{<:Base.Broadcast.AbstractArrayStyle{0}}) = (M.x = only(unique(A)); M)
794+
795+
M = MulAdd(1, Fill(2,4,4), Fill(3,4,4), 2, MFillMat(2,4,4))
796+
X = copy(M)
797+
@test X == Fill(28,4,4)
798+
799+
M = MulAdd(1, Fill(2,4,4), Fill(3,4,4), 0, MFillMat(2,4,4))
800+
X = copy(M)
801+
@test X == Fill(24,4,4)
802+
end
803+
804+
@testset "non-commutative" begin
805+
A = [quat(rand(4)...) for i in 1:4, j in 1:4]
806+
B = [quat(rand(4)...) for i in 1:4, j in 1:4]
807+
C = [quat(rand(4)...) for i in 1:4, j in 1:4]
808+
α, β = quat(0,0,0,1), quat(0,1,0,0)
809+
M = MulAdd(α, A, B, β, C)
810+
@test copy(M) mul!(copy(C), A, B, α, β) A * B * α + C * β
811+
812+
SA = Symmetric(A)
813+
M = MulAdd(α, SA, B, β, C)
814+
@test copy(M) mul!(copy(C), SA, B, α, β) SA * B * α + C * β
815+
816+
B = [quat(rand(4)...) for i in 1:4]
817+
C = [quat(rand(4)...) for i in 1:4]
818+
M = MulAdd(α, A, B, β, C)
819+
@test copy(M) mul!(copy(C), A, B, α, β) A * B * α + C * β
820+
821+
M = MulAdd(α, SA, B, β, C)
822+
@test copy(M) mul!(copy(C), SA, B, α, β) SA * B * α + C * β
823+
824+
A = [quat(rand(4)...) for i in 1:4]
825+
B = [quat(rand(4)...) for i in 1:1, j in 1:1]
826+
C = [quat(rand(4)...) for i in 1:4, j in 1:1]
827+
M = MulAdd(α, A, B, β, C)
828+
@test copy(M) mul!(copy(C), A, B, α, β) A * B * α + C * β
829+
830+
D = Diagonal(Fill(quat(rand(4)...), 4))
831+
b = [quat(rand(4)...) for i in 1:4]
832+
c = [quat(rand(4)...) for i in 1:4]
833+
M = MulAdd(α, D, b, β, c)
834+
@test copy(M) mul!(copy(c), D, b, α, β) D * b * α + c * β
835+
836+
D = Diagonal(Fill(quat(rand(4)...), 1))
837+
b = [quat(rand(4)...) for i in 1:4]
838+
c = [quat(rand(4)...) for i in 1:4, j in 1:1]
839+
M = MulAdd(α, b, D, β, c)
840+
if VERSION >= v"1.9"
841+
@test copy(M) mul!(copy(c), b, D, α, β) b * D * α + c * β
842+
else
843+
@test copy(M) b * D * α + c * β
844+
end
845+
end
738846
end

test/test_utils.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,5 +34,11 @@ using ArrayLayouts, LinearAlgebra, FillArrays, Test
3434
@test ArrayLayouts._copy_oftype(v, S) !== v
3535
end
3636
end
37+
38+
A = [1 3; 2 4]
39+
ArrayLayouts._fill_lmul!(2.0, A)
40+
@test A == 2 * [1 3; 2 4]
41+
ArrayLayouts._fill_lmul!(0, A)
42+
@test all(==(0), A)
3743
end
38-
end
44+
end

0 commit comments

Comments
 (0)