Skip to content

Commit 16ec4d6

Browse files
authored
Fill multiplication (#70)
* Add Fill multiplication special cases * v0.7 * Drop Julia v0.7 support * Add Eye tests * Improved ones broadcasting
1 parent 12e245e commit 16ec4d6

File tree

6 files changed

+167
-30
lines changed

6 files changed

+167
-30
lines changed

.travis.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ os:
44
- linux
55
- osx
66
julia:
7-
- 0.7
87
- 1.0
98
- 1.1
109
- 1.2

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
name = "FillArrays"
22
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
3-
version = "0.6.4"
3+
version = "0.7"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
77
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
88
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
99

1010
[compat]
11-
julia = "0.7, 1"
11+
julia = "1"
1212

1313
[extras]
1414
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

appveyor.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
environment:
22
matrix:
3-
- julia_version: 0.7
43
- julia_version: 1
54
- julia_version: 1.1
65
- julia_version: 1.2

src/fillalgebra.jl

Lines changed: 59 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,23 +15,72 @@ adjoint(a::Fill{T,2}) where T = Fill{T}(adjoint(a.value), reverse(a.axes))
1515

1616
## Algebraic identities
1717

18+
19+
function mult_fill(a::AbstractFill, b::AbstractFill{<:Any,2})
20+
axes(a, 2) axes(b, 1) &&
21+
throw(DimensionMismatch("Incompatible matrix multiplication dimensions"))
22+
return Fill(getindex_value(a)*getindex_value(b), (axes(a, 1), axes(b, 2)))
23+
end
24+
25+
function mult_fill(a::AbstractFill, b::AbstractFill{<:Any,1})
26+
axes(a, 2) axes(b, 1) &&
27+
throw(DimensionMismatch("Incompatible matrix multiplication dimensions"))
28+
return Fill(getindex_value(a)*getindex_value(b), (axes(a, 1),))
29+
end
30+
31+
function mult_ones(a, b::AbstractMatrix)
32+
axes(a, 2) axes(b, 1) &&
33+
throw(DimensionMismatch("Incompatible matrix multiplication dimensions"))
34+
return Ones{promote_type(eltype(a), eltype(b))}((axes(a, 1), axes(b, 2)))
35+
end
36+
function mult_ones(a, b::AbstractVector)
37+
axes(a, 2) axes(b, 1) &&
38+
throw(DimensionMismatch("Incompatible matrix multiplication dimensions"))
39+
return Ones{promote_type(eltype(a), eltype(b))}((axes(a, 1),))
40+
end
41+
1842
function mult_zeros(a, b::AbstractMatrix)
19-
size(a, 2) size(b, 1) &&
43+
axes(a, 2) axes(b, 1) &&
2044
throw(DimensionMismatch("Incompatible matrix multiplication dimensions"))
21-
return Zeros{promote_type(eltype(a), eltype(b))}(size(a, 1), size(b, 2))
45+
return Zeros{promote_type(eltype(a), eltype(b))}((axes(a, 1), axes(b, 2)))
2246
end
2347
function mult_zeros(a, b::AbstractVector)
24-
size(a, 2) size(b, 1) &&
48+
axes(a, 2) axes(b, 1) &&
2549
throw(DimensionMismatch("Incompatible matrix multiplication dimensions"))
26-
return Zeros{promote_type(eltype(a), eltype(b))}(size(a, 1))
50+
return Zeros{promote_type(eltype(a), eltype(b))}((axes(a, 1),))
2751
end
2852

29-
const ZerosVecOrMat{T} = Union{Zeros{T,1}, Zeros{T,2}}
30-
*(a::ZerosVecOrMat, b::AbstractMatrix) = mult_zeros(a, b)
31-
*(a::AbstractMatrix, b::ZerosVecOrMat) = mult_zeros(a, b)
32-
*(a::ZerosVecOrMat, b::AbstractVector) = mult_zeros(a, b)
33-
*(a::AbstractVector, b::ZerosVecOrMat) = mult_zeros(a, b)
34-
*(a::ZerosVecOrMat, b::ZerosVecOrMat) = mult_zeros(a, b)
53+
*(a::AbstractFill{<:Any,1}, b::AbstractFill{<:Any,2}) = mult_fill(a,b)
54+
*(a::AbstractFill{<:Any,2}, b::AbstractFill{<:Any,2}) = mult_fill(a,b)
55+
*(a::AbstractFill{<:Any,2}, b::AbstractFill{<:Any,1}) = mult_fill(a,b)
56+
57+
*(a::Ones{<:Any,1}, b::Ones{<:Any,2}) = mult_ones(a, b)
58+
*(a::Ones{<:Any,2}, b::Ones{<:Any,2}) = mult_ones(a, b)
59+
*(a::Ones{<:Any,2}, b::Ones{<:Any,1}) = mult_ones(a, b)
60+
61+
*(a::Zeros{<:Any,1}, b::Zeros{<:Any,2}) = mult_zeros(a, b)
62+
*(a::Zeros{<:Any,2}, b::Zeros{<:Any,2}) = mult_zeros(a, b)
63+
*(a::Zeros{<:Any,2}, b::Zeros{<:Any,1}) = mult_zeros(a, b)
64+
65+
*(a::Zeros{<:Any,1}, b::AbstractFill{<:Any,2}) = mult_zeros(a, b)
66+
*(a::Zeros{<:Any,2}, b::AbstractFill{<:Any,2}) = mult_zeros(a, b)
67+
*(a::Zeros{<:Any,2}, b::AbstractFill{<:Any,1}) = mult_zeros(a, b)
68+
*(a::AbstractFill{<:Any,1}, b::Zeros{<:Any,2}) = mult_zeros(a,b)
69+
*(a::AbstractFill{<:Any,2}, b::Zeros{<:Any,2}) = mult_zeros(a,b)
70+
*(a::AbstractFill{<:Any,2}, b::Zeros{<:Any,1}) = mult_zeros(a,b)
71+
72+
*(a::Zeros{<:Any,1}, b::AbstractMatrix) = mult_zeros(a, b)
73+
*(a::Zeros{<:Any,2}, b::AbstractMatrix) = mult_zeros(a, b)
74+
*(a::AbstractMatrix, b::Zeros{<:Any,1}) = mult_zeros(a, b)
75+
*(a::AbstractMatrix, b::Zeros{<:Any,2}) = mult_zeros(a, b)
76+
*(a::Zeros{<:Any,1}, b::AbstractVector) = mult_zeros(a, b)
77+
*(a::Zeros{<:Any,2}, b::AbstractVector) = mult_zeros(a, b)
78+
*(a::AbstractVector, b::Zeros{<:Any,2}) = mult_zeros(a, b)
79+
80+
*(a::Zeros{<:Any,1}, b::Diagonal) = mult_zeros(a, b)
81+
*(a::Zeros{<:Any,2}, b::Diagonal) = mult_zeros(a, b)
82+
*(a::Diagonal, b::Zeros{<:Any,1}) = mult_zeros(a, b)
83+
*(a::Diagonal, b::Zeros{<:Any,2}) = mult_zeros(a, b)
3584

3685

3786
function *(a::Adjoint{T, <:AbstractVector{T}}, b::Zeros{S, 1}) where {T, S}

src/fillbroadcast.jl

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,34 +28,49 @@ broadcasted(::DefaultArrayStyle, ::typeof(*), a::Zeros, b::Zeros) = _broadcasted
2828

2929
broadcasted(::DefaultArrayStyle, ::typeof(*), a::Zeros, b::Ones) = _broadcasted_zeros(a, b)
3030
broadcasted(::DefaultArrayStyle, ::typeof(*), a::Zeros, b::Fill) = _broadcasted_zeros(a, b)
31-
function broadcasted(::DefaultArrayStyle, ::typeof(*), a::Zeros, b::AbstractRange)
31+
broadcasted(::DefaultArrayStyle, ::typeof(*), a::Zeros, b::AbstractRange) =
3232
return _broadcasted_zeros(a, b)
33-
end
34-
function broadcasted(::DefaultArrayStyle, ::typeof(*), a::Zeros, b::AbstractArray)
33+
broadcasted(::DefaultArrayStyle, ::typeof(*), a::Zeros, b::AbstractArray) =
3534
return _broadcasted_zeros(a, b)
36-
end
3735

3836
broadcasted(::DefaultArrayStyle, ::typeof(*), a::Ones, b::Zeros) = _broadcasted_zeros(a, b)
3937
broadcasted(::DefaultArrayStyle, ::typeof(*), a::Fill, b::Zeros) = _broadcasted_zeros(a, b)
40-
function broadcasted(::DefaultArrayStyle, ::typeof(*), a::AbstractRange, b::Zeros)
38+
broadcasted(::DefaultArrayStyle, ::typeof(*), a::AbstractRange, b::Zeros) =
4139
return _broadcasted_zeros(a, b)
42-
end
43-
function broadcasted(::DefaultArrayStyle, ::typeof(*), a::AbstractArray, b::Zeros)
40+
broadcasted(::DefaultArrayStyle, ::typeof(*), a::AbstractArray, b::Zeros) =
4441
return _broadcasted_zeros(a, b)
45-
end
42+
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), a::Zeros, b::AbstractRange) =
43+
return _broadcasted_zeros(a, b)
44+
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), a::AbstractRange, b::Zeros) =
45+
return _broadcasted_zeros(a, b)
46+
4647

