Skip to content

Commit 959d8ac

Browse files
committed
add ka extension
1 parent c8a5dc2 commit 959d8ac

File tree

3 files changed

+153
-0
lines changed

3 files changed

+153
-0
lines changed
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import ParallelStencil
2+
using ParallelStencil.ParallelKernel.Exceptions: @KeywordArgumentError
3+
4+
const PK = ParallelStencil.ParallelKernel
5+
6+
# Helper to map KernelAbstractions allocations to the runtime backend.
7+
function runtime_allocator_symbol(kind::Symbol, hardware::Union{Symbol,Nothing})
8+
target_package, _, _ = PK.resolve_runtime_backend(PK.PKG_KERNELABSTRACTIONS, hardware)
9+
suffix = PK.allocator_suffix_for(target_package)
10+
if suffix === nothing
11+
@KeywordArgumentError("$(PK.ERRMSG_UNSUPPORTED_PACKAGE) (obtained: $target_package).")
12+
end
13+
return PK.allocator_function_symbol(kind, suffix)
14+
end
15+
16+
invoke_runtime_allocator(kind::Symbol, hardware::Union{Symbol,Nothing}, args...) =
17+
getfield(PK, runtime_allocator_symbol(kind, hardware))(args...)
18+
19+
20+
## RUNTIME ALLOCATOR FUNCTIONS
21+
22+
function PK.zeros_kernelabstractions(::Type{T}, blocklength, args...; hardware::Union{Symbol,Nothing}=nothing) where {T<:Number}
23+
return invoke_runtime_allocator(:zeros, hardware, T, blocklength, args...)
24+
end
25+
26+
function PK.ones_kernelabstractions(::Type{T}, blocklength, args...; hardware::Union{Symbol,Nothing}=nothing) where {T<:Number}
27+
return invoke_runtime_allocator(:ones, hardware, T, blocklength, args...)
28+
end
29+
30+
function PK.rand_kernelabstractions(::Type{T}, blocklength, args...; hardware::Union{Symbol,Nothing}=nothing) where {T<:Union{Number,Enum}}
31+
return invoke_runtime_allocator(:rand, hardware, T, blocklength, args...)
32+
end
33+
34+
function PK.falses_kernelabstractions(::Type{T}, blocklength, args...; hardware::Union{Symbol,Nothing}=nothing) where {T<:Bool}
35+
return invoke_runtime_allocator(:falses, hardware, T, blocklength, args...)
36+
end
37+
38+
function PK.trues_kernelabstractions(::Type{T}, blocklength, args...; hardware::Union{Symbol,Nothing}=nothing) where {T<:Bool}
39+
return invoke_runtime_allocator(:trues, hardware, T, blocklength, args...)
40+
end
41+
42+
function PK.fill_kernelabstractions(::Type{T}, blocklength, args...; hardware::Union{Symbol,Nothing}=nothing) where {T<:Union{Number,Enum}}
43+
return invoke_runtime_allocator(:fill, hardware, T, blocklength, args...)
44+
end
45+
46+
function PK.zeros_kernelabstractions(::Type{T}, blocklength, args...; hardware::Union{Symbol,Nothing}=nothing) where {T<:Union{SArray,FieldArray}}
47+
return invoke_runtime_allocator(:zeros, hardware, T, blocklength, args...)
48+
end
49+
50+
function PK.ones_kernelabstractions(::Type{T}, blocklength, args...; hardware::Union{Symbol,Nothing}=nothing) where {T<:Union{SArray,FieldArray}}
51+
return invoke_runtime_allocator(:ones, hardware, T, blocklength, args...)
52+
end
53+
54+
function PK.rand_kernelabstractions(::Type{T}, blocklength::Val{B}, dims; hardware::Union{Symbol,Nothing}=nothing) where {T<:Union{SArray,FieldArray},B}
55+
return invoke_runtime_allocator(:rand, hardware, T, blocklength, dims)
56+
end
57+
58+
function PK.rand_kernelabstractions(::Type{T}, blocklength, dims...; hardware::Union{Symbol,Nothing}=nothing) where {T<:Union{SArray,FieldArray}}
59+
return PK.rand_kernelabstractions(T, blocklength, dims; hardware=hardware)
60+
end
61+
62+
function PK.falses_kernelabstractions(::Type{T}, blocklength, args...; hardware::Union{Symbol,Nothing}=nothing) where {T<:Union{SArray,FieldArray}}
63+
return invoke_runtime_allocator(:falses, hardware, T, blocklength, args...)
64+
end
65+
66+
function PK.trues_kernelabstractions(::Type{T}, blocklength, args...; hardware::Union{Symbol,Nothing}=nothing) where {T<:Union{SArray,FieldArray}}
67+
return invoke_runtime_allocator(:trues, hardware, T, blocklength, args...)
68+
end
69+
70+
function PK.fill_kernelabstractions(::Type{T}, blocklength::Val{B}, x, args...; hardware::Union{Symbol,Nothing}=nothing) where {T<:Union{SArray,FieldArray},B}
71+
return invoke_runtime_allocator(:fill, hardware, T, blocklength, x, args...)
72+
end
73+
74+
function PK.fill_kernelabstractions!(A, x; hardware::Union{Symbol,Nothing}=nothing)
75+
return invoke_runtime_allocator(:fill!, hardware, A, x)
76+
end
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
const ERRMSG_KERNELABSTRACTIONSEXT_NOT_LOADED = "the KernelAbstractions extension was not loaded. Make sure to import KernelAbstractions before ParallelStencil."
2+
3+
4+
# shared.jl
5+
6+
function get_priority_kastream end
7+
function get_kastream end
8+
function get_kernelabstractions_compute_capability end
9+
10+
11+
# allocators.jl
12+
13+
zeros_kernelabstractions(arg...) = @NotLoadedError(ERRMSG_KERNELABSTRACTIONSEXT_NOT_LOADED)
14+
ones_kernelabstractions(arg...) = @NotLoadedError(ERRMSG_KERNELABSTRACTIONSEXT_NOT_LOADED)
15+
rand_kernelabstractions(arg...) = @NotLoadedError(ERRMSG_KERNELABSTRACTIONSEXT_NOT_LOADED)
16+
falses_kernelabstractions(arg...) = @NotLoadedError(ERRMSG_KERNELABSTRACTIONSEXT_NOT_LOADED)
17+
trues_kernelabstractions(arg...) = @NotLoadedError(ERRMSG_KERNELABSTRACTIONSEXT_NOT_LOADED)
18+
fill_kernelabstractions(arg...) = @NotLoadedError(ERRMSG_KERNELABSTRACTIONSEXT_NOT_LOADED)
19+
fill_kernelabstractions!(arg...) = @NotLoadedError(ERRMSG_KERNELABSTRACTIONSEXT_NOT_LOADED)
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import ParallelStencil
2+
import KernelAbstractions
3+
using ParallelStencil.ParallelKernel: PKG_KERNELABSTRACTIONS, PKG_CUDA, PKG_AMDGPU, PKG_METAL, PKG_THREADS, PKG_POLYESTER,
4+
resolve_runtime_backend,
5+
get_priority_custream, get_custream, get_priority_rocstream, get_rocstream,
6+
get_priority_metalstream, get_metalstream,
7+
get_cuda_compute_capability, get_amdgpu_compute_capability, get_metal_compute_capability, get_cpu_compute_capability,
8+
ERRMSG_UNSUPPORTED_PACKAGE
9+
using ParallelStencil.ParallelKernel.Exceptions: @ArgumentError
10+
11+
const PK = ParallelStencil.ParallelKernel
12+
13+
14+
## FUNCTIONS TO CHECK EXTENSION SUPPORT
15+
16+
PK.is_loaded(::Val{:ParallelStencil_KernelAbstractionsExt}) = true
17+
18+
19+
## FUNCTIONS TO GET OR MANAGE KERNELABSTRACTIONS STREAMS
20+
21+
function dispatch_kastream(kind::Symbol, id::Integer, hardware::Union{Symbol,Nothing})
22+
target_package, symbol, _ = resolve_runtime_backend(PKG_KERNELABSTRACTIONS, hardware)
23+
if target_package == PKG_CUDA
24+
return (kind === :priority) ? get_priority_custream(id) : get_custream(id)
25+
elseif target_package == PKG_AMDGPU
26+
return (kind === :priority) ? get_priority_rocstream(id) : get_rocstream(id)
27+
elseif target_package == PKG_METAL
28+
return (kind === :priority) ? get_priority_metalstream(id) : get_metalstream(id)
29+
elseif target_package == PKG_THREADS || target_package == PKG_POLYESTER
30+
@ArgumentError("KernelAbstractions hardware $(symbol) does not expose GPU streams.")
31+
else
32+
@ArgumentError("$(ERRMSG_UNSUPPORTED_PACKAGE) (obtained: $target_package).")
33+
end
34+
end
35+
36+
function PK.get_priority_kastream(id::Integer; hardware::Union{Symbol,Nothing}=nothing)
37+
return dispatch_kastream(:priority, id, hardware)
38+
end
39+
40+
function PK.get_kastream(id::Integer; hardware::Union{Symbol,Nothing}=nothing)
41+
return dispatch_kastream(:regular, id, hardware)
42+
end
43+
44+
45+
## FUNCTIONS TO QUERY DEVICE PROPERTIES
46+
47+
function PK.get_kernelabstractions_compute_capability(default::VersionNumber; hardware::Union{Symbol,Nothing}=nothing)
48+
target_package, _, _ = resolve_runtime_backend(PKG_KERNELABSTRACTIONS, hardware)
49+
if target_package == PKG_CUDA
50+
return get_cuda_compute_capability(default)
51+
elseif target_package == PKG_AMDGPU
52+
return get_amdgpu_compute_capability(default)
53+
elseif target_package == PKG_METAL
54+
return get_metal_compute_capability(default)
55+
else
56+
return get_cpu_compute_capability(default)
57+
end
58+
end

0 commit comments

Comments
 (0)