|  | 
|  | 1 | +module StructArraysStaticArraysExt | 
|  | 2 | + | 
|  | 3 | +using StructArrays | 
|  | 4 | +using StaticArrays: StaticArray, FieldArray, tuple_prod | 
|  | 5 | + | 
|  | 6 | +""" | 
|  | 7 | +    StructArrays.staticschema(::Type{<:StaticArray{S, T}}) where {S, T} | 
|  | 8 | +
 | 
|  | 9 | +The `staticschema` of a `StaticArray` element type is the `staticschema` of the underlying `Tuple`. | 
|  | 10 | +```julia | 
|  | 11 | +julia> StructArrays.staticschema(SVector{2, Float64}) | 
|  | 12 | +Tuple{Float64, Float64} | 
|  | 13 | +``` | 
|  | 14 | +The one exception to this rule is `<:StaticArrays.FieldArray`, since `FieldArray` is based on a  | 
|  | 15 | +struct. In this case, `staticschema(<:FieldArray)` returns the `staticschema` for the struct  | 
|  | 16 | +which subtypes `FieldArray`.  | 
|  | 17 | +""" | 
|  | 18 | +@generated function StructArrays.staticschema(::Type{<:StaticArray{S, T}}) where {S, T} | 
|  | 19 | +    return quote | 
|  | 20 | +        Base.@_inline_meta | 
|  | 21 | +        return NTuple{$(tuple_prod(S)), T} | 
|  | 22 | +    end | 
|  | 23 | +end | 
|  | 24 | +StructArrays.createinstance(::Type{T}, args...) where {T<:StaticArray} = T(args) | 
|  | 25 | +StructArrays.component(s::StaticArray, i) = getindex(s, i) | 
|  | 26 | + | 
|  | 27 | +# invoke general fallbacks for a `FieldArray` type. | 
|  | 28 | +@inline function StructArrays.staticschema(T::Type{<:FieldArray}) | 
|  | 29 | +    invoke(StructArrays.staticschema, Tuple{Type{<:Any}}, T) | 
|  | 30 | +end | 
|  | 31 | +StructArrays.component(s::FieldArray, i) = invoke(StructArrays.component, Tuple{Any, Any}, s, i) | 
|  | 32 | +StructArrays.createinstance(T::Type{<:FieldArray}, args...) = invoke(StructArrays.createinstance, Tuple{Type{<:Any}, Vararg}, T, args...) | 
|  | 33 | + | 
|  | 34 | +# Broadcast overload | 
|  | 35 | +using StaticArrays: StaticArrayStyle, similar_type, Size, SOneTo | 
|  | 36 | +using StaticArrays: broadcast_flatten, broadcast_sizes, first_statictype | 
|  | 37 | +using StructArrays: isnonemptystructtype | 
|  | 38 | +using Base.Broadcast: Broadcasted, _broadcast_getindex | 
|  | 39 | + | 
|  | 40 | +# StaticArrayStyle has no similar defined. | 
|  | 41 | +# Overload `try_struct_copy` instead. | 
|  | 42 | +@inline function StructArrays.try_struct_copy(bc::Broadcasted{StaticArrayStyle{M}}) where {M} | 
|  | 43 | +    flat = broadcast_flatten(bc); as = flat.args; f = flat.f | 
|  | 44 | +    argsizes = broadcast_sizes(as...) | 
|  | 45 | +    ax = axes(bc) | 
|  | 46 | +    ax isa Tuple{Vararg{SOneTo}} || error("Dimension is not static. Please file a bug at `StaticArrays.jl`.") | 
|  | 47 | +    return _broadcast(f, Size(map(length, ax)), argsizes, as...) | 
|  | 48 | +end | 
|  | 49 | + | 
|  | 50 | +# A functor generates the ith component of StructStaticBroadcast. | 
|  | 51 | +struct Similar_ith{SA, E<:Tuple} | 
|  | 52 | +    elements::E | 
|  | 53 | +    Similar_ith{SA}(elements::Tuple) where {SA} = new{SA, typeof(elements)}(elements) | 
|  | 54 | +end | 
|  | 55 | +function (s::Similar_ith{SA})(i::Int) where {SA} | 
|  | 56 | +    ith_elements = ntuple(Val(length(s.elements))) do j | 
|  | 57 | +        getfield(s.elements[j], i) | 
|  | 58 | +    end | 
|  | 59 | +    ith_SA = similar_type(SA, fieldtype(eltype(SA), i)) | 
|  | 60 | +    return @inbounds ith_SA(ith_elements) | 
|  | 61 | +end | 
|  | 62 | + | 
|  | 63 | +@inline function _broadcast(f, sz::Size{newsize}, s::Tuple{Vararg{Size}}, a...) where {newsize} | 
|  | 64 | +    first_staticarray = first_statictype(a...) | 
|  | 65 | +    elements, ET = if prod(newsize) == 0 | 
|  | 66 | +        # Use inference to get eltype in empty case (following StaticBroadcast defined in StaticArrays.jl) | 
|  | 67 | +        eltys = Tuple{map(eltype, a)...} | 
|  | 68 | +        (), Core.Compiler.return_type(f, eltys) | 
|  | 69 | +    else | 
|  | 70 | +        temp = __broadcast(f, sz, s, a...) | 
|  | 71 | +        temp, eltype(temp) | 
|  | 72 | +    end | 
|  | 73 | +    if isnonemptystructtype(ET) | 
|  | 74 | +        SA = similar_type(first_staticarray, ET, sz) | 
|  | 75 | +        arrs = ntuple(Similar_ith{SA}(elements), Val(fieldcount(ET))) | 
|  | 76 | +        return StructArray{ET}(arrs) | 
|  | 77 | +    else | 
|  | 78 | +        @inbounds return similar_type(first_staticarray, ET, sz)(elements) | 
|  | 79 | +    end | 
|  | 80 | +end | 
|  | 81 | + | 
|  | 82 | +# The `__broadcast` kernal is copied from `StaticArrays.jl`. | 
|  | 83 | +# see https://github.com/JuliaArrays/StaticArrays.jl/blob/master/src/broadcast.jl | 
|  | 84 | +@generated function __broadcast(f, ::Size{newsize}, s::Tuple{Vararg{Size}}, a...) where newsize | 
|  | 85 | +    sizes = [sz.parameters[1] for sz ∈ s.parameters] | 
|  | 86 | + | 
|  | 87 | +    indices = CartesianIndices(newsize) | 
|  | 88 | +    exprs = similar(indices, Expr) | 
|  | 89 | +    for (j, current_ind) ∈ enumerate(indices) | 
|  | 90 | +        exprs_vals = (broadcast_getindex(sz, i, current_ind) for (i, sz) in enumerate(sizes)) | 
|  | 91 | +        exprs[j] = :(f($(exprs_vals...))) | 
|  | 92 | +    end | 
|  | 93 | + | 
|  | 94 | +    return quote | 
|  | 95 | +        Base.@_inline_meta | 
|  | 96 | +        return tuple($(exprs...)) | 
|  | 97 | +    end | 
|  | 98 | +end | 
|  | 99 | + | 
|  | 100 | +broadcast_getindex(::Tuple{}, i::Int, I::CartesianIndex) = return :(_broadcast_getindex(a[$i], $I)) | 
|  | 101 | +function broadcast_getindex(oldsize::Tuple, i::Int, newindex::CartesianIndex) | 
|  | 102 | +    li = LinearIndices(oldsize) | 
|  | 103 | +    ind = _broadcast_getindex(li, newindex) | 
|  | 104 | +    return :(a[$i][$ind]) | 
|  | 105 | +end | 
|  | 106 | + | 
|  | 107 | +end | 
0 commit comments