- 
                Notifications
    You must be signed in to change notification settings 
- Fork 44
          Move StaticArrays support to extension
          #265
        
          New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c81c308
              7171bab
              b0d509b
              58a54f9
              bc629da
              0542152
              02257b9
              b4de96b
              60a8c8c
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| module StructArraysAdaptExt | ||
| # Use Adapt allows for automatic conversion of CPU to GPU StructArrays | ||
| using Adapt, StructArrays | ||
| Adapt.adapt_structure(to, s::StructArray) = replace_storage(adapt(to), s) | ||
|         
                  N5N3 marked this conversation as resolved.
              Show resolved
            Hide resolved | ||
| end | ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| module StructArraysGPUArraysCoreExt | ||
|  | ||
| using StructArrays | ||
| using StructArrays: map_params, array_types | ||
|  | ||
| using Base: tail | ||
|  | ||
| import GPUArraysCore | ||
|  | ||
| # for GPU broadcast | ||
| import GPUArraysCore | ||
| function GPUArraysCore.backend(::Type{T}) where {T<:StructArray} | ||
| backends = map_params(GPUArraysCore.backend, array_types(T)) | ||
| backend, others = backends[1], tail(backends) | ||
| isconsistent = mapfoldl(isequal(backend), &, others; init=true) | ||
| isconsistent || throw(ArgumentError("all component arrays must have the same GPU backend")) | ||
| return backend | ||
| end | ||
| StructArrays.always_struct_broadcast(::GPUArraysCore.AbstractGPUArrayStyle) = true | ||
|  | ||
| end # module | 
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,107 @@ | ||
| module StructArraysStaticArraysExt | ||
|  | ||
| using StructArrays | ||
| using StaticArrays: StaticArray, FieldArray, tuple_prod | ||
|  | ||
| """ | ||
| StructArrays.staticschema(::Type{<:StaticArray{S, T}}) where {S, T} | ||
|  | ||
| The `staticschema` of a `StaticArray` element type is the `staticschema` of the underlying `Tuple`. | ||
| ```julia | ||
| julia> StructArrays.staticschema(SVector{2, Float64}) | ||
| Tuple{Float64, Float64} | ||
| ``` | ||
| The one exception to this rule is `<:StaticArrays.FieldArray`, since `FieldArray` is based on a | ||
| struct. In this case, `staticschema(<:FieldArray)` returns the `staticschema` for the struct | ||
| which subtypes `FieldArray`. | ||
| """ | ||
| @generated function StructArrays.staticschema(::Type{<:StaticArray{S, T}}) where {S, T} | ||
| return quote | ||
| Base.@_inline_meta | ||
| return NTuple{$(tuple_prod(S)), T} | ||
| end | ||
| end | ||
| StructArrays.createinstance(::Type{T}, args...) where {T<:StaticArray} = T(args) | ||
| StructArrays.component(s::StaticArray, i) = getindex(s, i) | ||
|  | ||
| # invoke general fallbacks for a `FieldArray` type. | ||
| @inline function StructArrays.staticschema(T::Type{<:FieldArray}) | ||
| invoke(StructArrays.staticschema, Tuple{Type{<:Any}}, T) | ||
| end | ||
| StructArrays.component(s::FieldArray, i) = invoke(StructArrays.component, Tuple{Any, Any}, s, i) | ||
| StructArrays.createinstance(T::Type{<:FieldArray}, args...) = invoke(StructArrays.createinstance, Tuple{Type{<:Any}, Vararg}, T, args...) | ||
|  | ||
| # Broadcast overload | ||
| using StaticArrays: StaticArrayStyle, similar_type, Size, SOneTo | ||
| using StaticArrays: broadcast_flatten, broadcast_sizes, first_statictype | ||
| using StructArrays: isnonemptystructtype | ||
| using Base.Broadcast: Broadcasted, _broadcast_getindex | ||
|  | ||
| # StaticArrayStyle has no similar defined. | ||
| # Overload `try_struct_copy` instead. | ||
| @inline function StructArrays.try_struct_copy(bc::Broadcasted{StaticArrayStyle{M}}) where {M} | ||
| flat = broadcast_flatten(bc); as = flat.args; f = flat.f | ||
| argsizes = broadcast_sizes(as...) | ||
| ax = axes(bc) | ||
| ax isa Tuple{Vararg{SOneTo}} || error("Dimension is not static. Please file a bug at `StaticArrays.jl`.") | ||
| return _broadcast(f, Size(map(length, ax)), argsizes, as...) | ||
| end | ||
|  | ||
| # A functor generates the ith component of StructStaticBroadcast. | ||
| struct Similar_ith{SA, E<:Tuple} | ||
| elements::E | ||
| Similar_ith{SA}(elements::Tuple) where {SA} = new{SA, typeof(elements)}(elements) | ||
| end | ||
| function (s::Similar_ith{SA})(i::Int) where {SA} | ||
| ith_elements = ntuple(Val(length(s.elements))) do j | ||
| getfield(s.elements[j], i) | ||
| end | ||
| ith_SA = similar_type(SA, fieldtype(eltype(SA), i)) | ||
| return @inbounds ith_SA(ith_elements) | ||
| end | ||
|  | ||
| @inline function _broadcast(f, sz::Size{newsize}, s::Tuple{Vararg{Size}}, a...) where {newsize} | ||
|         
                  N5N3 marked this conversation as resolved.
              Show resolved
            Hide resolved | ||
