Skip to content

Commit a81daf1

Browse files
authored
Matrix multiplication with array-type elements (#280)
* Matrix multiplication with array-type elements * tests for Vector{Any} and use oneton instead of randn * avoid 2D contatenation * Diagonal test with different-sized blocks * test with higher nesting * zero element test with arrays
1 parent be9386c commit a81daf1

File tree

2 files changed

+136
-38
lines changed

2 files changed

+136
-38
lines changed

src/fillalgebra.jl

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,10 @@ function _mult_fill(a::AbstractFill, b::AbstractFill, ax, ::Type{Fill})
3636
return Fill(val, ax)
3737
end
3838

39-
function _mult_fill(a, b, ax, ::Type{OnesZeros}) where {OnesZeros}
40-
ElType = promote_type(eltype(a), eltype(b))
39+
function _mult_fill(a, b, ax, ::Type{OnesZeros}) where {OnesZeros<:Union{Ones,Zeros}}
40+
# This is currently only used in contexts where zero is defined
41+
# might need a rethink
42+
ElType = typeof(zero(eltype(a)) * zero(eltype(b)))
4143
return OnesZeros{ElType}(ax)
4244
end
4345

@@ -48,48 +50,49 @@ function mult_fill(a, b, T::Type = Fill)
4850
ax_result = (axes(a, 1), axes(b)[2:end]...)
4951
_mult_fill(a, b, ax_result, T)
5052
end
51-
mult_zeros(a, b) = mult_fill(a, b, Zeros)
53+
# for arrays of numbers, we assume that zero is defined for the result
54+
# in this case, we may express the result as a Zeros
55+
mult_zeros(a::AbstractArray{<:Number}, b::AbstractArray{<:Number}) = mult_fill(a, b, Zeros)
56+
# In general, we create a Fill that doesn't assume anything about the
57+
# properties of the element type
58+
mult_zeros(a, b) = mult_fill(a, b, Fill)
5259
mult_ones(a, b) = mult_fill(a, b, Ones)
5360

54-
*(a::AbstractFillVector, b::AbstractFillMatrix) = mult_fill(a,b)
5561
*(a::AbstractFillMatrix, b::AbstractFillMatrix) = mult_fill(a,b)
5662
*(a::AbstractFillMatrix, b::AbstractFillVector) = mult_fill(a,b)
5763

64+
# this treats a size (n,) vector as a nx1 matrix, so b needs to have 1 row
65+
# special cased, as OnesMatrix * OnesMatrix isn't a Ones
5866
*(a::OnesVector, b::OnesMatrix) = mult_ones(a, b)
5967

60-
*(a::ZerosVector, b::ZerosMatrix) = mult_zeros(a, b)
6168
*(a::ZerosMatrix, b::ZerosMatrix) = mult_zeros(a, b)
6269
*(a::ZerosMatrix, b::ZerosVector) = mult_zeros(a, b)
6370

64-
*(a::ZerosVector, b::AbstractFillMatrix) = mult_zeros(a, b)
6571
*(a::ZerosMatrix, b::AbstractFillMatrix) = mult_zeros(a, b)
6672
*(a::ZerosMatrix, b::AbstractFillVector) = mult_zeros(a, b)
67-
*(a::AbstractFillVector, b::ZerosMatrix) = mult_zeros(a, b)
6873
*(a::AbstractFillMatrix, b::ZerosMatrix) = mult_zeros(a, b)
6974
*(a::AbstractFillMatrix, b::ZerosVector) = mult_zeros(a, b)
7075

71-
*(a::ZerosVector, b::AbstractMatrix) = mult_zeros(a, b)
7276
*(a::ZerosMatrix, b::AbstractMatrix) = mult_zeros(a, b)
7377
*(a::AbstractMatrix, b::ZerosVector) = mult_zeros(a, b)
7478
*(a::AbstractMatrix, b::ZerosMatrix) = mult_zeros(a, b)
7579
*(a::ZerosMatrix, b::AbstractVector) = mult_zeros(a, b)
76-
*(a::AbstractVector, b::ZerosMatrix) = mult_zeros(a, b)
7780

78-
*(a::ZerosVector, b::AdjOrTransAbsVec) = mult_zeros(a, b)
79-
80-
*(a::ZerosVector, b::Diagonal) = mult_zeros(a, b)
81-
*(a::ZerosMatrix, b::Diagonal) = mult_zeros(a, b)
82-
*(a::Diagonal, b::ZerosVector) = mult_zeros(a, b)
83-
*(a::Diagonal, b::ZerosMatrix) = mult_zeros(a, b)
84-
function *(a::Diagonal, b::AbstractFillMatrix)
81+
function lmul_diag(a::Diagonal, b)
8582
size(a,2) == size(b,1) || throw(DimensionMismatch("A has dimensions $(size(a)) but B has dimensions $(size(b))"))
8683
parent(a) .* b # use special broadcast
8784
end
88-
function *(a::AbstractFillMatrix, b::Diagonal)
85+
function rmul_diag(a, b::Diagonal)
8986
size(a,2) == size(b,1) || throw(DimensionMismatch("A has dimensions $(size(a)) but B has dimensions $(size(b))"))
9087
a .* permutedims(parent(b)) # use special broadcast
9188
end
9289

90+
*(a::ZerosMatrix, b::Diagonal) = rmul_diag(a, b)
91+
*(a::Diagonal, b::ZerosVector) = lmul_diag(a, b)
92+
*(a::Diagonal, b::ZerosMatrix) = lmul_diag(a, b)
93+
*(a::Diagonal, b::AbstractFillMatrix) = lmul_diag(a, b)
94+
*(a::AbstractFillMatrix, b::Diagonal) = rmul_diag(a, b)
95+
9396
@noinline function check_matmul_sizes(A::AbstractMatrix, x::AbstractVector)
9497
Base.require_one_based_indexing(A, x)
9598
size(A,2) == size(x,1) ||
@@ -253,7 +256,18 @@ function _adjvec_mul_zeros(a, b)
253256
if la lb
254257
throw(DimensionMismatch("dot product arguments have lengths $la and $lb"))
255258
end
256-
return zero(Base.promote_op(*, eltype(a), eltype(b)))
259+
# ensure that all the elements of `a` are of the same size,
260+
# so that ∑ᵢaᵢbᵢ = b₁∑ᵢaᵢ makes sense
261+
if la == 0
262+
# this errors if a is a nested array, and zero isn't well-defined
263+
return zero(eltype(a)) * zero(eltype(b))
264+
end
265+
a1 = a[1]
266+
sza1 = size(a1)
267+
all(x -> size(x) == sza1, a) || throw(DimensionMismatch("not all elements of A are of size $sza1"))
268+
# we replace b₁∑ᵢaᵢ by b₁a₁, as we know that b₁ is zero.
269+
# Each term in the summation is zero, so the sum is equal to the first term
270+
return a1 * b[1]
257271
end
258272

259273
*(a::AdjointAbsVec{<:Any,<:ZerosVector}, b::AbstractMatrix) = (b' * a')'

test/runtests.jl

Lines changed: 104 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@ end
88

99
include("infinitearrays.jl")
1010

11+
# we may use this instead of rand(n) to generate deterministic arrays
12+
oneton(T::Type, sz...) = reshape(T.(1:prod(sz)), sz)
13+
oneton(sz...) = oneton(Float64, sz...)
14+
1115
@testset "fill array constructors and convert" begin
1216
for (Typ, funcs) in ((:Zeros, :zeros), (:Ones, :ones))
1317
@eval begin
@@ -592,36 +596,74 @@ end
592596
@test [1,2,3]*Zeros(1,3) Zeros(3,3)
593597
@test_throws MethodError [1,2,3]*Zeros(3) # Not defined for [1,2,3]*[0,0,0] either
594598

599+
@testset "Matrix multiplication with array elements" begin
600+
x = [1 2; 3 4]
601+
z = zero(SVector{2,Int})
602+
ZV = Zeros{typeof(z)}(2)
603+
A = Fill(x, 3, 2) * ZV
604+
@test A isa Fill
605+
@test size(A) == (3,)
606+
@test A[1] == x * z
607+
@test_throws DimensionMismatch Fill(x, 1, 1) * ZV
608+
@test_throws DimensionMismatch Fill(oneton(1,1), 1, length(ZV)) * ZV
609+
610+
@test_throws DimensionMismatch Ones(SMatrix{3,3,Int,9},2) * Ones(SMatrix{2,2,Int,4},1,2)
611+
end
612+
595613
@testset "Check multiplication by Adjoint vectors works as expected." begin
596-
@test randn(4, 3)' * Zeros(4) Zeros(3)
597-
@test randn(4)' * Zeros(4) transpose(randn(4)) * Zeros(4) zero(Float64)
614+
@test @inferred(oneton(4, 3)' * Zeros(4)) Zeros(3)
615+
@test @inferred(oneton(4)' * Zeros(4)) @inferred(transpose(oneton(4)) * Zeros(4)) == 0.0
598616
@test [1, 2, 3]' * Zeros{Int}(3) zero(Int)
599617
@test [SVector(1,2)', SVector(2,3)', SVector(3,4)']' * Zeros{Int}(3) === SVector(0,0)
600-
@test_throws DimensionMismatch randn(4)' * Zeros(3)
601-
@test Zeros(5)' * randn(5,3) Zeros(5)'*Zeros(5,3) Zeros(5)'*Ones(5,3) Zeros(3)'
602-
@test abs(Zeros(5)' * randn(5)) abs(Zeros(5)' * Zeros(5)) abs(Zeros(5)' * Ones(5)) 0.0
618+
@test_throws DimensionMismatch oneton(4)' * Zeros(3)
619+
@test Zeros(5)' * oneton(5,3) Zeros(5)'*Zeros(5,3) Zeros(5)'*Ones(5,3) Zeros(3)'
620+
@test abs(Zeros(5)' * oneton(5)) == abs(Zeros(5)' * Zeros(5)) abs(Zeros(5)' * Ones(5)) == 0.0
603621
@test Zeros(5) * Zeros(6)' Zeros(5,1) * Zeros(6)' Zeros(5,6)
604-
@test randn(5) * Zeros(6)' randn(5,1) * Zeros(6)' Zeros(5,6)
605-
@test Zeros(5) * randn(6)' Zeros(5,6)
622+
@test oneton(5) * Zeros(6)' oneton(5,1) * Zeros(6)' Zeros(5,6)
623+
@test Zeros(5) * oneton(6)' Zeros(5,6)
606624

607-
@test ([[1,2]])' * Zeros{SVector{2,Int}}(1) 0
608-
@test_broken ([[1,2,3]])' * Zeros{SVector{2,Int}}(1)
625+
@test @inferred(Zeros{Int}(0)' * Zeros{Int}(0)) === zero(Int)
626+
627+
@test Any[1,2.0]' * Zeros{Int}(2) == 0
628+
@test Real[1,2.0]' * Zeros{Int}(2) == 0
629+
630+
@test @inferred(([[1,2]])' * Zeros{SVector{2,Int}}(1)) 0
631+
@test ([[1,2], [1,2]])' * Zeros{SVector{2,Int}}(2) 0
632+
@test_throws DimensionMismatch ([[1,2,3]])' * Zeros{SVector{2,Int}}(1)
633+
@test_throws DimensionMismatch ([[1,2,3], [1,2]])' * Zeros{SVector{2,Int}}(2)
634+
635+
A = SMatrix{2,1,Int,2}[]'
636+
B = Zeros(SVector{2,Int},0)
637+
C = collect(B)
638+
@test @inferred(A * B) == @inferred(A * C)
609639
end
610640

611641
@testset "Check multiplication by Transpose-d vectors works as expected." begin
612-
@test transpose(randn(4, 3)) * Zeros(4) === Zeros(3)
613-
@test transpose(randn(4)) * Zeros(4) === zero(Float64)
642+
@test transpose(oneton(4, 3)) * Zeros(4) === Zeros(3)
643+
@test transpose(oneton(4)) * Zeros(4) == 0.0
614644
@test transpose([1, 2, 3]) * Zeros{Int}(3) === zero(Int)
615-
@test_throws DimensionMismatch transpose(randn(4)) * Zeros(3)
616-
@test transpose(Zeros(5)) * randn(5,3) transpose(Zeros(5))*Zeros(5,3) transpose(Zeros(5))*Ones(5,3) transpose(Zeros(3))
617-
@test abs(transpose(Zeros(5)) * randn(5)) abs(transpose(Zeros(5)) * Zeros(5)) abs(transpose(Zeros(5)) * Ones(5)) 0.0
618-
@test randn(5) * transpose(Zeros(6)) randn(5,1) * transpose(Zeros(6)) Zeros(5,6)
619-
@test Zeros(5) * transpose(randn(6)) Zeros(5,6)
620-
@test transpose(randn(5)) * Zeros(5) 0.0
621-
@test transpose(randn(5) .+ im) * Zeros(5) 0.0 + 0im
645+
@test_throws DimensionMismatch transpose(oneton(4)) * Zeros(3)
646+
@test transpose(Zeros(5)) * oneton(5,3) transpose(Zeros(5))*Zeros(5,3) transpose(Zeros(5))*Ones(5,3) transpose(Zeros(3))
647+
@test abs(transpose(Zeros(5)) * oneton(5)) abs(transpose(Zeros(5)) * Zeros(5)) abs(transpose(Zeros(5)) * Ones(5)) 0.0
648+
@test oneton(5) * transpose(Zeros(6)) oneton(5,1) * transpose(Zeros(6)) Zeros(5,6)
649+
@test Zeros(5) * transpose(oneton(6)) Zeros(5,6)
650+
@test transpose(oneton(5)) * Zeros(5) == 0.0
651+
@test transpose(oneton(5) .+ im) * Zeros(5) == 0.0 + 0im
652+
653+
@test @inferred(transpose(Zeros{Int}(0)) * Zeros{Int}(0)) === zero(Int)
622654

623-
@test transpose([[1,2]]) * Zeros{SVector{2,Int}}(1) 0
624-
@test_broken transpose([[1,2,3]]) * Zeros{SVector{2,Int}}(1)
655+
@test transpose(Any[1,2.0]) * Zeros{Int}(2) == 0
656+
@test transpose(Real[1,2.0]) * Zeros{Int}(2) == 0
657+
658+
@test @inferred(transpose([[1,2]]) * Zeros{SVector{2,Int}}(1)) 0
659+
@test transpose([[1,2], [1,2]]) * Zeros{SVector{2,Int}}(2) 0
660+
@test_throws DimensionMismatch transpose([[1,2,3]]) * Zeros{SVector{2,Int}}(1)
661+
@test_throws DimensionMismatch transpose([[1,2,3], [1,2]]) * Zeros{SVector{2,Int}}(2)
662+
663+
A = transpose(SMatrix{2,1,Int,2}[])
664+
B = Zeros(SVector{2,Int},0)
665+
C = collect(B)
666+
@test @inferred(A * B) == @inferred(A * C)
625667

626668
@testset "Diagonal mul introduced in v1.9" begin
627669
@test Zeros(5)'*Diagonal(1:5) Zeros(5)'
@@ -1386,6 +1428,48 @@ end
13861428
f = Zeros((Base.IdentityUnitRange(3:4), Base.IdentityUnitRange(3:4)))
13871429
@test_throws ArgumentError f * f
13881430

1431+
@testset "Arrays as elements" begin
1432+
SMT = SMatrix{2,2,Int,4}
1433+
SVT = SVector{2,Int}
1434+
@test @inferred(Zeros{SMT}(0,0) * Fill([1 2; 3 4], 0, 0)) == Zeros{SMT}(0,0)
1435+
@test @inferred(Zeros{SMT}(4,2) * Fill([1 2; 3 4], 2, 3)) == Zeros{SMT}(4,3)
1436+
@test @inferred(Fill([1 2; 3 4], 2, 3) * Zeros{SMT}(3, 4)) == Zeros{SMT}(2,4)
1437+
@test @inferred(Zeros{SMT}(4,2) * Fill([1, 2], 2, 3)) == Zeros{SVT}(4,3)
1438+
@test @inferred(Fill([1 2], 2, 3) * Zeros{SMT}(3,4)) == Zeros{SMatrix{1,2,Int,2}}(2,4)
1439+
1440+
TSM = SMatrix{2,2,Int,4}
1441+
s = TSM(1:4)
1442+
for n in 0:3
1443+
v = fill(s, 1)
1444+
z = zeros(TSM, n)
1445+
A = @inferred Zeros{TSM}(n) * Diagonal(v)
1446+
B = z * Diagonal(v)
1447+
@test A == B
1448+
1449+
w = fill(s, n)
1450+
A = @inferred Diagonal(w) * Zeros{TSM}(n)
1451+
B = Diagonal(w) * z
1452+
@test A == B
1453+
1454+
A = @inferred Zeros{TSM}(2n, n) * Diagonal(w)
1455+
B = zeros(TSM, 2n, n) * Diagonal(w)
1456+
@test A == B
1457+
1458+
A = @inferred Diagonal(w) * Zeros{TSM}(n, 2n)
1459+
B = Diagonal(w) * zeros(TSM, n, 2n)
1460+
@test A == B
1461+
end
1462+
1463+
D = Diagonal([[1 2; 3 4], [1 2 3; 4 5 6]])
1464+
@test @inferred(Zeros(TSM, 2,2) * D) == zeros(TSM, 2,2) * D
1465+
1466+
# doubly nested
1467+
A = [[[1,2]]]'
1468+
Z = Zeros(SMatrix{1,1,SMatrix{2,2,Int,4},1},1)
1469+
Z2 = zeros(SMatrix{1,1,SMatrix{2,2,Int,4},1},1)
1470+
@test A * Z == A * Z2
1471+
end
1472+
13891473
for W in (zeros(3,4), @MMatrix zeros(3,4))
13901474
mW, nW = size(W)
13911475
@test mul!(W, Fill(2,mW,5), Fill(3,5,nW)) Fill(30,mW,nW) fill(2,mW,5) * fill(3,5,nW)

0 commit comments

Comments
 (0)