4748
broadcasted(::DefaultArrayStyle, ::typeof(*), a::Ones, b::Ones) = _broadcasted_ones(a, b)
4849
broadcasted(::DefaultArrayStyle, ::typeof(/), a::Ones, b::Ones) = _broadcasted_ones(a, b)
4950
broadcasted(::DefaultArrayStyle, ::typeof(\), a::Ones, b::Ones) = _broadcasted_ones(a, b)
5051

52+
# special case due to missing converts for ranges
53+
_range_convert(::Type{AbstractVector{T}}, a::AbstractRange{T}) where T = a
54+
_range_convert(::Type{AbstractVector{T}}, a::AbstractUnitRange) where T = convert(T,first(a)):convert(T,last(a))
55+
_range_convert(::Type{AbstractVector{T}}, a::AbstractRange) where T = convert(T,first(a)):step(a):convert(T,last(a))
56+
57+
function broadcasted(::DefaultArrayStyle{1}, ::typeof(*), a::Ones{T}, b::AbstractRange{V}) where {T,V}
58+
broadcast_shape(size(a), size(b)) # Check sizes are compatible.
59+
return _range_convert(AbstractVector{promote_type(T,V)}, b)
60+
end
61+
62+
function broadcasted(::DefaultArrayStyle{1}, ::typeof(*), a::AbstractRange{V}, b::Ones{T}) where {T,V}
63+
broadcast_shape(size(a), size(b)) # Check sizes are compatible.
64+
return _range_convert(AbstractVector{promote_type(T,V)}, a)
65+
end
5166

