Skip to content

Commit f30e57d

Browse files
authored
Fill, zero, etc. (#17)
1 parent 61786d0 commit f30e57d

File tree

6 files changed

+101
-5
lines changed

6 files changed

+101
-5
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Derive"
22
uuid = "a07dfc7f-7d04-4eb5-84cc-a97f051f655a"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.3.3"
4+
version = "0.3.4"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/abstractarrayinterface.jl

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,20 @@ using BroadcastMapConversion: map_function, map_args
6767
# TODO: Look into `SparseArrays.capturescalars`:
6868
# https://github.com/JuliaSparse/SparseArrays.jl/blob/1beb0e4a4618b0399907b0000c43d9f66d34accc/src/higherorderfns.jl#L1092-L1102
6969
@interface interface::AbstractArrayInterface function Base.copyto!(
70-
dest::AbstractArray, bc::Broadcast.Broadcasted
70+
a_dest::AbstractArray, bc::Broadcast.Broadcasted
7171
)
72-
return @interface interface map!(map_function(bc), dest, map_args(bc)...)
72+
return @interface interface map!(map_function(bc), a_dest, map_args(bc)...)
73+
end
74+
75+
# This captures broadcast expressions such as `a .= 2`.
76+
# Ideally this would be handled by `map!(f, a_dest)` but that isn't defined yet:
77+
# https://github.com/JuliaLang/julia/issues/31677
78+
# https://github.com/JuliaLang/julia/pull/40632
79+
@interface interface::AbstractArrayInterface function Base.copyto!(
80+
a_dest::AbstractArray, bc::Broadcast.Broadcasted{Broadcast.DefaultArrayStyle{0}}
81+
)
82+
isempty(map_args(bc)) || error("Bad broadcast expression.")
83+
return @interface interface map!(map_function(bc), a_dest, a_dest)
7384
end
7485

7586
# This is defined in this way so we can rely on the Broadcast logic
@@ -86,11 +97,41 @@ end
8697
# `invoke(Base.map!, Tuple{Any,AbstractArray,Vararg{AbstractArray}}, f, dest, as...)`.
8798
# TODO: Use `MethodError`?
8899
@interface ::AbstractArrayInterface function Base.map!(
89-
f, dest::AbstractArray, as::AbstractArray...
100+
f, a_dest::AbstractArray, a_srcs::AbstractArray...
90101
)
91102
return error("Not implemented.")
92103
end
93104

105+
@interface interface::AbstractArrayInterface function Base.fill!(a::AbstractArray, value)
106+
@interface interface map!(Returns(value), a, a)
107+
end
108+
109+
using ArrayLayouts: zero!
110+
111+
# `zero!` isn't defined in `Base`, but it is defined in `ArrayLayouts`
112+
# and is useful for sparse array logic, since it can be used to empty
113+
# the sparse array storage.
114+
# We use a single function definition to minimize method ambiguities.
115+
@interface interface::AbstractArrayInterface function ArrayLayouts.zero!(a::AbstractArray)
116+
# More generally, the first codepath could be taking if `zero(eltype(a))`
117+
# is defined and the elements are immutable.
118+
f = eltype(a) <: Number ? Returns(zero(eltype(a))) : zero!
119+
return @interface interface map!(f, a, a)
120+
end
121+
122+
# Specialized version of `Base.zero` written in terms of `ArrayLayouts.zero!`.
123+
# This is friendlier for sparse arrays since `ArrayLayouts.zero!` makes it easier
124+
# to handle the logic of dropping all elements of the sparse array when possible.
125+
# We use a single function definition to minimize method ambiguities.
126+
@interface interface::AbstractArrayInterface function Base.zero(a::AbstractArray)
127+
# More generally, the first codepath could be taking if `zero(eltype(a))`
128+
# is defined and the elements are immutable.
129+
if eltype(a) <: Number
130+
return @interface interface zero!(similar(a))
131+
end
132+
return @interface interface map(interface(zero), a)
133+
end
134+
94135
@interface ::AbstractArrayInterface function Base.mapreduce(
95136
f, op, as::AbstractArray...; kwargs...
96137
)

src/traits.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,12 @@ function derive(::Val{:AbstractArrayOps}, type)
2727
Base.all(::$type)
2828
Base.iszero(::$type)
2929
Base.real(::$type)
30+
Base.fill!(::$type, ::Any)
31+
ArrayLayouts.zero!(::$type)
32+
Base.zero(::$type)
3033
Base.permutedims!(::Any, ::$type, ::Any)
3134
Broadcast.BroadcastStyle(::Type{<:$type})
35+
Base.copyto!(::$type, ::Broadcast.Broadcasted{Broadcast.DefaultArrayStyle{0}})
3236
ArrayLayouts.MemoryLayout(::Type{<:$type})
3337
LinearAlgebra.mul!(::AbstractMatrix, ::$type, ::$type, ::Number, ::Number)
3438
end

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
[deps]
22
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
33
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
4+
BroadcastMapConversion = "4a4adec5-520f-4750-bb37-d5e66b4ddeb2"
45
Derive = "a07dfc7f-7d04-4eb5-84cc-a97f051f655a"
56
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
67
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"

test/basics/SparseArrayDOKs.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,24 @@ end
8080
@interface ::SparseArrayInterface function Base.map!(
8181
f, a_dest::AbstractArray, as::AbstractArray...
8282
)
83+
# TODO: Define a function `preserves_unstored(a_dest, f, as...)`
84+
# to determine if a function preserves the stored values
85+
# of the destination sparse array.
86+
# The current code may be inefficient since it actually
87+
# accesses an unstored element, which in the case of a
88+
# sparse array of arrays can allocate an array.
89+
# Sparse arrays could be expected to define a cheap
90+
# unstored element allocator, for example
91+
# `get_prototypical_unstored(a::AbstractArray)`.
92+
I = first(eachindex(as...))
93+
preserves_unstored = iszero(f(map(a -> getunstoredindex(a, I), as)...))
94+
if !preserves_unstored
95+
# Doesn't preserve unstored values, loop over all elements.
96+
for I in eachindex(as...)
97+
a_dest[I] = map(f, map(a -> a[I], as)...)
98+
end
99+
end
100+
# TODO: Define `eachstoredindex(as...)`.
83101
for I in union(eachstoredindex.(as)...)
84102
a_dest[I] = map(f, map(a -> a[I], as)...)
85103
end
@@ -230,6 +248,11 @@ end
230248
eachstoredindex(a::SparseArrayDOK) = keys(storage(a))
231249
storedlength(a::SparseArrayDOK) = length(eachstoredindex(a))
232250

251+
function ArrayLayouts.zero!(a::SparseArrayDOK)
252+
empty!(storage(a))
253+
return a
254+
end
255+
233256
# Specify the interface the type adheres to.
234257
Derive.interface(::Type{<:SparseArrayDOK}) = SparseArrayInterface()
235258

test/basics/test_basics.jl

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
using Test: @test, @testset
1+
using ArrayLayouts: zero!
22
include("SparseArrayDOKs.jl")
33
using .SparseArrayDOKs: SparseArrayDOK, storedlength
4+
using Test: @test, @testset
45

56
elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
67
@testset "Derive" for elt in elts
@@ -89,4 +90,30 @@ elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
8990
@test b == a
9091
@test b[1, 2] == 12
9192
@test storedlength(b) == 1
93+
94+
a = SparseArrayDOK{elt}(2, 2)
95+
a .= 2
96+
@test storedlength(a) == length(a)
97+
for I in eachindex(a)
98+
@test a[I] == 2
99+
end
100+
101+
a = SparseArrayDOK{elt}(2, 2)
102+
fill!(a, 2)
103+
@test storedlength(a) == length(a)
104+
for I in eachindex(a)
105+
@test a[I] == 2
106+
end
107+
108+
a = SparseArrayDOK{elt}(2, 2)
109+
a[1, 2] = 12
110+
zero!(a)
111+
@test iszero(a)
112+
@test iszero(storedlength(a))
113+
114+
a = SparseArrayDOK{elt}(2, 2)
115+
a[1, 2] = 12
116+
b = zero(a)
117+
@test iszero(b)
118+
@test iszero(storedlength(b))
92119
end

0 commit comments

Comments
 (0)