| first_staticarray = first_statictype(a...) | ||
| elements, ET = if prod(newsize) == 0 | ||
|         
                  N5N3 marked this conversation as resolved.
              Show resolved
            Hide resolved | ||
| # Use inference to get eltype in empty case (following StaticBroadcast defined in StaticArrays.jl) | ||
| eltys = Tuple{map(eltype, a)...} | ||
| (), Core.Compiler.return_type(f, eltys) | ||
| else | ||
| temp = __broadcast(f, sz, s, a...) | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This part worries me a little bit, we are using something explicitly marked as internal in StaticArrays. Is there no way to achieve this using only public methods? Or maybe we could check over at StaticArrays if they can offer some solution (maybe add a public method that does what we need). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well  There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah I see, thanks for pointing out that discussion. In that case, maybe one could just add a small docstring in  There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That makes sense too. | ||
| temp, eltype(temp) | ||
| end | ||
| if isnonemptystructtype(ET) | ||
| SA = similar_type(first_staticarray, ET, sz) | ||
| arrs = ntuple(Similar_ith{SA}(elements), Val(fieldcount(ET))) | ||
| return StructArray{ET}(arrs) | ||
| else | ||
| @inbounds return similar_type(first_staticarray, ET, sz)(elements) | ||
| end | ||
| end | ||
|  | ||
| # The `__broadcast` kernal is copied from `StaticArrays.jl`. | ||
| # see https://github.com/JuliaArrays/StaticArrays.jl/blob/master/src/broadcast.jl | ||
| @generated function __broadcast(f, ::Size{newsize}, s::Tuple{Vararg{Size}}, a...) where newsize | ||
| sizes = [sz.parameters[1] for sz ∈ s.parameters] | ||
|  | ||
| indices = CartesianIndices(newsize) | ||
| exprs = similar(indices, Expr) | ||
| for (j, current_ind) ∈ enumerate(indices) | ||
| exprs_vals = (broadcast_getindex(sz, i, current_ind) for (i, sz) in enumerate(sizes)) | ||
| exprs[j] = :(f($(exprs_vals...))) | ||
| end | ||
|  | ||
| return quote | ||
| Base.@_inline_meta | ||
| return tuple($(exprs...)) | ||
| end | ||
| end | ||
|  | ||
| broadcast_getindex(::Tuple{}, i::Int, I::CartesianIndex) = return :(_broadcast_getindex(a[$i], $I)) | ||
| function broadcast_getindex(oldsize::Tuple, i::Int, newindex::CartesianIndex) | ||
| li = LinearIndices(oldsize) | ||
| ind = _broadcast_getindex(li, newindex) | ||
| return :(a[$i][$ind]) | ||
| end | ||
|  | ||
| end | ||
Uh oh!
There was an error while loading. Please reload this page.