5267

53-
function broadcasted(::DefaultArrayStyle, ::typeof(*), a::AbstractFill, b::AbstractRange)
68+
function broadcasted(::DefaultArrayStyle{1}, ::typeof(*), a::AbstractFill, b::AbstractRange)
5469
broadcast_shape(size(a), size(b)) # Check sizes are compatible.
5570
return broadcasted(*, getindex_value(a), b)
5671
end
5772

58-
function broadcasted(::DefaultArrayStyle, ::typeof(*), a::AbstractRange, b::AbstractFill)
73+
function broadcasted(::DefaultArrayStyle{1}, ::typeof(*), a::AbstractRange, b::AbstractFill)
5974
broadcast_shape(size(a), size(b)) # Check sizes are compatible.
6075
return broadcasted(*, a, getindex_value(b))
6176
end

test/runtests.jl

Lines changed: 81 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,6 @@ import FillArrays: AbstractFill, RectDiagonal
150150
end
151151
end
152152

153-
154153
@testset "RectDiagonal" begin
155154
data = 1:3
156155
expected_size = (5, 3)
@@ -189,7 +188,6 @@ end
189188
@test_throws ArgumentError mut[2, 1] = 9
190189
end
191190

192-
193191
# Check that all pair-wise combinations of + / - elements of As and Bs yield the correct
194192
# type, and produce numerically correct results.
195193
function test_addition_and_subtraction(As, Bs, Tout::Type)
@@ -313,7 +311,6 @@ end
313311

