Skip to content

Commit 3e2dcdb

Browse files
Merge pull request #293 from jlchan/jc/NamedArrayPartition
adding `NamedArrayPartition` type
2 parents 67e0be0 + d663c49 commit 3e2dcdb

File tree

5 files changed

+155
-2
lines changed

5 files changed

+155
-2
lines changed

docs/src/array_types.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,5 @@ mapping and iteration functions, and more.
1313
VectorOfArray
1414
DiffEqArray
1515
ArrayPartition
16+
NamedArrayPartition
1617
```

src/RecursiveArrayTools.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ include("utils.jl")
2121
include("vector_of_array.jl")
2222
include("tabletraits.jl")
2323
include("array_partition.jl")
24+
include("named_array_partition.jl")
2425

2526
function Base.show(io::IO, x::Union{ArrayPartition, AbstractVectorOfArray})
2627
invoke(show, Tuple{typeof(io), Any}, io, x)
@@ -37,6 +38,6 @@ export recursivecopy, recursivecopy!, recursivefill!, vecvecapply, copyat_or_pus
3738
vecvec_to_mat, recursive_one, recursive_mean, recursive_bottom_eltype,
3839
recursive_unitless_bottom_eltype, recursive_unitless_eltype
3940

40-
export ArrayPartition
41+
export ArrayPartition, NamedArrayPartition
4142

4243
end # module

src/named_array_partition.jl

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
"""
2+
NamedArrayPartition(; kwargs...)
3+
NamedArrayPartition(x::NamedTuple)
4+
5+
Similar to an `ArrayPartition` but the individual arrays can be accessed via the
6+
constructor-specified names. However, unlike `ArrayPartition`, each individual array
7+
must have the same element type.
8+
"""
9+
struct NamedArrayPartition{T, A<:ArrayPartition{T}, NT<:NamedTuple} <: AbstractVector{T}
10+
array_partition::A
11+
names_to_indices::NT
12+
end
13+
NamedArrayPartition(; kwargs...) = NamedArrayPartition(NamedTuple(kwargs))
14+
function NamedArrayPartition(x::NamedTuple)
15+
names_to_indices = NamedTuple(Pair(symbol, index) for (index, symbol) in enumerate(keys(x)))
16+
17+
# enforce homogeneity of eltypes
18+
@assert all(eltype.(values(x)) .== eltype(first(x)))
19+
T = eltype(first(x))
20+
S = typeof(values(x))
21+
return NamedArrayPartition(ArrayPartition{T, S}(values(x)), names_to_indices)
22+
end
23+
24+
# Note: overloading `getproperty` means we cannot access `NamedArrayPartition`
25+
# fields except through `getfield` and accessor functions.
26+
ArrayPartition(x::NamedArrayPartition) = getfield(x, :array_partition)
27+
28+
Base.Array(x::NamedArrayPartition) = Array(ArrayPartition(x))
29+
30+
Base.zero(x::NamedArrayPartition{T, S, TN}) where {T, S, TN} =
31+
NamedArrayPartition{T, S, TN}(zero(ArrayPartition(x)), getfield(x, :names_to_indices))
32+
Base.zero(A::NamedArrayPartition, dims::NTuple{N, Int}) where {N} = zero(A) # ignore dims since named array partitions are vectors
33+
34+
35+
Base.propertynames(x::NamedArrayPartition) = propertynames(getfield(x, :names_to_indices))
36+
Base.getproperty(x::NamedArrayPartition, s::Symbol) =
37+
getindex(ArrayPartition(x).x, getproperty(getfield(x, :names_to_indices), s))
38+
39+
# this enables x.s = some_array.
40+
@inline function Base.setproperty!(x::NamedArrayPartition, s::Symbol, v)
41+
index = getproperty(getfield(x, :names_to_indices), s)
42+
ArrayPartition(x).x[index] .= v
43+
end
44+
45+
# print out NamedArrayPartition as a NamedTuple
46+
Base.summary(x::NamedArrayPartition) = string(typeof(x), " with arrays:")
47+
Base.show(io::IO, m::MIME"text/plain", x::NamedArrayPartition) =
48+
show(io, m, NamedTuple(Pair.(keys(getfield(x, :names_to_indices)), ArrayPartition(x).x)))
49+
50+
Base.size(x::NamedArrayPartition) = size(ArrayPartition(x))
51+
Base.length(x::NamedArrayPartition) = length(ArrayPartition(x))
52+
Base.getindex(x::NamedArrayPartition, args...) = getindex(ArrayPartition(x), args...)
53+
54+
Base.setindex!(x::NamedArrayPartition, args...) = setindex!(ArrayPartition(x), args...)
55+
Base.map(f, x::NamedArrayPartition) = NamedArrayPartition(map(f, ArrayPartition(x)), getfield(x, :names_to_indices))
56+
Base.mapreduce(f, op, x::NamedArrayPartition) = mapreduce(f, op, ArrayPartition(x))
57+
# Base.filter(f, x::NamedArrayPartition) = filter(f, ArrayPartition(x))
58+
59+
Base.similar(x::NamedArrayPartition{T, S, NT}) where {T, S, NT} =
60+
NamedArrayPartition{T, S, NT}(similar(ArrayPartition(x)), getfield(x, :names_to_indices))
61+
62+
# broadcasting
63+
Base.BroadcastStyle(::Type{<:NamedArrayPartition}) = Broadcast.ArrayStyle{NamedArrayPartition}()
64+
function Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}},
65+
::Type{ElType}) where {ElType}
66+
x = find_NamedArrayPartition(bc)
67+
return NamedArrayPartition(similar(ArrayPartition(x)), getfield(x, :names_to_indices))
68+
end
69+
70+
# when broadcasting with ArrayPartition + another array type, the output is the other array tupe
71+
Base.BroadcastStyle(::Broadcast.ArrayStyle{NamedArrayPartition}, ::Broadcast.DefaultArrayStyle{1}) =
72+
Broadcast.DefaultArrayStyle{1}()
73+
74+
# hook into ArrayPartition broadcasting routines
75+
@inline RecursiveArrayTools.npartitions(x::NamedArrayPartition) = npartitions(ArrayPartition(x))
76+
@inline RecursiveArrayTools.unpack(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}}, i) =
77+
Broadcast.Broadcasted(bc.f, RecursiveArrayTools.unpack_args(i, bc.args))
78+
@inline RecursiveArrayTools.unpack(x::NamedArrayPartition, i) = unpack(ArrayPartition(x), i)
79+
80+
Base.copy(A::NamedArrayPartition{T,S,NT}) where {T,S,NT} =
81+
NamedArrayPartition{T,S,NT}(copy(ArrayPartition(A)), getfield(A, :names_to_indices))
82+
83+
@inline NamedArrayPartition(f::F, N, names_to_indices) where F<:Function =
84+
NamedArrayPartition(ArrayPartition(ntuple(f, Val(N))), names_to_indices)
85+
86+
@inline function Base.copy(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}})
87+
N = npartitions(bc)
88+
@inline function f(i)
89+
copy(unpack(bc, i))
90+
end
91+
x = find_NamedArrayPartition(bc)
92+
NamedArrayPartition(f, N, getfield(x, :names_to_indices))
93+
end
94+
95+
@inline function Base.copyto!(dest::NamedArrayPartition,
96+
bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}})
97+
N = npartitions(dest, bc)
98+
@inline function f(i)
99+
copyto!(ArrayPartition(dest).x[i], unpack(bc, i))
100+
end
101+
ntuple(f, Val(N))
102+
return dest
103+
end
104+
105+
# `x = find_NamedArrayPartition(x)` returns the first `NamedArrayPartition` among broadcast arguments.
106+
find_NamedArrayPartition(bc::Base.Broadcast.Broadcasted) = find_NamedArrayPartition(bc.args)
107+
find_NamedArrayPartition(args::Tuple) =
108+
find_NamedArrayPartition(find_NamedArrayPartition(args[1]), Base.tail(args))
109+
find_NamedArrayPartition(x) = x
110+
find_NamedArrayPartition(::Tuple{}) = nothing
111+
find_NamedArrayPartition(x::NamedArrayPartition, rest) = x
112+
find_NamedArrayPartition(::Any, rest) = find_NamedArrayPartition(rest)
113+
114+
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
using RecursiveArrayTools, Test
2+
3+
@testset "NamedArrayPartition tests" begin
4+
x = NamedArrayPartition(a = ones(10), b = rand(20))
5+
@test typeof(@. sin(x * x^2 / x - 1)) <: NamedArrayPartition
6+
@test typeof(x.^2) <: NamedArrayPartition
7+
@test x.a ones(10)
8+
@test typeof(x .+ x[1:end]) <: Vector # test broadcast precedence
9+
@test all(x .== x[1:end])
10+
y = copy(x)
11+
@test zero(x, (10, 20)) == zero(x) # test that ignoring dims works
12+
@test typeof(zero(x)) <: NamedArrayPartition
13+
@test (y .*= 2).a[1] 2 # test in-place bcast
14+
15+
@test length(Array(x))==30
16+
@test typeof(Array(x)) <: Array
17+
@test propertynames(x) == (:a, :b)
18+
19+
x = NamedArrayPartition(a = ones(1), b = 2*ones(1))
20+
@test Base.summary(x) == string(typeof(x), " with arrays:")
21+
io = IOBuffer()
22+
Base.show(io, MIME"text/plain"(), x)
23+
@test String(take!(io)) == "(a = [1.0], b = [2.0])"
24+
25+
using StructArrays
26+
using StaticArrays: SVector
27+
x = NamedArrayPartition(a = StructArray{SVector{2, Float64}}((ones(5), 2*ones(5))),
28+
b = StructArray{SVector{2, Float64}}((3 * ones(2,2), 4*ones(2,2))))
29+
@test typeof(x.a) <: StructVector{<:SVector{2}}
30+
@test typeof(x.b) <: StructArray{<:SVector{2}, 2}
31+
@test typeof((x->x[1]).(x)) <: NamedArrayPartition
32+
@test typeof(map(x->x[1], x)) <: NamedArrayPartition
33+
end
34+

test/runtests.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,12 @@ end
2525
@time @safetestset "Utils Tests" begin
2626
include("utils_test.jl")
2727
end
28+
@time @safetestset "NamedArrayPartition Tests" begin
29+
include("named_array_partition_tests.jl")
30+
end
2831
@time @safetestset "Partitions Tests" begin
2932
include("partitions_test.jl")
30-
end
33+
end
3134
@time @safetestset "VecOfArr Indexing Tests" begin
3235
include("basic_indexing.jl")
3336
end

0 commit comments

Comments
 (0)