Skip to content

Commit b2fea1d

Browse files
authored
Reductions (#14)
1 parent f08a672 commit b2fea1d

File tree

6 files changed

+74
-4
lines changed

6 files changed

+74
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Derive"
22
uuid = "a07dfc7f-7d04-4eb5-84cc-a97f051f655a"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.3.1"
4+
version = "0.3.2"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/abstractarrayinterface.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,35 @@ end
9292
return error("Not implemented.")
9393
end
9494

95+
@interface ::AbstractArrayInterface function Base.mapreduce(
96+
f, op, as::AbstractArray...; kwargs...
97+
)
98+
return error("Not implemented.")
99+
end
100+
101+
# TODO: Generalize to multiple inputs.
102+
@interface interface::AbstractInterface function Base.reduce(f, a::AbstractArray; kwargs...)
103+
return @interface interface mapreduce(identity, f, a; kwargs...)
104+
end
105+
106+
@interface interface::AbstractArrayInterface function Base.all(a::AbstractArray)
107+
return @interface interface reduce(&, a; init=true)
108+
end
109+
110+
@interface interface::AbstractArrayInterface function Base.all(
111+
f::Function, a::AbstractArray
112+
)
113+
return @interface interface mapreduce(f, &, a; init=true)
114+
end
115+
116+
@interface interface::AbstractArrayInterface function Base.iszero(a::AbstractArray)
117+
return @interface interface all(iszero, a)
118+
end
119+
120+
@interface interface::AbstractArrayInterface function Base.isreal(a::AbstractArray)
121+
return @interface interface all(isreal, a)
122+
end
123+
95124
@interface ::AbstractArrayInterface function Base.permutedims!(
96125
a_dest::AbstractArray, a_src::AbstractArray, perm
97126
)

src/interface_function.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@ This errors for debugging, but probably should be defined as:
77
call(interface, f, args...) = f(args...)
88
```
99
=#
10-
call(interface, f, args...) = error("Not implemented")
10+
call(interface, f, args...; kwargs...) = error("Not implemented")
1111

1212
# Change the behavior of a function to use a certain interface.
1313
struct InterfaceFunction{Interface,F} <: Function
1414
interface::Interface
1515
f::F
1616
end
17-
(f::InterfaceFunction)(args...) = call(f.interface, f.f, args...)
17+
(f::InterfaceFunction)(args...; kwargs...) = call(f.interface, f.f, args...; kwargs...)

src/traits.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,12 @@ function derive(::Val{:AbstractArrayOps}, type)
1919
Base.copy(::$type)
2020
Base.map(::Any, ::$type...)
2121
Base.map!(::Any, ::AbstractArray, ::$type...)
22+
Base.mapreduce(::Any, ::Any, ::$type...; kwargs...)
23+
Base.reduce(::Any, ::$type...; kwargs...)
24+
Base.all(::Function, ::$type)
25+
Base.all(::$type)
26+
Base.iszero(::$type)
27+
Base.real(::$type)
2228
Base.permutedims!(::Any, ::$type, ::Any)
2329
Broadcast.BroadcastStyle(::Type{<:$type})
2430
ArrayLayouts.MemoryLayout(::Type{<:$type})

test/basics/SparseArrayDOKs.jl

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,25 @@ function setunstoredindex!(a::AbstractArray, value, I::CartesianIndex)
1010
return setunstoredindex!(a, value, Tuple(I)...)
1111
end
1212

13+
# A view of the stored values of an array.
14+
# Similar to: `@view a[collect(eachstoredindex(a))]`, but the issue
15+
# with that is it returns a `SubArray` wrapping a sparse array, which
16+
# is then interpreted as a sparse array. Also, that involves extra
17+
# logic for determining if the indices are stored or not, but we know
18+
# the indices are stored.
19+
struct StoredValues{T,A<:AbstractArray{T},I} <: AbstractVector{T}
20+
array::A
21+
storedindices::I
22+
end
23+
StoredValues(a::AbstractArray) = StoredValues(a, collect(eachstoredindex(a)))
24+
Base.size(a::StoredValues) = size(a.storedindices)
25+
Base.getindex(a::StoredValues, I::Int) = getstoredindex(a.array, a.storedindices[I])
26+
function Base.setindex!(a::StoredValues, value, I::Int)
27+
return setstoredindex!(a.array, value, a.storedindices[I])
28+
end
29+
30+
storedvalues(a::AbstractArray) = StoredValues(a)
31+
1332
using ArrayLayouts: ArrayLayouts, MatMulMatAdd, MemoryLayout
1433
using Derive: Derive, @array_aliases, @derive, @interface, AbstractArrayInterface, interface
1534
using LinearAlgebra: LinearAlgebra
@@ -29,8 +48,8 @@ end
2948
a::AbstractArray{<:Any,N}, value, I::Vararg{Int,N}
3049
) where {N}
3150
checkbounds(a, I...)
32-
iszero(value) && return a
3351
if !isstored(a, I...)
52+
iszero(value) && return a
3453
setunstoredindex!(a, value, I...)
3554
return a
3655
end
@@ -67,6 +86,13 @@ end
6786
return a_dest
6887
end
6988

89+
@interface ::SparseArrayInterface function Base.mapreduce(
90+
f, op, a::AbstractArray; kwargs...
91+
)
92+
# TODO: Need to select a better `init`.
93+
return mapreduce(f, op, storedvalues(a); kwargs...)
94+
end
95+
7096
# ArrayLayouts functionality.
7197

7298
function ArrayLayouts.sub_materialize(::SparseLayout, a::AbstractArray, axes::Tuple)

test/basics/test_basics.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,4 +72,13 @@ elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
7272
@test b isa SparseArrayDOK{elt,1}
7373
@test b == [12, 0]
7474
@test storedlength(b) == 1
75+
76+
a = SparseArrayDOK{elt}(2, 2)
77+
@test iszero(a)
78+
a[2, 1] = 21
79+
a[1, 2] = 12
80+
@test !iszero(a)
81+
@test isreal(a)
82+
@test sum(a) == 33
83+
@test mapreduce(x -> 2x, +, a) == 66
7584
end

0 commit comments

Comments
 (0)