314312
end
315313

316-
317314
@testset "IndexStyle" begin
318315
@test IndexStyle(Zeros(5,5)) == IndexStyle(typeof(Zeros(5,5))) == IndexLinear()
319316
end
@@ -326,9 +323,9 @@ end
326323
@test Zeros(3, 4) * randn(4) == Zeros(3, 4) * Zeros(4) == Zeros(3)
327324
@test Zeros(3, 4) * Zeros(4, 5) === Zeros(3, 5)
328325

329-
@test [1,2,3]*Zeros(1) Zeros(3)
326+
@test_throws MethodError [1,2,3]*Zeros(1) # Not defined for [1,2,3]*[0] either
330327
@test [1,2,3]*Zeros(1,3) Zeros(3,3)
331-
@test_throws DimensionMismatch [1,2,3]*Zeros(3)
328+
@test_throws MethodError [1,2,3]*Zeros(3) # Not defined for [1,2,3]*[0,0,0] either
332329

333330
# Check multiplication by Adjoint vectors works as expected.
334331
@test randn(4, 3)' * Zeros(4) === Zeros(3)
@@ -503,6 +500,17 @@ end
503500
@test Fill(1,10) .- 1 Fill(1,10) .- Ref(1) Fill(1,10) .- Ref(1I)
504501
@test Fill([1 2; 3 4],10) .- Ref(1I) == Fill([0 2; 3 3],10)
505502
@test Ref(1I) .+ Fill([1 2; 3 4],10) == Fill([2 2; 3 5],10)
503+
504+
@testset "Special Ones" begin
505+
@test Ones{Int}(5) .* (1:5) (1:5) .* Ones{Int}(5) 1:5
506+
@test Ones(5) .* (1:5) (1:5) .* Ones(5) 1.0:5
507+
@test Ones{Int}(5) .* Ones{Int}(5) Ones{Int}(5)
508+
@test Ones{Int}(5,2) .* (1:5) == Array(Ones{Int}(5,2)) .* Array(1:5)
509+
@test (1:5) .* Ones{Int}(5,2) == Array(1:5) .* Array(Ones{Int}(5,2))
510+
@test_throws DimensionMismatch Ones{Int}(6) .* (1:5)
511+
@test_throws DimensionMismatch (1:5) .* Ones{Int}(6)
512+
@test_throws DimensionMismatch Ones{Int}(5) .* Ones{Int}(6)
513+
end
506514
end
507515

508516
@testset "Sub-arrays" begin
@@ -743,4 +751,71 @@ end
743751

