Skip to content

Commit e3e171b

Browse files
authored
Overload reshape, Add special broadcasting for number and Zeros, add \ and / (#77)
* Add special broadcasting for number and Zeros, add \ and / * Overload reshape * v0.8
1 parent a406175 commit e3e171b

File tree

4 files changed

+61
-27
lines changed

4 files changed

+61
-27
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "FillArrays"
22
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
3-
version = "0.7.4"
3+
version = "0.8"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/FillArrays.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using LinearAlgebra, SparseArrays
44
import Base: size, getindex, setindex!, IndexStyle, checkbounds, convert,
55
+, -, *, /, \, diff, sum, cumsum, maximum, minimum, sort, sort!,
66
any, all, axes, isone, iterate, unique, allunique, permutedims, inv,
7-
copy, vec, setindex!, count, ==
7+
copy, vec, setindex!, count, ==, reshape, _throw_dmrs
88

99
import LinearAlgebra: rank, svdvals!, tril, triu, tril!, triu!, diag, transpose, adjoint, fill!,
1010
norm2, norm1, normInf, normMinusInf, normp
@@ -159,7 +159,14 @@ end
159159
-(a::AbstractFill, b::AbstractRange) = a + (-b)
160160
-(a::AbstractRange, b::AbstractFill) = a + (-b)
161161

162+
function fill_reshape(parent, dims::Integer...)
163+
n = length(parent)
164+
prod(dims) == n || _throw_dmrs(n, "size", dims)
165+
Fill(getindex_value(parent), dims...)
166+
end
162167

168+
reshape(parent::AbstractFill, dims::Integer...) = fill_reshape(parent, dims...)
169+
reshape(parent::AbstractFill, dims::Int...) = fill_reshape(parent, dims...)
163170

164171
for (Typ, funcs, func) in ((:Zeros, :zeros, :zero), (:Ones, :ones, :one))
165172
@eval begin
@@ -211,6 +218,12 @@ for (Typ, funcs, func) in ((:Zeros, :zeros, :zero), (:Ones, :ones, :one))
211218
size(A) == size(kr) || throw(DimensionMismatch())
212219
$Typ{T}(count(kr))
213220
end
221+
222+
function fill_reshape(parent::$Typ{T}, dims::Integer...) where T
223+
n = length(parent)
224+
prod(dims) == n || _throw_dmrs(n, "size", dims)
225+
$Typ{T}(dims...)
226+
end
214227
end
215228
end
216229

@@ -464,4 +477,5 @@ include("fillbroadcast.jl")
464477
Base.replace_in_print_matrix(::Zeros, ::Integer, ::Integer, s::AbstractString) =
465478
Base.replace_with_centered_mark(s)
466479

480+
467481
end # module

src/fillbroadcast.jl

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,23 +26,27 @@ broadcasted(::DefaultArrayStyle, ::typeof(+), a::Zeros, b::Ones) = _broadcasted_
2626

2727
broadcasted(::DefaultArrayStyle, ::typeof(*), a::Zeros, b::Zeros) = _broadcasted_zeros(a, b)
2828

29-
broadcasted(::DefaultArrayStyle, ::typeof(*), a::Zeros, b::Ones) = _broadcasted_zeros(a, b)
30-
broadcasted(::DefaultArrayStyle, ::typeof(*), a::Zeros, b::Fill) = _broadcasted_zeros(a, b)
31-
broadcasted(::DefaultArrayStyle, ::typeof(*), a::Zeros, b::AbstractRange) =
32-
return _broadcasted_zeros(a, b)
33-
broadcasted(::DefaultArrayStyle, ::typeof(*), a::Zeros, b::AbstractArray) =
34-
return _broadcasted_zeros(a, b)
35-
36-
broadcasted(::DefaultArrayStyle, ::typeof(*), a::Ones, b::Zeros) = _broadcasted_zeros(a, b)
37-
broadcasted(::DefaultArrayStyle, ::typeof(*), a::Fill, b::Zeros) = _broadcasted_zeros(a, b)
38-
broadcasted(::DefaultArrayStyle, ::typeof(*), a::AbstractRange, b::Zeros) =
39-
return _broadcasted_zeros(a, b)
40-
broadcasted(::DefaultArrayStyle, ::typeof(*), a::AbstractArray, b::Zeros) =
41-
return _broadcasted_zeros(a, b)
42-
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), a::Zeros, b::AbstractRange) =
43-
return _broadcasted_zeros(a, b)
44-
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), a::AbstractRange, b::Zeros) =
45-
return _broadcasted_zeros(a, b)
29+
for op in (:*, :/)
30+
@eval begin
31+
broadcasted(::DefaultArrayStyle, ::typeof($op), a::Zeros, b::Ones) = _broadcasted_zeros(a, b)
32+
broadcasted(::DefaultArrayStyle, ::typeof($op), a::Zeros, b::Fill) = _broadcasted_zeros(a, b)
33+
broadcasted(::DefaultArrayStyle, ::typeof($op), a::Zeros, b::Number) = _broadcasted_zeros(a, b)
34+
broadcasted(::DefaultArrayStyle, ::typeof($op), a::Zeros, b::AbstractRange) = _broadcasted_zeros(a, b)
35+
broadcasted(::DefaultArrayStyle, ::typeof($op), a::Zeros, b::AbstractArray) = _broadcasted_zeros(a, b)
36+
broadcasted(::DefaultArrayStyle{1}, ::typeof($op), a::Zeros, b::AbstractRange) = _broadcasted_zeros(a, b)
37+
end
38+
end
39+
40+
for op in (:*, :\)
41+
@eval begin
42+
broadcasted(::DefaultArrayStyle, ::typeof($op), a::Ones, b::Zeros) = _broadcasted_zeros(a, b)
43+
broadcasted(::DefaultArrayStyle, ::typeof($op), a::Fill, b::Zeros) = _broadcasted_zeros(a, b)
44+
broadcasted(::DefaultArrayStyle, ::typeof($op), a::Number, b::Zeros) = _broadcasted_zeros(a, b)
45+
broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractRange, b::Zeros) = _broadcasted_zeros(a, b)
46+
broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractArray, b::Zeros) = _broadcasted_zeros(a, b)
47+
broadcasted(::DefaultArrayStyle{1}, ::typeof($op), a::AbstractRange, b::Zeros) = _broadcasted_zeros(a, b)
48+
end
49+
end
4650

