Skip to content

Commit 2766105

Browse files
faster matrix-fillmatrix multiplication
cleanup
1 parent 2915481 commit 2766105

File tree

2 files changed

+36
-41
lines changed

2 files changed

+36
-41
lines changed

src/fillalgebra.jl

Lines changed: 23 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,10 @@ end
8484

8585
*(a::Zeros{<:Any,1}, b::AbstractMatrix) = mult_zeros(a, b)
8686
*(a::Zeros{<:Any,2}, b::AbstractMatrix) = mult_zeros(a, b)
87+
*(a::Zeros{<:Any,2}, b::AbstractTriangular) = mult_zeros(a, b)
8788
*(a::AbstractMatrix, b::Zeros{<:Any,1}) = mult_zeros(a, b)
8889
*(a::AbstractMatrix, b::Zeros{<:Any,2}) = mult_zeros(a, b)
90+
*(a::AbstractTriangular, b::Zeros{<:Any,2}) = mult_zeros(a, b)
8991
*(a::Zeros{<:Any,1}, b::AbstractVector) = mult_zeros(a, b)
9092
*(a::Zeros{<:Any,2}, b::AbstractVector) = mult_zeros(a, b)
9193
*(a::AbstractVector, b::Zeros{<:Any,2}) = mult_zeros(a, b)
@@ -95,66 +97,36 @@ end
9597
*(a::Diagonal, b::Zeros{<:Any,1}) = mult_zeros(a, b)
9698
*(a::Diagonal, b::Zeros{<:Any,2}) = mult_zeros(a, b)
9799

98-
# Cannot unify following methods for Diagonal
99-
# due to ambiguity with general array mult. with fill
100-
function *(a::Diagonal, b::FillMatrix)
100+
function *(a::Diagonal, b::AbstractFill{T,2}) where T
101101
size(a,2) == size(b,1) || throw(DimensionMismatch("A has dimensions $(size(a)) but B has dimensions $(size(b))"))
102102
a.diag .* b # use special broadcast
103103
end
104-
function *(a::FillMatrix, b::Diagonal)
104+
function *(a::AbstractFill{T,2}, b::Diagonal) where T
105105
size(a,2) == size(b,1) || throw(DimensionMismatch("A has dimensions $(size(a)) but B has dimensions $(size(b))"))
106106
a .* permutedims(b.diag) # use special broadcast
107107
end
108-
function *(a::Diagonal, b::OnesMatrix)
109-
size(a,2) == size(b,1) || throw(DimensionMismatch("A has dimensions $(size(a)) but B has dimensions $(size(b))"))
110-
a.diag .* b # use special broadcast
111-
end
112-
function *(a::OnesMatrix, b::Diagonal)
113-
size(a,2) == size(b,1) || throw(DimensionMismatch("A has dimensions $(size(a)) but B has dimensions $(size(b))"))
114-
a .* permutedims(b.diag) # use special broadcast
115-
end
116-
117-
*(a::Adjoint{T, <:StridedMatrix{T}}, b::Fill{T, 1}) where T = reshape(sum(conj.(parent(a)); dims=1) .* b.value, size(parent(a), 2))
118-
*(a::Transpose{T, <:StridedMatrix{T}}, b::Fill{T, 1}) where T = reshape(sum(parent(a); dims=1) .* b.value, size(parent(a), 2))
119-
*(a::StridedMatrix{T}, b::Fill{T, 1}) where T = reshape(sum(a; dims=2) .* b.value, size(a, 1))
120108

121-
function *(x::AbstractMatrix, f::FillMatrix)
109+
function mult_sum2(x::AbstractMatrix, f::AbstractFill{T,2}) where T
122110
axes(x, 2) axes(f, 1) &&
123111
throw(DimensionMismatch("Incompatible matrix multiplication dimensions"))
124112
m = size(f, 2)
125-
repeat(sum(x, dims=2) * f.value, 1, m)
113+
repeat(sum(x, dims=2) * getindex_value(f), 1, m)
126114
end
127115

128-
function *(f::FillMatrix, x::AbstractMatrix)
116+
function mult_sum1(f::AbstractFill{T,2}, x::AbstractMatrix) where T
129117
axes(f, 2) axes(x, 1) &&
130118
throw(DimensionMismatch("Incompatible matrix multiplication dimensions"))
131119
m = size(f, 1)
132-
repeat(sum(x, dims=1) * f.value, m, 1)
120+
repeat(sum(x, dims=1) * getindex_value(f), m, 1)
133121
end
134122