744752
@test fill!(F,1) == F
745753
@test_throws ArgumentError fill!(F,2)
746-
end
754+
end
755+
756+
@testset "mult" begin
757+
@test Fill(2,10)*Fill(3,1,12) == Vector(Fill(2,10))*Matrix(Fill(3,1,12))
758+
@test Fill(2,10)*Fill(3,1,12) Fill(6,10,12)
759+
@test Fill(2,3,10)*Fill(3,10,12) Fill(6,3,12)
760+
@test Fill(2,3,10)*Fill(3,10) Fill(6,3)
761+
@test_throws DimensionMismatch Fill(2,10)*Fill(3,2,12)
762+
@test_throws DimensionMismatch Fill(2,3,10)*Fill(3,2,12)
763+
764+
@test Ones(10)*Fill(3,1,12) Fill(3.0,10,12)
765+
@test Ones(10,3)*Fill(3,3,12) Fill(3.0,10,12)
766+
@test Ones(10,3)*Fill(3,3) Fill(3.0,10)
767+
768+
@test Fill(2,10)*Ones(1,12) Fill(2.0,10,12)
769+
@test Fill(2,3,10)*Ones(10,12) Fill(2.0,3,12)
770+
@test Fill(2,3,10)*Ones(10) Fill(2.0,3)
771+
772+
@test Ones(10)*Ones(1,12) Ones(10,12)
773+
@test Ones(3,10)*Ones(10,12) Ones(3,12)
774+
@test Ones(3,10)*Ones(10) Ones(3)
775+
776+
@test Zeros(10)*Fill(3,1,12) Zeros(10,12)
777+
@test Zeros(10,3)*Fill(3,3,12) Zeros(10,12)
778+
@test Zeros(10,3)*Fill(3,3) Zeros(10)
779+
780+
@test Fill(2,10)* Zeros(1,12) Zeros(10,12)
781+
@test Fill(2,3,10)*Zeros(10,12) Zeros(3,12)
782+
@test Fill(2,3,10)*Zeros(10) Zeros(3)
783+
784+
@test Zeros(10)*Zeros(1,12) Zeros(10,12)
785+
@test Zeros(3,10)*Zeros(10,12) Zeros(3,12)
786+
@test Zeros(3,10)*Zeros(10) Zeros(3)
787+
788+
a = randn(3)
789+
A = randn(1,4)
790+
791+
@test Fill(2,3)*A == Vector(Fill(2,3))*A
792+
@test Fill(2,3,1)*A == Matrix(Fill(2,3,1))*A
793+
@test Fill(2,3,3)*a == Matrix(Fill(2,3,3))*a
794+
@test Ones(3)*A == Vector(Ones(3))*A
795+
@test Ones(3,1)*A == Matrix(Ones(3,1))*A
796+
@test Ones(3,3)*a == Matrix(Ones(3,3))*a
797+
@test Zeros(3)*A Zeros(3,4)
798+
@test Zeros(3,1)*A == Zeros(3,4)
799+
@test Zeros(3,3)*a == Zeros(3)
800+
801+
@test A*Fill(2,4) == A*Vector(Fill(2,4))
802+
@test A*Fill(2,4,1) == A*Matrix(Fill(2,4,1))
803+
@test a*Fill(2,1,3) == a*Matrix(Fill(2,1,3))
804+
@test A*Ones(4) == A*Vector(Ones(4))
805+
@test A*Ones(4,1) == A*Matrix(Ones(4,1))
806+
@test a*Ones(1,3) == a*Matrix(Ones(1,3))
807+
@test A*Zeros(4) Zeros(1)
808+
@test A*Zeros(4,1) Zeros(1,1)
809+
@test a*Zeros(1,3) Zeros(3,3)
810+
811+
D = Diagonal(randn(1))
812+
@test Zeros(1,1)*D Zeros(1,1)
813+
@test Zeros(1)*D Zeros(1,1)
814+
@test D*Zeros(1,1) Zeros(1,1)
815+
@test D*Zeros(1) Zeros(1)
816+
817+
E = Eye(5)
818+
@test E*(1:5) 1.0:5.0
819+
@test (1:5)'E == (1.0:5)'
820+
@test E*E E
821+
end

0 commit comments

Comments
 (0)