Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/RecursiveArrayTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,6 @@ export recursivecopy, recursivecopy!, recursivefill!, vecvecapply, copyat_or_pus
vecvec_to_mat, recursive_one, recursive_mean, recursive_bottom_eltype,
recursive_unitless_bottom_eltype, recursive_unitless_eltype

export ArrayPartition, NamedArrayPartition
export ArrayPartition, NamedArrayPartition, AbstractNamedArrayPartition

end # module
132 changes: 76 additions & 56 deletions src/named_array_partition.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
abstract type AbstractNamedArrayPartition{T, A, NT} <: AbstractVector{T} end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should have an AbstractArrayPartition, and that needs a docstring to describe its interface?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or is ArrayPartition an AbstractNamedArrayPartition with x being its only names?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can add a higher level called AbstractArrayPartition{T} if you want and apply the same level of abstraction to the ArrayPartition interface as well (either or both) such that the same subtyping can be performed there also. I will add to the documentation, a docstring about how you can subtype from the AbstractNamedArrayPartition


"""
NamedArrayPartition(; kwargs...)
NamedArrayPartition(x::NamedTuple)
Expand All @@ -6,137 +8,155 @@ Similar to an `ArrayPartition` but the individual arrays can be accessed via the
constructor-specified names. However, unlike `ArrayPartition`, each individual array
must have the same element type.
"""
struct NamedArrayPartition{T, A <: ArrayPartition{T}, NT <: NamedTuple} <: AbstractVector{T}
struct NamedArrayPartition{T, A <: ArrayPartition{T}, NT <: NamedTuple} <: AbstractNamedArrayPartition{T, A, NT}
array_partition::A
names_to_indices::NT
end
NamedArrayPartition(; kwargs...) = NamedArrayPartition(NamedTuple(kwargs))
function NamedArrayPartition(x::NamedTuple)
(::Type{T})(; kwargs...) where {T<:AbstractNamedArrayPartition} = T(NamedTuple(kwargs))
function (::Type{T})(x::NamedTuple) where {T<:AbstractNamedArrayPartition}
names_to_indices = NamedTuple(Pair(symbol, index)
for (index, symbol) in enumerate(keys(x)))

# enforce homogeneity of eltypes
@assert all(eltype.(values(x)) .== eltype(first(x)))
T = eltype(first(x))
R = eltype(first(x))
S = typeof(values(x))
return NamedArrayPartition(ArrayPartition{T, S}(values(x)), names_to_indices)
return T(ArrayPartition{R, S}(values(x)), names_to_indices)
end

function named_partition_constructor(X::T) where {T<:AbstractNamedArrayPartition}
getfield(parentmodule(T), nameof(T))
end

# Note: overloading `getproperty` means we cannot access `NamedArrayPartition`
# fields except through `getfield` and accessor functions.
ArrayPartition(x::NamedArrayPartition) = getfield(x, :array_partition)
ArrayPartition(x::AbstractNamedArrayPartition) = getfield(x, :array_partition)

