|
84 | 84 |
|
85 | 85 | *(a::Zeros{<:Any,1}, b::AbstractMatrix) = mult_zeros(a, b)
|
86 | 86 | *(a::Zeros{<:Any,2}, b::AbstractMatrix) = mult_zeros(a, b)
|
| 87 | +*(a::Zeros{<:Any,2}, b::AbstractTriangular) = mult_zeros(a, b) |
87 | 88 | *(a::AbstractMatrix, b::Zeros{<:Any,1}) = mult_zeros(a, b)
|
88 | 89 | *(a::AbstractMatrix, b::Zeros{<:Any,2}) = mult_zeros(a, b)
|
| 90 | +*(a::AbstractTriangular, b::Zeros{<:Any,2}) = mult_zeros(a, b) |
89 | 91 | *(a::Zeros{<:Any,1}, b::AbstractVector) = mult_zeros(a, b)
|
90 | 92 | *(a::Zeros{<:Any,2}, b::AbstractVector) = mult_zeros(a, b)
|
91 | 93 | *(a::AbstractVector, b::Zeros{<:Any,2}) = mult_zeros(a, b)
|
|
95 | 97 | *(a::Diagonal, b::Zeros{<:Any,1}) = mult_zeros(a, b)
|
96 | 98 | *(a::Diagonal, b::Zeros{<:Any,2}) = mult_zeros(a, b)
|
97 | 99 |
|
98 |
| -# Cannot unify following methods for Diagonal |
99 |
| -# due to ambiguity with general array mult. with fill |
100 |
| -function *(a::Diagonal, b::FillMatrix) |
| 100 | +function *(a::Diagonal, b::AbstractFill{T,2}) where T |
101 | 101 | size(a,2) == size(b,1) || throw(DimensionMismatch("A has dimensions $(size(a)) but B has dimensions $(size(b))"))
|
102 | 102 | a.diag .* b # use special broadcast
|
103 | 103 | end
|
104 |
| -function *(a::FillMatrix, b::Diagonal) |
| 104 | +function *(a::AbstractFill{T,2}, b::Diagonal) where T |
105 | 105 | size(a,2) == size(b,1) || throw(DimensionMismatch("A has dimensions $(size(a)) but B has dimensions $(size(b))"))
|
106 | 106 | a .* permutedims(b.diag) # use special broadcast
|
107 | 107 | end
|
108 |
| -function *(a::Diagonal, b::OnesMatrix) |
109 |
| - size(a,2) == size(b,1) || throw(DimensionMismatch("A has dimensions $(size(a)) but B has dimensions $(size(b))")) |
110 |
| - a.diag .* b # use special broadcast |
111 |
| -end |
112 |
| -function *(a::OnesMatrix, b::Diagonal) |
113 |
| - size(a,2) == size(b,1) || throw(DimensionMismatch("A has dimensions $(size(a)) but B has dimensions $(size(b))")) |
114 |
| - a .* permutedims(b.diag) # use special broadcast |
115 |
| -end |
116 |
| - |
117 |
| -*(a::Adjoint{T, <:StridedMatrix{T}}, b::Fill{T, 1}) where T = reshape(sum(conj.(parent(a)); dims=1) .* b.value, size(parent(a), 2)) |
118 |
| -*(a::Transpose{T, <:StridedMatrix{T}}, b::Fill{T, 1}) where T = reshape(sum(parent(a); dims=1) .* b.value, size(parent(a), 2)) |
119 |
| -*(a::StridedMatrix{T}, b::Fill{T, 1}) where T = reshape(sum(a; dims=2) .* b.value, size(a, 1)) |
120 | 108 |
|
121 |
| -function *(x::AbstractMatrix, f::FillMatrix) |
| 109 | +function mult_sum2(x::AbstractMatrix, f::AbstractFill{T,2}) where T |
122 | 110 | axes(x, 2) ≠ axes(f, 1) &&
|
123 | 111 | throw(DimensionMismatch("Incompatible matrix multiplication dimensions"))
|
124 | 112 | m = size(f, 2)
|
125 |
| - repeat(sum(x, dims=2) * f.value, 1, m) |
| 113 | + repeat(sum(x, dims=2) * getindex_value(f), 1, m) |
126 | 114 | end
|
127 | 115 |
|
128 |
| -function *(f::FillMatrix, x::AbstractMatrix) |
| 116 | +function mult_sum1(f::AbstractFill{T,2}, x::AbstractMatrix) where T |
129 | 117 | axes(f, 2) ≠ axes(x, 1) &&
|
130 | 118 | throw(DimensionMismatch("Incompatible matrix multiplication dimensions"))
|
131 | 119 | m = size(f, 1)
|
132 |
| - repeat(sum(x, dims=1) * f.value, m, 1) |
| 120 | + repeat(sum(x, dims=1) * getindex_value(f), m, 1) |
133 | 121 | end
|
134 | 122 |
|
135 |
| -function *(x::AbstractMatrix, f::OnesMatrix) |
136 |
| - axes(x, 2) ≠ axes(f, 1) && |
137 |
| - throw(DimensionMismatch("Incompatible matrix multiplication dimensions")) |
138 |
| - m = size(f, 2) |
139 |
| - repeat(sum(x, dims=2) * one(eltype(f)), 1, m) |
140 |
| -end |
| 123 | +*(x::AbstractMatrix, y::AbstractFill{<:Any,2}) = mult_sum2(x, y) |
| 124 | +*(x::AbstractTriangular, y::AbstractFill{<:Any,2}) = mult_sum2(x, y) |
| 125 | +*(x::AbstractFill{<:Any,2}, y::AbstractMatrix) = mult_sum1(x, y) |
| 126 | +*(x::AbstractFill{<:Any,2}, y::AbstractTriangular) = mult_sum1(x, y) |
141 | 127 |
|
142 |
| -function *(f::OnesMatrix, x::AbstractMatrix) |
143 |
| - axes(f, 2) ≠ axes(x, 1) && |
144 |
| - throw(DimensionMismatch("Incompatible matrix multiplication dimensions")) |
145 |
| - m = size(f, 1) |
146 |
| - repeat(sum(x, dims=1) * one(eltype(f)), m, 1) |
147 |
| -end |
148 |
| - |
149 |
| -*(x::FillMatrix, y::FillMatrix) = mult_fill(x, y) |
150 |
| -*(x::FillMatrix, y::OnesMatrix) = mult_fill(x, y) |
151 |
| -*(x::OnesMatrix, y::FillMatrix) = mult_fill(x, y) |
152 |
| -*(x::OnesMatrix, y::OnesMatrix) = mult_fill(x, y) |
153 |
| -*(x::ZerosMatrix, y::OnesMatrix) = mult_zeros(x, y) |
154 |
| -*(x::ZerosMatrix, y::FillMatrix) = mult_zeros(x, y) |
155 |
| -*(x::FillMatrix, y::ZerosMatrix) = mult_zeros(x, y) |
156 |
| -*(x::OnesMatrix, y::ZerosMatrix) = mult_zeros(x, y) |
157 | 128 |
|
| 129 | +### These methods are faster for small n ############# |
158 | 130 | # function *(a::Adjoint{T, <:StridedMatrix{T}}, b::Fill{T, 2}) where T
|
159 | 131 | # fB = similar(parent(a), size(b, 1), size(b, 2))
|
160 | 132 | # fill!(fB, b.value)
|
|
173 | 145 | # return a*fB
|
174 | 146 | # end
|
175 | 147 |
|
| 148 | +## Matrix-Vector multiplication |
| 149 | + |
| 150 | +*(a::Adjoint{T, <:StridedMatrix{T}}, b::Fill{T, 1}) where T = |
| 151 | + reshape(sum(conj.(parent(a)); dims=1) .* b.value, size(parent(a), 2)) |
| 152 | +*(a::Transpose{T, <:StridedMatrix{T}}, b::Fill{T, 1}) where T = |
| 153 | + reshape(sum(parent(a); dims=1) .* b.value, size(parent(a), 2)) |
| 154 | +*(a::StridedMatrix{T}, b::Fill{T, 1}) where T = |
| 155 | + reshape(sum(a; dims=2) .* b.value, size(a, 1)) |
| 156 | + |
| 157 | + |
176 | 158 | function _adjvec_mul_zeros(a::Adjoint{T}, b::Zeros{S, 1}) where {T, S}
|
177 | 159 | la, lb = length(a), length(b)
|
178 | 160 | if la ≠ lb
|
|
0 commit comments