|
| 1 | +module StructArrays |
| 2 | + |
| 3 | +struct StructArray{T,N,C <: Tuple{Vararg{AbstractArray{<:Any,N}}}} <: AbstractArray{T,N} |
| 4 | + components :: C |
| 5 | + |
| 6 | + function StructArray{T,N,C}(components::C) where {T,N,C} |
| 7 | + fieldcount(T) == length(components) || throw(ArgumentError("number of components incompatible with eltype")) |
| 8 | + allequal(axes.(components)) || throw(ArgumentError("component arrays must have the same axes")) |
| 9 | + new{T,N,C}(components) |
| 10 | + end |
| 11 | +end |
| 12 | + |
| 13 | +function StructArray{T}(components::Tuple{Vararg{AbstractArray{<:Any,N}}}) where {T,N} |
| 14 | + StructArray{T,N,typeof(components)}(components) |
| 15 | +end |
| 16 | + |
| 17 | +Base.size(S::StructArray) = size(S.components[1]) |
| 18 | +Base.axes(S::StructArray) = axes(S.components[1]) |
| 19 | +function Base.getindex(S::StructArray{T,N}, inds::Vararg{Int,N}) where {T,N} |
| 20 | + vals = map(x -> x[inds...], S.components) |
| 21 | + T(vals...) |
| 22 | +end |
| 23 | +function Base.setindex!(S::StructArray{T,N}, val, inds::Vararg{Int,N}) where {T,N} |
| 24 | + vals = getfield.(Ref(convert(T, val)), fieldnames(T)) |
| 25 | + for (A,v) in zip(S.components, vals) |
| 26 | + A[inds...] = v |
| 27 | + end |
| 28 | + S |
| 29 | +end |
| 30 | + |
| 31 | +isnonemptystructtype(::Type{T}) where {T} = isstructtype(T) && fieldcount(T) != 0 |
| 32 | + |
| 33 | +function Base.similar(S::StructArray, ::Type{T}, dims::Tuple{Int, Vararg{Int}}) where {T} |
| 34 | + isnonemptystructtype(T) || return similar(S.components[1], T, dims) |
| 35 | + arrs = similar.(S.components, fieldtypes(T), Ref(dims)) |
| 36 | + StructArray{T}(arrs) |
| 37 | +end |
| 38 | + |
| 39 | +end |
0 commit comments