function Base.similar(A::NamedArrayPartition)
NamedArrayPartition(
# With new type structure this function does the same as Base.similar(x::AbstractNamedArrayPartition{T, S, NT}) where {T, S, NT}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this one commented?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because I couldn't manage in either the old or the new method to call the less specific function as providing an NamedArrayPartition will have the type structure X{N, A, NT} and therefore a later version would have instead been called. In the original code the two functions are:

function Base.similar(A::NamedArrayPartition)
    NamedArrayPartition(
        similar(getfield(A, :array_partition)), getfield(A, :names_to_indices))
end

function Base.similar(x::NamedArrayPartition{T, S, NT}) where {T, S, NT}
    NamedArrayPartition{T, S, NT}(
        similar(ArrayPartition(x)), getfield(x, :names_to_indices))
end

But in all cases I tested the second always overwrites the first as any NamedArrayPartition has the structure of the second

#= function Base.similar(A::T) where {T<:AbstractNamedArrayPartition}
Tconstr = named_partition_constructor(A)
Tconstr(
similar(getfield(A, :array_partition)), getfield(A, :names_to_indices))
end
end =#

# return ArrayPartition when possible, otherwise next best thing of the correct size
function Base.similar(A::NamedArrayPartition, dims::NTuple{N, Int}) where {N}
NamedArrayPartition(
function Base.similar(A::T, dims::NTuple{N, Int}) where {T<:AbstractNamedArrayPartition, N}
Tconstr = named_partition_constructor(A)
Tconstr(
similar(getfield(A, :array_partition), dims), getfield(A, :names_to_indices))
end

# similar array partition of common type
@inline function Base.similar(A::NamedArrayPartition, ::Type{T}) where {T}
NamedArrayPartition(
@inline function Base.similar(A::S, ::Type{T}) where {S<:AbstractNamedArrayPartition, T}
Tconstr = named_partition_constructor(A)
Tconstr(
similar(getfield(A, :array_partition), T), getfield(A, :names_to_indices))
end

# return ArrayPartition when possible, otherwise next best thing of the correct size
function Base.similar(A::NamedArrayPartition, ::Type{T}, dims::NTuple{N, Int}) where {T, N}
NamedArrayPartition(
function Base.similar(A::S, ::Type{T}, dims::NTuple{N, Int}) where {T, N, S<:AbstractNamedArrayPartition}
Tconstr = named_partition_constructor(A)
Tconstr(
similar(getfield(A, :array_partition), T, dims), getfield(A, :names_to_indices))
end

# similar array partition with different types
function Base.similar(
A::NamedArrayPartition, ::Type{T}, ::Type{S}, R::DataType...) where {T, S}
NamedArrayPartition(
A::U, ::Type{T}, ::Type{S}, R::DataType...) where {T, S, U<:AbstractNamedArrayPartition}
Tconstr = named_partition_constructor(A)
Tconstr(
similar(getfield(A, :array_partition), T, S, R), getfield(A, :names_to_indices))
end

Base.Array(x::NamedArrayPartition) = Array(ArrayPartition(x))
Base.Array(x::AbstractNamedArrayPartition) = Array(ArrayPartition(x))

function Base.zero(x::NamedArrayPartition{T, S, TN}) where {T, S, TN}
NamedArrayPartition{T, S, TN}(zero(ArrayPartition(x)), getfield(x, :names_to_indices))
function Base.zero(x::R) where {R <: AbstractNamedArrayPartition}
R(zero(ArrayPartition(x)), getfield(x, :names_to_indices))
end
Base.zero(A::NamedArrayPartition, dims::NTuple{N, Int}) where {N} = zero(A) # ignore dims since named array partitions are vectors
Base.zero(A::AbstractNamedArrayPartition, dims::NTuple{N, Int}) where {N} = zero(A) # ignore dims since named array partitions are vectors

Base.propertynames(x::NamedArrayPartition) = propertynames(getfield(x, :names_to_indices))
function Base.getproperty(x::NamedArrayPartition, s::Symbol)
Base.propertynames(x::AbstractNamedArrayPartition) = propertynames(getfield(x, :names_to_indices))
function Base.getproperty(x::AbstractNamedArrayPartition, s::Symbol)
getindex(ArrayPartition(x).x, getproperty(getfield(x, :names_to_indices), s))
end

# this enables x.s = some_array.
@inline function Base.setproperty!(x::NamedArrayPartition, s::Symbol, v)
@inline function Base.setproperty!(x::AbstractNamedArrayPartition, s::Symbol, v)
index = getproperty(getfield(x, :names_to_indices), s)
ArrayPartition(x).x[index] .= v
end

# print out NamedArrayPartition as a NamedTuple
Base.summary(x::NamedArrayPartition) = string(typeof(x), " with arrays:")
function Base.show(io::IO, m::MIME"text/plain", x::NamedArrayPartition)
Base.summary(x::AbstractNamedArrayPartition) = string(typeof(x), " with arrays:")
function Base.show(io::IO, m::MIME"text/plain", x::AbstractNamedArrayPartition)
show(
io, m, NamedTuple(Pair.(keys(getfield(x, :names_to_indices)), ArrayPartition(x).x)))
end

Base.size(x::NamedArrayPartition) = size(ArrayPartition(x))
Base.length(x::NamedArrayPartition) = length(ArrayPartition(x))
Base.getindex(x::NamedArrayPartition, args...) = getindex(ArrayPartition(x), args...)
Base.size(x::AbstractNamedArrayPartition) = size(ArrayPartition(x))
Base.length(x::AbstractNamedArrayPartition) = length(ArrayPartition(x))
Base.getindex(x::AbstractNamedArrayPartition, args...) = getindex(ArrayPartition(x), args...)

Base.setindex!(x::NamedArrayPartition, args...) = setindex!(ArrayPartition(x), args...)
function Base.map(f, x::NamedArrayPartition)
NamedArrayPartition(map(f, ArrayPartition(x)), getfield(x, :names_to_indices))
Base.setindex!(x::AbstractNamedArrayPartition, args...) = setindex!(ArrayPartition(x), args...)
function Base.map(f, x::T) where {T<:AbstractNamedArrayPartition}
Tconstr = named_partition_constructor(x)
Tconstr(map(f, ArrayPartition(x)), getfield(x, :names_to_indices))
end
Base.mapreduce(f, op, x::NamedArrayPartition) = mapreduce(f, op, ArrayPartition(x))
# Base.filter(f, x::NamedArrayPartition) = filter(f, ArrayPartition(x))
Base.mapreduce(f, op, x::AbstractNamedArrayPartition) = mapreduce(f, op, ArrayPartition(x))
# Base.filter(f, x::AbstractNamedArrayPartition) = filter(f, ArrayPartition(x))

function Base.similar(x::NamedArrayPartition{T, S, NT}) where {T, S, NT}
NamedArrayPartition{T, S, NT}(
similar(ArrayPartition(x)), getfield(x, :names_to_indices))
end
function Base.similar(x::AbstractNamedArrayPartition{T, A, NT}) where {T, A, NT}
# Safely extract the concrete type parameters

Tconstr = named_partition_constructor(x)
return Tconstr{T, A, NT}(
similar(getfield(x, :array_partition)),
getfield(x, :names_to_indices)
)
end
# broadcasting
function Base.BroadcastStyle(::Type{<:NamedArrayPartition})
Broadcast.ArrayStyle{NamedArrayPartition}()
function Base.BroadcastStyle(::Type{T}) where{T<:AbstractNamedArrayPartition}
Broadcast.ArrayStyle{T}()
end
function Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}},
::Type{ElType}) where {ElType}
function Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{T}},
::Type{ElType}) where {ElType, T<:AbstractNamedArrayPartition}
x = find_NamedArrayPartition(bc)
return NamedArrayPartition(similar(ArrayPartition(x)), getfield(x, :names_to_indices))
Tconstr = named_partition_constructor(x)
return Tconstr(similar(ArrayPartition(x)), getfield(x, :names_to_indices))
end

