- 
                Notifications
    You must be signed in to change notification settings 
- Fork 44
          Move StaticArrays support to extension
          #265
        
          New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
c81c308
              7171bab
              b0d509b
              58a54f9
              bc629da
              0542152
              02257b9
              b4de96b
              60a8c8c
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,16 @@ | ||
| module StructArraysAdaptExt | ||
| # Use Adapt allows for automatic conversion of CPU to GPU StructArrays | ||
| using Adapt, StructArrays | ||
| @static if !applicable(Adapt.adapt, Int) | ||
| # Adapt.jl has curried support, implement it ourself | ||
|         
                  N5N3 marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
| adpat(to) = Base.Fix1(Adapt.adapt, to) | ||
|         
                  N5N3 marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
| if VERSION < v"1.9.0-DEV.857" | ||
| @eval function adapt(to::Type{T}) where {T} | ||
| (@isdefined T) || return Base.Fix1(Adapt.adapt, to) | ||
| AT = Base.Fix1{typeof(Adapt.adapt),Type{T}} | ||
| return $(Expr(:new, :AT, :(Adapt.adapt), :to)) | ||
| end | ||
| end | ||
| end | ||
| Adapt.adapt_structure(to, s::StructArray) = replace_storage(adapt(to), s) | ||
|         
                  N5N3 marked this conversation as resolved.
              Show resolved
            Hide resolved | ||
| end | ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| module StructArraysGPUArraysCoreExt | ||
|  | ||
| using StructArrays | ||
| using StructArrays: map_params, array_types | ||
|  | ||
| using Base: tail | ||
|  | ||
| import GPUArraysCore | ||
|  | ||
| # for GPU broadcast | ||
| import GPUArraysCore | ||
| function GPUArraysCore.backend(::Type{T}) where {T<:StructArray} | ||
| backends = map_params(GPUArraysCore.backend, array_types(T)) | ||
| backend, others = backends[1], tail(backends) | ||
| isconsistent = mapfoldl(isequal(backend), &, others; init=true) | ||
| isconsistent || throw(ArgumentError("all component arrays must have the same GPU backend")) | ||
| return backend | ||
| end | ||
| StructArrays.always_struct_broadcast(::GPUArraysCore.AbstractGPUArrayStyle) = true | ||
|  | ||
| end # module | 
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,82 @@ | ||
| module StructArraysStaticArraysExt | ||
|  | ||
| using StructArrays | ||
| using StaticArrays: StaticArray, FieldArray, tuple_prod | ||
|  | ||
| """ | ||
| StructArrays.staticschema(::Type{<:StaticArray{S, T}}) where {S, T} | ||
| The `staticschema` of a `StaticArray` element type is the `staticschema` of the underlying `Tuple`. | ||
| ```julia | ||
| julia> StructArrays.staticschema(SVector{2, Float64}) | ||
| Tuple{Float64, Float64} | ||
| ``` | ||
| The one exception to this rule is `<:StaticArrays.FieldArray`, since `FieldArray` is based on a | ||
| struct. In this case, `staticschema(<:FieldArray)` returns the `staticschema` for the struct | ||
| which subtypes `FieldArray`. | ||
| """ | ||
| @generated function StructArrays.staticschema(::Type{<:StaticArray{S, T}}) where {S, T} | ||
| return quote | ||
| Base.@_inline_meta | ||
| return NTuple{$(tuple_prod(S)), T} | ||
| end | ||
| end | ||
| StructArrays.createinstance(::Type{T}, args...) where {T<:StaticArray} = T(args) | ||
| StructArrays.component(s::StaticArray, i) = getindex(s, i) | ||
|  | ||
| # invoke general fallbacks for a `FieldArray` type. | ||
| @inline function StructArrays.staticschema(T::Type{<:FieldArray}) | ||
| invoke(StructArrays.staticschema, Tuple{Type{<:Any}}, T) | ||
| end | ||
| StructArrays.component(s::FieldArray, i) = invoke(StructArrays.component, Tuple{Any, Any}, s, i) | ||
| StructArrays.createinstance(T::Type{<:FieldArray}, args...) = invoke(StructArrays.createinstance, Tuple{Type{<:Any}, Vararg}, T, args...) | ||
|  | ||
| # Broadcast overload | ||
| using StaticArrays: StaticArrayStyle, similar_type, Size, SOneTo | ||
| using StaticArrays: broadcast_flatten, broadcast_sizes, first_statictype, __broadcast | ||
| using StructArrays: isnonemptystructtype | ||
| using Base.Broadcast: Broadcasted | ||
|  | ||
| # StaticArrayStyle has no similar defined. | ||
| # Overload `try_struct_copy` instead. | ||
| @inline function StructArrays.try_struct_copy(bc::Broadcasted{StaticArrayStyle{M}}) where {M} | ||
| flat = broadcast_flatten(bc); as = flat.args; f = flat.f | ||
| argsizes = broadcast_sizes(as...) | ||
| ax = axes(bc) | ||
| ax isa Tuple{Vararg{SOneTo}} || error("Dimension is not static. Please file a bug at `StaticArrays.jl`.") | ||
| return _broadcast(f, Size(map(length, ax)), argsizes, as...) | ||
| end | ||
|  | ||
| @inline function _broadcast(f, sz::Size{newsize}, s::Tuple{Vararg{Size}}, a...) where {newsize} | ||
|         
                  N5N3 marked this conversation as resolved.
              Show resolved
            Hide resolved | ||
| first_staticarray = first_statictype(a...) | ||
| elements, ET = if prod(newsize) == 0 | ||
|         
                  N5N3 marked this conversation as resolved.
              Show resolved
            Hide resolved | ||
| # Use inference to get eltype in empty case (see also comments in _map) | ||
|         
                  N5N3 marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
| eltys = Tuple{map(eltype, a)...} | ||
| (), Core.Compiler.return_type(f, eltys) | ||
| else | ||
| temp = __broadcast(f, sz, s, a...) | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This part worries me a little bit, we are using something explicitly marked as internal in StaticArrays. Is there no way to achieve this using only public methods? Or maybe we could check over at StaticArrays if they can offer some solution (maybe add a public method that does what we need). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well  There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah I see, thanks for pointing out that discussion. In that case, maybe one could just add a small docstring in  There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That makes sense too. | ||
| temp, eltype(temp) | ||
| end | ||
| if isnonemptystructtype(ET) | ||
| @static if VERSION >= v"1.7" | ||
| arrs = ntuple(Val(fieldcount(ET))) do i | ||
| @inbounds similar_type(first_staticarray, fieldtype(ET, i), sz)(_getfields(elements, i)) | ||
| end | ||
| else | ||
| similarET(::Type{SA}, ::Type{T}) where {SA, T} = i -> @inbounds similar_type(SA, fieldtype(T, i), sz)(_getfields(elements, i)) | ||
| arrs = ntuple(similarET(first_staticarray, ET), Val(fieldcount(ET))) | ||
| end | ||
|         
                  N5N3 marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
| return StructArray{ET}(arrs) | ||
| end | ||
| @inbounds return similar_type(first_staticarray, ET, sz)(elements) | ||
|         
                  N5N3 marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
| end | ||
|  | ||
| @inline function _getfields(x::Tuple, i::Int) | ||
|         
                  N5N3 marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
| if @generated | ||
| return Expr(:tuple, (:(getfield(x[$j], i)) for j in 1:fieldcount(x))...) | ||
| else | ||
| return map(Base.Fix2(getfield, i), x) | ||
| end | ||
| end | ||
|  | ||
| end | ||
Uh oh!
There was an error while loading. Please reload this page.