Skip to content

Commit 5b02d4d

Browse files
authored
Diagonal * AbstractFill matrix special case (#108)
* Diagonal * AbstractFill matrix special case * Fix Fill * Fill * Test on 1.5 * Fixes for Julia v1.5 * Adjoint of StaticArray test
1 parent 5a73c7c commit 5b02d4d

File tree

5 files changed

+41
-25
lines changed

5 files changed

+41
-25
lines changed

Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "FillArrays"
22
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
3-
version = "0.9"
3+
version = "0.9.1"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -12,7 +12,8 @@ julia = "1"
1212

1313
[extras]
1414
Base64 = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
15+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1516
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1617

1718
[targets]
18-
test = ["Test", "Base64"]
19+
test = ["Test", "Base64", "StaticArrays"]

src/FillArrays.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import Base: size, getindex, setindex!, IndexStyle, checkbounds, convert,
99
show
1010

1111
import LinearAlgebra: rank, svdvals!, tril, triu, tril!, triu!, diag, transpose, adjoint, fill!,
12-
norm2, norm1, normInf, normMinusInf, normp, lmul!, rmul!, diagzero, AbstractTriangular
12+
norm2, norm1, normInf, normMinusInf, normp, lmul!, rmul!, diagzero, AbstractTriangular, AdjointAbsVec
1313

1414
import Base.Broadcast: broadcasted, DefaultArrayStyle, broadcast_shape
1515

src/fillalgebra.jl

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,25 +36,20 @@ end
3636
function mult_fill(a::AbstractFill, b::AbstractFill{<:Any,2})
3737
axes(a, 2) axes(b, 1) &&
3838
throw(DimensionMismatch("Incompatible matrix multiplication dimensions"))
39-
return Fill(getindex_value(a)*getindex_value(b), (axes(a, 1), axes(b, 2)))
39+
return Fill(getindex_value(a)*getindex_value(b)*size(a,2), (axes(a, 1), axes(b, 2)))
4040
end
4141

4242
function mult_fill(a::AbstractFill, b::AbstractFill{<:Any,1})
4343
axes(a, 2) axes(b, 1) &&
4444
throw(DimensionMismatch("Incompatible matrix multiplication dimensions"))
45-
return Fill(getindex_value(a)*getindex_value(b), (axes(a, 1),))
45+
return Fill(getindex_value(a)*getindex_value(b)*size(a,2), (axes(a, 1),))
4646
end
4747

48-
function mult_ones(a, b::AbstractMatrix)
48+
function mult_ones(a::AbstractVector, b::AbstractMatrix)
4949
axes(a, 2) axes(b, 1) &&
5050
throw(DimensionMismatch("Incompatible matrix multiplication dimensions"))
5151
return Ones{promote_type(eltype(a), eltype(b))}((axes(a, 1), axes(b, 2)))
5252
end
53-
function mult_ones(a, b::AbstractVector)
54-
axes(a, 2) axes(b, 1) &&
55-
throw(DimensionMismatch("Incompatible matrix multiplication dimensions"))
56-
return Ones{promote_type(eltype(a), eltype(b))}((axes(a, 1),))
57-
end
5853

5954
function mult_zeros(a, b::AbstractMatrix)
6055
axes(a, 2) axes(b, 1) &&
@@ -72,8 +67,6 @@ end
7267
*(a::AbstractFill{<:Any,2}, b::AbstractFill{<:Any,1}) = mult_fill(a,b)
7368

7469
*(a::Ones{<:Any,1}, b::Ones{<:Any,2}) = mult_ones(a, b)
75-
*(a::Ones{<:Any,2}, b::Ones{<:Any,2}) = mult_ones(a, b)
76-
*(a::Ones{<:Any,2}, b::Ones{<:Any,1}) = mult_ones(a, b)
7770

7871
*(a::Zeros{<:Any,1}, b::Zeros{<:Any,2}) = mult_zeros(a, b)
7972
*(a::Zeros{<:Any,2}, b::Zeros{<:Any,2}) = mult_zeros(a, b)
@@ -98,6 +91,14 @@ end
9891
*(a::Zeros{<:Any,2}, b::Diagonal) = mult_zeros(a, b)
9992
*(a::Diagonal, b::Zeros{<:Any,1}) = mult_zeros(a, b)
10093
*(a::Diagonal, b::Zeros{<:Any,2}) = mult_zeros(a, b)
94+
function *(a::Diagonal, b::AbstractFill{<:Any,2})
95+
size(a,2) == size(b,1) || throw(DimensionMismatch("A has dimensions $(size(a)) but B has dimensions $(size(b))"))
96+
a.diag .* b # use special broadcast
97+
end
98+
function *(a::AbstractFill{<:Any,2}, b::Diagonal)
99+
size(a,2) == size(b,1) || throw(DimensionMismatch("A has dimensions $(size(a)) but B has dimensions $(size(b))"))
100+
a .* permutedims(b.diag) # use special broadcast
101+
end
101102

102103
*(a::Adjoint{T, <:StridedMatrix{T}}, b::Fill{T, 1}) where T = reshape(sum(conj.(parent(a)); dims=1) .* b.value, size(parent(a), 2))
103104
*(a::Transpose{T, <:StridedMatrix{T}}, b::Fill{T, 1}) where T = reshape(sum(parent(a); dims=1) .* b.value, size(parent(a), 2))
@@ -120,13 +121,16 @@ function *(a::StridedMatrix{T}, b::Fill{T, 2}) where T
120121
fill!(fB, b.value)
121122
return a*fB
122123
end
123-
function *(a::Adjoint{T, <:AbstractVector{T}}, b::Zeros{S, 1}) where {T, S}
124+
function _adjvec_mul_zeros(a::Adjoint{T}, b::Zeros{S, 1}) where {T, S}
124125
la, lb = length(a), length(b)
125126
if la lb
126127
throw(DimensionMismatch("dot product arguments have lengths $la and $lb"))
127128
end
128-
return zero(promote_type(T, S))
129+
return zero(Base.promote_op(*, T, S))
129130
end
131+
132+
*(a::AdjointAbsVec, b::Zeros{<:Any, 1}) = _adjvec_mul_zeros(a, b)
133+
*(a::AdjointAbsVec{<:Number}, b::Zeros{<:Number, 1}) = _adjvec_mul_zeros(a, b)
130134
*(a::Adjoint{T, <:AbstractMatrix{T}} where T, b::Zeros{<:Any, 1}) = mult_zeros(a, b)
131135

132136
function *(a::Transpose{T, <:AbstractVector{T}}, b::Zeros{T, 1}) where T<:Real

src/fillbroadcast.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,4 +100,4 @@ end
100100
broadcasted(::DefaultArrayStyle{N}, op, r::AbstractFill{T,N}, x::Number) where {T,N} = Fill(op(getindex_value(r),x), size(r))
101101
broadcasted(::DefaultArrayStyle{N}, op, x::Number, r::AbstractFill{T,N}) where {T,N} = Fill(op(x, getindex_value(r)), size(r))
102102
broadcasted(::DefaultArrayStyle{N}, op, r::AbstractFill{T,N}, x::Ref) where {T,N} = Fill(op(getindex_value(r),x[]), size(r))
103-
broadcasted(::DefaultArrayStyle{N}, op, x::Ref, r::AbstractFill{T,N}) where {T,N} = Fill(op(x[], getindex_value(r)), size(r))
103+
broadcasted(::DefaultArrayStyle{N}, op, x::Ref, r::AbstractFill{T,N}) where {T,N} = Fill(op(x[], getindex_value(r)), size(r))

test/runtests.jl

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using FillArrays, LinearAlgebra, SparseArrays, Random, Base64, Test
1+
using FillArrays, LinearAlgebra, SparseArrays, StaticArrays, Random, Base64, Test
22
import FillArrays: AbstractFill, RectDiagonal, SquareEye
33

44
@testset "fill array constructors and convert" begin
@@ -354,6 +354,7 @@ end
354354
@test randn(4, 3)' * Zeros(4) === Zeros(3)
355355
@test randn(4)' * Zeros(4) === zero(Float64)
356356
@test [1, 2, 3]' * Zeros{Int}(3) === zero(Int)
357+
@test [SVector(1,2)', SVector(2,3)', SVector(3,4)']' * Zeros{Int}(3) === SVector(0,0)
357358
@test_throws DimensionMismatch randn(4)' * Zeros(3)
358359

359360
# Check multiplication by Transpose-d vectors works as expected.
@@ -866,22 +867,22 @@ end
866867
@testset "mult" begin
867868
@test Fill(2,10)*Fill(3,1,12) == Vector(Fill(2,10))*Matrix(Fill(3,1,12))
868869
@test Fill(2,10)*Fill(3,1,12) Fill(6,10,12)
869-
@test Fill(2,3,10)*Fill(3,10,12) Fill(6,3,12)
870-
@test Fill(2,3,10)*Fill(3,10) Fill(6,3)
870+
@test Fill(2,3,10)*Fill(3,10,12) Fill(60,3,12)
871+
@test Fill(2,3,10)*Fill(3,10) Fill(60,3)
871872
@test_throws DimensionMismatch Fill(2,10)*Fill(3,2,12)
872873
@test_throws DimensionMismatch Fill(2,3,10)*Fill(3,2,12)
873874

874875
@test Ones(10)*Fill(3,1,12) Fill(3.0,10,12)
875-
@test Ones(10,3)*Fill(3,3,12) Fill(3.0,10,12)
876-
@test Ones(10,3)*Fill(3,3) Fill(3.0,10)
876+
@test Ones(10,3)*Fill(3,3,12) Fill(9.0,10,12)
877+
@test Ones(10,3)*Fill(3,3) Fill(9.0,10)
877878

878879
@test Fill(2,10)*Ones(1,12) Fill(2.0,10,12)
879-
@test Fill(2,3,10)*Ones(10,12) Fill(2.0,3,12)
880-
@test Fill(2,3,10)*Ones(10) Fill(2.0,3)
880+
@test Fill(2,3,10)*Ones(10,12) Fill(20.0,3,12)
881+
@test Fill(2,3,10)*Ones(10) Fill(20.0,3)
881882

882883
@test Ones(10)*Ones(1,12) Ones(10,12)
883-
@test Ones(3,10)*Ones(10,12) Ones(3,12)
884-
@test Ones(3,10)*Ones(10) Ones(3)
884+
@test Ones(3,10)*Ones(10,12) Fill(10.0,3,12)
885+
@test Ones(3,10)*Ones(10) Fill(10.0,3)
885886

886887
@test Zeros(10)*Fill(3,1,12) Zeros(10,12)
887888
@test Zeros(10,3)*Fill(3,3,12) Zeros(10,12)
@@ -924,6 +925,16 @@ end
924925
@test D*Zeros(1,1) Zeros(1,1)
925926
@test D*Zeros(1) Zeros(1)
926927

928+
D = Diagonal(Fill(2,10))
929+
@test D * Ones(10) Fill(2.0,10)
930+
@test D * Ones(10,5) Fill(2.0,10,5)
931+
@test Ones(5,10) * D Fill(2.0,5,10)
932+
933+
# following test is broken in Base as of Julia v1.5
934+
@test_skip @test_throws DimensionMismatch Diagonal(Fill(1,1)) * Ones(10)
935+
@test_throws DimensionMismatch Diagonal(Fill(1,1)) * Ones(10,5)
936+
@test_throws DimensionMismatch Ones(5,10) * Diagonal(Fill(1,1))
937+
927938
E = Eye(5)
928939
@test E*(1:5) 1.0:5.0
929940
@test (1:5)'E == (1.0:5)'

0 commit comments

Comments
 (0)