# when broadcasting with ArrayPartition + another array type, the output is the other array tupe
function Base.BroadcastStyle(
::Broadcast.ArrayStyle{NamedArrayPartition}, ::Broadcast.DefaultArrayStyle{1})
::Broadcast.ArrayStyle{<:AbstractNamedArrayPartition}, ::Broadcast.DefaultArrayStyle{1})
Broadcast.DefaultArrayStyle{1}()
end

# hook into ArrayPartition broadcasting routines
@inline RecursiveArrayTools.npartitions(x::NamedArrayPartition) = npartitions(ArrayPartition(x))
@inline RecursiveArrayTools.unpack(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}}, i) = Broadcast.Broadcasted(
@inline RecursiveArrayTools.npartitions(x::AbstractNamedArrayPartition) = npartitions(ArrayPartition(x))
@inline RecursiveArrayTools.unpack(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{<:AbstractNamedArrayPartition}}, i) = Broadcast.Broadcasted(
bc.f, RecursiveArrayTools.unpack_args(i, bc.args))
@inline RecursiveArrayTools.unpack(x::NamedArrayPartition, i) = unpack(ArrayPartition(x), i)
@inline RecursiveArrayTools.unpack(x::AbstractNamedArrayPartition, i) = unpack(ArrayPartition(x), i)

function Base.copy(A::NamedArrayPartition{T, S, NT}) where {T, S, NT}
NamedArrayPartition{T, S, NT}(copy(ArrayPartition(A)), getfield(A, :names_to_indices))
function Base.copy(A::AbstractNamedArrayPartition{T, S, NT}) where {T, S, NT}
Tconstr = named_partition_constructor(A)
Tconstr{T, S, NT}(copy(ArrayPartition(A)), getfield(A, :names_to_indices))
end

@inline NamedArrayPartition(f::F, N, names_to_indices) where {F <: Function} = NamedArrayPartition(
@inline (::Type{T})(f::F, N, names_to_indices) where {F <: Function, T<:AbstractNamedArrayPartition} = T(
ArrayPartition(ntuple(f, Val(N))), names_to_indices)

@inline function Base.copy(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}})
@inline function Base.copy(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{T}}) where {T<:AbstractNamedArrayPartition}
N = npartitions(bc)
@inline function f(i)
copy(unpack(bc, i))
end
x = find_NamedArrayPartition(bc)
NamedArrayPartition(f, N, getfield(x, :names_to_indices))
Tconstr = named_partition_constructor(x)
Tconstr(f, N, getfield(x, :names_to_indices))
end

@inline function Base.copyto!(dest::NamedArrayPartition,
bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}})
@inline function Base.copyto!(dest::AbstractNamedArrayPartition,
bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{<:AbstractNamedArrayPartition}})
N = npartitions(dest, bc)
@inline function f(i)
copyto!(ArrayPartition(dest).x[i], unpack(bc, i))
Expand All @@ -146,7 +166,7 @@ end
end

#Overwrite ArrayInterface zeromatrix to work with NamedArrayPartitions & implicit solvers within OrdinaryDiffEq
function ArrayInterface.zeromatrix(A::NamedArrayPartition)
function ArrayInterface.zeromatrix(A::AbstractNamedArrayPartition)
B = ArrayPartition(A)
x = reduce(vcat,vec.(B.x))
x .* x' .* false
Expand All @@ -159,5 +179,5 @@ function find_NamedArrayPartition(args::Tuple)
end
find_NamedArrayPartition(x) = x
find_NamedArrayPartition(::Tuple{}) = nothing
find_NamedArrayPartition(x::NamedArrayPartition, rest) = x
find_NamedArrayPartition(x::AbstractNamedArrayPartition, rest) = x
find_NamedArrayPartition(::Any, rest) = find_NamedArrayPartition(rest)
2 changes: 1 addition & 1 deletion test/named_array_partition_tests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using RecursiveArrayTools, Test
using RecursiveArrayTools, Test, ArrayInterface

@testset "NamedArrayPartition tests" begin
x = NamedArrayPartition(a = ones(10), b = rand(20))
Expand Down
Loading