135-
function *(x::AbstractMatrix, f::OnesMatrix)
136-
axes(x, 2) axes(f, 1) &&
137-
throw(DimensionMismatch("Incompatible matrix multiplication dimensions"))
138-
m = size(f, 2)
139-
repeat(sum(x, dims=2) * one(eltype(f)), 1, m)
140-
end
123+
*(x::AbstractMatrix, y::AbstractFill{<:Any,2}) = mult_sum2(x, y)
124+
*(x::AbstractTriangular, y::AbstractFill{<:Any,2}) = mult_sum2(x, y)
125+
*(x::AbstractFill{<:Any,2}, y::AbstractMatrix) = mult_sum1(x, y)
126+
*(x::AbstractFill{<:Any,2}, y::AbstractTriangular) = mult_sum1(x, y)
141127

142-
function *(f::OnesMatrix, x::AbstractMatrix)
143-
axes(f, 2) axes(x, 1) &&
144-
throw(DimensionMismatch("Incompatible matrix multiplication dimensions"))
145-
m = size(f, 1)
146-
repeat(sum(x, dims=1) * one(eltype(f)), m, 1)
147-
end
148-
149-
*(x::FillMatrix, y::FillMatrix) = mult_fill(x, y)
150-
*(x::FillMatrix, y::OnesMatrix) = mult_fill(x, y)
151-
*(x::OnesMatrix, y::FillMatrix) = mult_fill(x, y)
152-
*(x::OnesMatrix, y::OnesMatrix) = mult_fill(x, y)
153-
*(x::ZerosMatrix, y::OnesMatrix) = mult_zeros(x, y)
154-
*(x::ZerosMatrix, y::FillMatrix) = mult_zeros(x, y)
155-
*(x::FillMatrix, y::ZerosMatrix) = mult_zeros(x, y)
156-
*(x::OnesMatrix, y::ZerosMatrix) = mult_zeros(x, y)
157128

129+
### These methods are faster for small n #############
158130
# function *(a::Adjoint{T, <:StridedMatrix{T}}, b::Fill{T, 2}) where T
159131
# fB = similar(parent(a), size(b, 1), size(b, 2))
160132
# fill!(fB, b.value)
@@ -173,6 +145,16 @@ end
173145
# return a*fB
174146
# end
175147

148+
## Matrix-Vector multiplication
149+
150+
*(a::Adjoint{T, <:StridedMatrix{T}}, b::Fill{T, 1}) where T =
151+
reshape(sum(conj.(parent(a)); dims=1) .* b.value, size(parent(a), 2))
152+
*(a::Transpose{T, <:StridedMatrix{T}}, b::Fill{T, 1}) where T =
153+
reshape(sum(parent(a); dims=1) .* b.value, size(parent(a), 2))
154+
*(a::StridedMatrix{T}, b::Fill{T, 1}) where T =
155+
reshape(sum(a; dims=2) .* b.value, size(a, 1))
156+
157+
176158
function _adjvec_mul_zeros(a::Adjoint{T}, b::Zeros{S, 1}) where {T, S}
177159
la, lb = length(a), length(b)
178160
if la lb

test/runtests.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1028,6 +1028,19 @@ end
10281028
@test E*(1:5) 1.0:5.0
10291029
@test (1:5)'E == (1.0:5)'
10301030
@test E*E E
1031+
1032+
# Adjoint / Transpose / Triangular / Symmetric
1033+
for x in [transpose(rand(2, 2)),
1034+
adjoint(rand(2,2)),
1035+
UpperTriangular(rand(2,2)),
1036+
Symmetric(rand(2,2))]
1037+
@test x * Ones(2, 2) isa Matrix
1038+
@test Ones(2, 2) * x isa Matrix
1039+
@test x * Zeros(2, 2) isa Zeros
1040+
@test Zeros(2, 2) * x isa Zeros
1041+
@test x * Fill(1., 2, 2) isa Matrix
1042+
@test Fill(1., 2, 2) * x isa Matrix
1043+
end
10311044
end
10321045

10331046
@testset "count" begin

0 commit comments

Comments
 (0)