|
| 1 | +import ParallelStencil |
| 2 | +using ParallelStencil.ParallelKernel.Exceptions: @KeywordArgumentError |
| 3 | + |
| 4 | +const PK = ParallelStencil.ParallelKernel |
| 5 | + |
| 6 | +# Helper to map KernelAbstractions allocations to the runtime backend. |
| 7 | +function runtime_allocator_symbol(kind::Symbol, hardware::Union{Symbol,Nothing}) |
| 8 | + target_package, _, _ = PK.resolve_runtime_backend(PK.PKG_KERNELABSTRACTIONS, hardware) |
| 9 | + suffix = PK.allocator_suffix_for(target_package) |
| 10 | + if suffix === nothing |
| 11 | + @KeywordArgumentError("$(PK.ERRMSG_UNSUPPORTED_PACKAGE) (obtained: $target_package).") |
| 12 | + end |
| 13 | + return PK.allocator_function_symbol(kind, suffix) |
| 14 | +end |
| 15 | + |
| 16 | +invoke_runtime_allocator(kind::Symbol, hardware::Union{Symbol,Nothing}, args...) = |
| 17 | + getfield(PK, runtime_allocator_symbol(kind, hardware))(args...) |
| 18 | + |
| 19 | + |
| 20 | +## RUNTIME ALLOCATOR FUNCTIONS |
| 21 | + |
| 22 | +function PK.zeros_kernelabstractions(::Type{T}, blocklength, args...; hardware::Union{Symbol,Nothing}=nothing) where {T<:Number} |
| 23 | + return invoke_runtime_allocator(:zeros, hardware, T, blocklength, args...) |
| 24 | +end |
| 25 | + |
| 26 | +function PK.ones_kernelabstractions(::Type{T}, blocklength, args...; hardware::Union{Symbol,Nothing}=nothing) where {T<:Number} |
| 27 | + return invoke_runtime_allocator(:ones, hardware, T, blocklength, args...) |
| 28 | +end |
| 29 | + |
| 30 | +function PK.rand_kernelabstractions(::Type{T}, blocklength, args...; hardware::Union{Symbol,Nothing}=nothing) where {T<:Union{Number,Enum}} |
| 31 | + return invoke_runtime_allocator(:rand, hardware, T, blocklength, args...) |
| 32 | +end |
| 33 | + |
| 34 | +function PK.falses_kernelabstractions(::Type{T}, blocklength, args...; hardware::Union{Symbol,Nothing}=nothing) where {T<:Bool} |
| 35 | + return invoke_runtime_allocator(:falses, hardware, T, blocklength, args...) |
| 36 | +end |
| 37 | + |
| 38 | +function PK.trues_kernelabstractions(::Type{T}, blocklength, args...; hardware::Union{Symbol,Nothing}=nothing) where {T<:Bool} |
| 39 | + return invoke_runtime_allocator(:trues, hardware, T, blocklength, args...) |
| 40 | +end |
| 41 | + |
| 42 | +function PK.fill_kernelabstractions(::Type{T}, blocklength, args...; hardware::Union{Symbol,Nothing}=nothing) where {T<:Union{Number,Enum}} |
| 43 | + return invoke_runtime_allocator(:fill, hardware, T, blocklength, args...) |
| 44 | +end |
| 45 | + |
| 46 | +function PK.zeros_kernelabstractions(::Type{T}, blocklength, args...; hardware::Union{Symbol,Nothing}=nothing) where {T<:Union{SArray,FieldArray}} |
| 47 | + return invoke_runtime_allocator(:zeros, hardware, T, blocklength, args...) |
| 48 | +end |
| 49 | + |
| 50 | +function PK.ones_kernelabstractions(::Type{T}, blocklength, args...; hardware::Union{Symbol,Nothing}=nothing) where {T<:Union{SArray,FieldArray}} |
| 51 | + return invoke_runtime_allocator(:ones, hardware, T, blocklength, args...) |
| 52 | +end |
| 53 | + |
| 54 | +function PK.rand_kernelabstractions(::Type{T}, blocklength::Val{B}, dims; hardware::Union{Symbol,Nothing}=nothing) where {T<:Union{SArray,FieldArray},B} |
| 55 | + return invoke_runtime_allocator(:rand, hardware, T, blocklength, dims) |
| 56 | +end |
| 57 | + |
| 58 | +function PK.rand_kernelabstractions(::Type{T}, blocklength, dims...; hardware::Union{Symbol,Nothing}=nothing) where {T<:Union{SArray,FieldArray}} |
| 59 | + return PK.rand_kernelabstractions(T, blocklength, dims; hardware=hardware) |
| 60 | +end |
| 61 | + |
| 62 | +function PK.falses_kernelabstractions(::Type{T}, blocklength, args...; hardware::Union{Symbol,Nothing}=nothing) where {T<:Union{SArray,FieldArray}} |
| 63 | + return invoke_runtime_allocator(:falses, hardware, T, blocklength, args...) |
| 64 | +end |
| 65 | + |
| 66 | +function PK.trues_kernelabstractions(::Type{T}, blocklength, args...; hardware::Union{Symbol,Nothing}=nothing) where {T<:Union{SArray,FieldArray}} |
| 67 | + return invoke_runtime_allocator(:trues, hardware, T, blocklength, args...) |
| 68 | +end |
| 69 | + |
| 70 | +function PK.fill_kernelabstractions(::Type{T}, blocklength::Val{B}, x, args...; hardware::Union{Symbol,Nothing}=nothing) where {T<:Union{SArray,FieldArray},B} |
| 71 | + return invoke_runtime_allocator(:fill, hardware, T, blocklength, x, args...) |
| 72 | +end |
| 73 | + |
| 74 | +function PK.fill_kernelabstractions!(A, x; hardware::Union{Symbol,Nothing}=nothing) |
| 75 | + return invoke_runtime_allocator(:fill!, hardware, A, x) |
| 76 | +end |
0 commit comments