4751

4852
broadcasted(::DefaultArrayStyle, ::typeof(*), a::Ones, b::Ones) = _broadcasted_ones(a, b)

test/runtests.jl

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -499,15 +499,23 @@ end
499499
@test broadcast(*, rnge, Fill(5.0, 10)) == broadcast(*, rnge, 5.0)
500500
@test_throws DimensionMismatch broadcast(*, rnge, Fill(5.0, 11))
501501

502-
@test Zeros(5) .* Ones(5) == Zeros(5)
503-
@test Zeros(5) .* Fill(5.0, 5) == Zeros(5)
504-
@test Ones(5) .* Zeros(5) == Zeros(5)
505-
@test Fill(5.0, 5) .* Zeros(5) == Zeros(5)
502+
@testset "Special zeros" begin
503+
@test Zeros(5) .* Ones(5) Zeros(5) .* 1 Zeros(5)
504+
@test Zeros(5) .* Fill(5.0, 5) Zeros(5) .* 5.0 Zeros(5)
505+
@test Ones(5) .* Zeros(5) 1 .* Zeros(5) Zeros(5)
506+
@test Fill(5.0, 5) .* Zeros(5) 5.0 .* Zeros(5) Zeros(5)
507+
508+
@test Zeros(5) ./ Ones(5) Zeros(5) ./ 1 Zeros(5)
509+
@test Zeros(5) ./ Fill(5.0, 5) Zeros(5) ./ 5.0 Zeros(5)
510+
@test Ones(5) .\ Zeros(5) 1 .\ Zeros(5) Zeros(5)
511+
@test Fill(5.0, 5) .\ Zeros(5) 5.0 .\ Zeros(5) Zeros(5)
512+
end
506513

507-
# support Ref
508-
@test Fill(1,10) .- 1 Fill(1,10) .- Ref(1) Fill(1,10) .- Ref(1I)
509-
@test Fill([1 2; 3 4],10) .- Ref(1I) == Fill([0 2; 3 3],10)
510-
@test Ref(1I) .+ Fill([1 2; 3 4],10) == Fill([2 2; 3 5],10)
514+
@testset "support Ref" begin
515+
@test Fill(1,10) .- 1 Fill(1,10) .- Ref(1) Fill(1,10) .- Ref(1I)
516+
@test Fill([1 2; 3 4],10) .- Ref(1I) == Fill([0 2; 3 3],10)
517+
@test Ref(1I) .+ Fill([1 2; 3 4],10) == Fill([2 2; 3 5],10)
518+
end
511519

512520
@testset "Special Ones" begin
513521
@test Ones{Int}(5) .* (1:5) (1:5) .* Ones{Int}(5) 1:5
@@ -877,4 +885,12 @@ end
877885

878886
@testset "print" begin
879887
@test stringmime("text/plain", Zeros(3)) == "3-element Zeros{Float64,1,Tuple{Base.OneTo{$Int}}}:\n\n\n"
888+
end
889+
890+
@testset "reshape" begin
891+
@test reshape(Fill(2,6),2,3) Fill(2,2,3)
892+
@test reshape(Fill(2,6),big(2),3) == Fill(2,big(2),3)
893+
@test_throws DimensionMismatch reshape(Fill(2,6),2,4)
894+
@test reshape(Zeros(6),2,3) Zeros(2,3)
895+
@test reshape(Zeros(6),big(2),3) == Zeros(big(2),3)
880896
end

0 commit comments

Comments
 (0)