|
| 1 | +""" |
| 2 | + NamedArrayPartition(; kwargs...) |
| 3 | + NamedArrayPartition(x::NamedTuple) |
| 4 | +
|
| 5 | +Similar to an `ArrayPartition` but the individual arrays can be accessed via the |
| 6 | +constructor-specified names. However, unlike `ArrayPartition`, each individual array |
| 7 | +must have the same element type. |
| 8 | +""" |
| 9 | +struct NamedArrayPartition{T, A<:ArrayPartition{T}, NT<:NamedTuple} <: AbstractVector{T} |
| 10 | + array_partition::A |
| 11 | + names_to_indices::NT |
| 12 | +end |
| 13 | +NamedArrayPartition(; kwargs...) = NamedArrayPartition(NamedTuple(kwargs)) |
| 14 | +function NamedArrayPartition(x::NamedTuple) |
| 15 | + names_to_indices = NamedTuple(Pair(symbol, index) for (index, symbol) in enumerate(keys(x))) |
| 16 | + |
| 17 | + # enforce homogeneity of eltypes |
| 18 | + @assert all(eltype.(values(x)) .== eltype(first(x))) |
| 19 | + T = eltype(first(x)) |
| 20 | + S = typeof(values(x)) |
| 21 | + return NamedArrayPartition(ArrayPartition{T, S}(values(x)), names_to_indices) |
| 22 | +end |
| 23 | + |
| 24 | +# Note: overloading `getproperty` means we cannot access `NamedArrayPartition` |
| 25 | +# fields except through `getfield` and accessor functions. |
| 26 | +ArrayPartition(x::NamedArrayPartition) = getfield(x, :array_partition) |
| 27 | + |
| 28 | +Base.Array(x::NamedArrayPartition) = Array(ArrayPartition(x)) |
| 29 | + |
| 30 | +Base.zero(x::NamedArrayPartition{T, S, TN}) where {T, S, TN} = |
| 31 | + NamedArrayPartition{T, S, TN}(zero(ArrayPartition(x)), getfield(x, :names_to_indices)) |
| 32 | +Base.zero(A::NamedArrayPartition, dims::NTuple{N, Int}) where {N} = zero(A) # ignore dims since named array partitions are vectors |
| 33 | + |
| 34 | + |
| 35 | +Base.propertynames(x::NamedArrayPartition) = propertynames(getfield(x, :names_to_indices)) |
| 36 | +Base.getproperty(x::NamedArrayPartition, s::Symbol) = |
| 37 | + getindex(ArrayPartition(x).x, getproperty(getfield(x, :names_to_indices), s)) |
| 38 | + |
| 39 | +# this enables x.s = some_array. |
| 40 | +@inline function Base.setproperty!(x::NamedArrayPartition, s::Symbol, v) |
| 41 | + index = getproperty(getfield(x, :names_to_indices), s) |
| 42 | + ArrayPartition(x).x[index] .= v |
| 43 | +end |
| 44 | + |
| 45 | +# print out NamedArrayPartition as a NamedTuple |
| 46 | +Base.summary(x::NamedArrayPartition) = string(typeof(x), " with arrays:") |
| 47 | +Base.show(io::IO, m::MIME"text/plain", x::NamedArrayPartition) = |
| 48 | + show(io, m, NamedTuple(Pair.(keys(getfield(x, :names_to_indices)), ArrayPartition(x).x))) |
| 49 | + |
| 50 | +Base.size(x::NamedArrayPartition) = size(ArrayPartition(x)) |
| 51 | +Base.length(x::NamedArrayPartition) = length(ArrayPartition(x)) |
| 52 | +Base.getindex(x::NamedArrayPartition, args...) = getindex(ArrayPartition(x), args...) |
| 53 | + |
| 54 | +Base.setindex!(x::NamedArrayPartition, args...) = setindex!(ArrayPartition(x), args...) |
| 55 | +Base.map(f, x::NamedArrayPartition) = NamedArrayPartition(map(f, ArrayPartition(x)), getfield(x, :names_to_indices)) |
| 56 | +Base.mapreduce(f, op, x::NamedArrayPartition) = mapreduce(f, op, ArrayPartition(x)) |
| 57 | +# Base.filter(f, x::NamedArrayPartition) = filter(f, ArrayPartition(x)) |
| 58 | + |
| 59 | +Base.similar(x::NamedArrayPartition{T, S, NT}) where {T, S, NT} = |
| 60 | + NamedArrayPartition{T, S, NT}(similar(ArrayPartition(x)), getfield(x, :names_to_indices)) |
| 61 | + |
| 62 | +# broadcasting |
| 63 | +Base.BroadcastStyle(::Type{<:NamedArrayPartition}) = Broadcast.ArrayStyle{NamedArrayPartition}() |
| 64 | +function Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}}, |
| 65 | + ::Type{ElType}) where {ElType} |
| 66 | + x = find_NamedArrayPartition(bc) |
| 67 | + return NamedArrayPartition(similar(ArrayPartition(x)), getfield(x, :names_to_indices)) |
| 68 | +end |
| 69 | + |
| 70 | +# when broadcasting with ArrayPartition + another array type, the output is the other array tupe |
| 71 | +Base.BroadcastStyle(::Broadcast.ArrayStyle{NamedArrayPartition}, ::Broadcast.DefaultArrayStyle{1}) = |
| 72 | + Broadcast.DefaultArrayStyle{1}() |
| 73 | + |
| 74 | +# hook into ArrayPartition broadcasting routines |
| 75 | +@inline RecursiveArrayTools.npartitions(x::NamedArrayPartition) = npartitions(ArrayPartition(x)) |
| 76 | +@inline RecursiveArrayTools.unpack(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}}, i) = |
| 77 | + Broadcast.Broadcasted(bc.f, RecursiveArrayTools.unpack_args(i, bc.args)) |
| 78 | +@inline RecursiveArrayTools.unpack(x::NamedArrayPartition, i) = unpack(ArrayPartition(x), i) |
| 79 | + |
| 80 | +Base.copy(A::NamedArrayPartition{T,S,NT}) where {T,S,NT} = |
| 81 | + NamedArrayPartition{T,S,NT}(copy(ArrayPartition(A)), getfield(A, :names_to_indices)) |
| 82 | + |
| 83 | +@inline NamedArrayPartition(f::F, N, names_to_indices) where F<:Function = |
| 84 | + NamedArrayPartition(ArrayPartition(ntuple(f, Val(N))), names_to_indices) |
| 85 | + |
| 86 | +@inline function Base.copy(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}}) |
| 87 | + N = npartitions(bc) |
| 88 | + @inline function f(i) |
| 89 | + copy(unpack(bc, i)) |
| 90 | + end |
| 91 | + x = find_NamedArrayPartition(bc) |
| 92 | + NamedArrayPartition(f, N, getfield(x, :names_to_indices)) |
| 93 | +end |
| 94 | + |
| 95 | +@inline function Base.copyto!(dest::NamedArrayPartition, |
| 96 | + bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}}) |
| 97 | + N = npartitions(dest, bc) |
| 98 | + @inline function f(i) |
| 99 | + copyto!(ArrayPartition(dest).x[i], unpack(bc, i)) |
| 100 | + end |
| 101 | + ntuple(f, Val(N)) |
| 102 | + return dest |
| 103 | +end |
| 104 | + |
| 105 | +# `x = find_NamedArrayPartition(x)` returns the first `NamedArrayPartition` among broadcast arguments. |
| 106 | +find_NamedArrayPartition(bc::Base.Broadcast.Broadcasted) = find_NamedArrayPartition(bc.args) |
| 107 | +find_NamedArrayPartition(args::Tuple) = |
| 108 | + find_NamedArrayPartition(find_NamedArrayPartition(args[1]), Base.tail(args)) |
| 109 | +find_NamedArrayPartition(x) = x |
| 110 | +find_NamedArrayPartition(::Tuple{}) = nothing |
| 111 | +find_NamedArrayPartition(x::NamedArrayPartition, rest) = x |
| 112 | +find_NamedArrayPartition(::Any, rest) = find_NamedArrayPartition(rest) |
| 113 | + |
| 114 | + |
0 commit comments