Skip to content

Commit 2915481

Browse files
fix tests
1 parent ba49d16 commit 2915481

File tree

1 file changed

+70
-14
lines changed

1 file changed

+70
-14
lines changed

src/fillalgebra.jl

Lines changed: 70 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
const FillVector{F,A} = Fill{F,1,A}
2+
const FillMatrix{F,A} = Fill{F,2,A}
3+
const OnesVector{F,A} = Ones{F,1,A}
4+
const OnesMatrix{F,A} = Ones{F,2,A}
5+
const ZerosVector{F,A} = Zeros{F,1,A}
6+
const ZerosMatrix{F,A} = Zeros{F,2,A}
7+
18
## vec
29

310
vec(a::Ones{T}) where T = Ones{T}(length(a))
@@ -87,11 +94,22 @@ end
8794
*(a::Zeros{<:Any,2}, b::Diagonal) = mult_zeros(a, b)
8895
*(a::Diagonal, b::Zeros{<:Any,1}) = mult_zeros(a, b)
8996
*(a::Diagonal, b::Zeros{<:Any,2}) = mult_zeros(a, b)
90-
function *(a::Diagonal, b::AbstractFill{<:Any,2})
97+
98+
# Cannot unify following methods for Diagonal
99+
# due to ambiguity with general array mult. with fill
100+
function *(a::Diagonal, b::FillMatrix)
101+
size(a,2) == size(b,1) || throw(DimensionMismatch("A has dimensions $(size(a)) but B has dimensions $(size(b))"))
102+
a.diag .* b # use special broadcast
103+
end
104+
function *(a::FillMatrix, b::Diagonal)
105+
size(a,2) == size(b,1) || throw(DimensionMismatch("A has dimensions $(size(a)) but B has dimensions $(size(b))"))
106+
a .* permutedims(b.diag) # use special broadcast
107+
end
108+
function *(a::Diagonal, b::OnesMatrix)
91109
size(a,2) == size(b,1) || throw(DimensionMismatch("A has dimensions $(size(a)) but B has dimensions $(size(b))"))
92110
a.diag .* b # use special broadcast
93111
end
94-
function *(a::AbstractFill{<:Any,2}, b::Diagonal)
112+
function *(a::OnesMatrix, b::Diagonal)
95113
size(a,2) == size(b,1) || throw(DimensionMismatch("A has dimensions $(size(a)) but B has dimensions $(size(b))"))
96114
a .* permutedims(b.diag) # use special broadcast
97115
end
@@ -100,23 +118,61 @@ end
100118
*(a::Transpose{T, <:StridedMatrix{T}}, b::Fill{T, 1}) where T = reshape(sum(parent(a); dims=1) .* b.value, size(parent(a), 2))
101119
*(a::StridedMatrix{T}, b::Fill{T, 1}) where T = reshape(sum(a; dims=2) .* b.value, size(a, 1))
102120

103-
function *(a::Adjoint{T, <:StridedMatrix{T}}, b::Fill{T, 2}) where T
104-
fB = similar(parent(a), size(b, 1), size(b, 2))
105-
fill!(fB, b.value)
106-
return a*fB
121+
function *(x::AbstractMatrix, f::FillMatrix)
122+
axes(x, 2) axes(f, 1) &&
123+
throw(DimensionMismatch("Incompatible matrix multiplication dimensions"))
124+
m = size(f, 2)
125+
repeat(sum(x, dims=2) * f.value, 1, m)
126+
end
127+
128+
function *(f::FillMatrix, x::AbstractMatrix)
129+
axes(f, 2) axes(x, 1) &&
130+
throw(DimensionMismatch("Incompatible matrix multiplication dimensions"))
131+
m = size(f, 1)
132+
repeat(sum(x, dims=1) * f.value, m, 1)
107133
end
108134

109-
function *(a::Transpose{T, <:StridedMatrix{T}}, b::Fill{T, 2}) where T
110-
fB = similar(parent(a), size(b, 1), size(b, 2))
111-
fill!(fB, b.value)
112-
return a*fB
135+
function *(x::AbstractMatrix, f::OnesMatrix)
136+
axes(x, 2) axes(f, 1) &&
137+
throw(DimensionMismatch("Incompatible matrix multiplication dimensions"))
138+
m = size(f, 2)
139+
repeat(sum(x, dims=2) * one(eltype(f)), 1, m)
113140
end
114141

115-
function *(a::StridedMatrix{T}, b::Fill{T, 2}) where T
116-
fB = similar(a, size(b, 1), size(b, 2))
117-
fill!(fB, b.value)
118-
return a*fB
142+
function *(f::OnesMatrix, x::AbstractMatrix)
143+
axes(f, 2) axes(x, 1) &&
144+
throw(DimensionMismatch("Incompatible matrix multiplication dimensions"))
145+
m = size(f, 1)
146+
repeat(sum(x, dims=1) * one(eltype(f)), m, 1)
119147
end
148+
149+
*(x::FillMatrix, y::FillMatrix) = mult_fill(x, y)
150+
*(x::FillMatrix, y::OnesMatrix) = mult_fill(x, y)
151+
*(x::OnesMatrix, y::FillMatrix) = mult_fill(x, y)
152+
*(x::OnesMatrix, y::OnesMatrix) = mult_fill(x, y)
153+
*(x::ZerosMatrix, y::OnesMatrix) = mult_zeros(x, y)
154+
*(x::ZerosMatrix, y::FillMatrix) = mult_zeros(x, y)
155+
*(x::FillMatrix, y::ZerosMatrix) = mult_zeros(x, y)
156+
*(x::OnesMatrix, y::ZerosMatrix) = mult_zeros(x, y)
157+
158+
# function *(a::Adjoint{T, <:StridedMatrix{T}}, b::Fill{T, 2}) where T
159+
# fB = similar(parent(a), size(b, 1), size(b, 2))
160+
# fill!(fB, b.value)
161+
# return a*fB
162+
# end
163+
164+
# function *(a::Transpose{T, <:StridedMatrix{T}}, b::Fill{T, 2}) where T
165+
# fB = similar(parent(a), size(b, 1), size(b, 2))
166+
# fill!(fB, b.value)
167+
# return a*fB
168+
# end
169+
170+
# function *(a::StridedMatrix{T}, b::Fill{T, 2}) where T
171+
# fB = similar(a, size(b, 1), size(b, 2))
172+
# fill!(fB, b.value)
173+
# return a*fB
174+
# end
175+
120176
function _adjvec_mul_zeros(a::Adjoint{T}, b::Zeros{S, 1}) where {T, S}
121177
la, lb = length(a), length(b)
122178
if la lb

0 commit comments

Comments
 (0)