Skip to content

Commit 13e4c3d

Browse files
committed
define shared things
1 parent 78fbe8e commit 13e4c3d

File tree

1 file changed

+34
-12
lines changed

1 file changed

+34
-12
lines changed

src/ParallelKernel/shared.jl

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,22 @@ izd(count) = @ModuleInternalError("function izd had not been evaluated at parse
1616
const MOD_METADATA_PK = gensym_world("__metadata_PK__", @__MODULE__) # # TODO: name mangling should be used here later, or if there is any sense to leave it like that then at check whether it's available must be done before creating it
1717
const PKG_CUDA = :CUDA
1818
const PKG_AMDGPU = :AMDGPU
19+
const PKG_KERNELABSTRACTIONS = :KernelAbstractions
1920
const PKG_METAL = :Metal
2021
const PKG_THREADS = :Threads
2122
const PKG_POLYESTER = :Polyester
2223
const PKG_NONE = :PKG_NONE
23-
const SUPPORTED_PACKAGES = [PKG_THREADS, PKG_POLYESTER, PKG_CUDA, PKG_AMDGPU, PKG_METAL]
24+
const SUPPORTED_PACKAGES = [PKG_THREADS, PKG_POLYESTER, PKG_CUDA, PKG_AMDGPU, PKG_KERNELABSTRACTIONS, PKG_METAL]
2425
const INT_CUDA = Int64 # NOTE: unsigned integers are not yet supported (proper negative offset and range is dealing missing)
2526
const INT_AMDGPU = Int64 # NOTE: ...
27+
const INT_KERNELABSTRACTIONS = Int64 # NOTE: KernelAbstractions dispatch defaults to CPU integers until a GPU-specific handle is selected at runtime.
2628
const INT_METAL = Int64 # NOTE: ...
2729
const INT_POLYESTER = Int64 # NOTE: ...
2830
const INT_THREADS = Int64 # NOTE: ...
2931
const COMPUTE_CAPABILITY_DEFAULT = v"" # having it infinity if it is not set allows to directly use statements like `if compute_capability < v"8"`, assuming a recent architecture if it is not set.
3032
const NTHREADS_X_MAX = 32
3133
const NTHREADS_X_MAX_AMDGPU = 64
34+
const NTHREADS_X_MAX_KERNELABSTRACTIONS = 32
3235
const NTHREADS_MAX = 256
3336
const INDICES = (gensym_world("ix", @__MODULE__), gensym_world("iy", @__MODULE__), gensym_world("iz", @__MODULE__))
3437
const INDICES_INN = (gensym_world("ixi", @__MODULE__), gensym_world("iyi", @__MODULE__), gensym_world("izi", @__MODULE__)) # ( :($(INDICES[1])+1), :($(INDICES[2])+1), :($(INDICES[3])+1) )
@@ -77,12 +80,13 @@ const ERRMSG_CHECK_INBOUNDS = "inbounds must be a evaluatable at parse ti
7780
const ERRMSG_CHECK_PADDING = "padding must be a evaluatable at parse time (e.g. literal or constant) and has to be of type Bool."
7881
const ERRMSG_CHECK_LITERALTYPES = "the type given to 'literaltype' must be one of the following: $(join(SUPPORTED_LITERALTYPES,", "))"
7982

80-
const CELLARRAY_BLOCKLENGTH = Dict(PKG_NONE => 0,
81-
PKG_CUDA => 0,
82-
PKG_AMDGPU => 0,
83-
PKG_METAL => 0,
84-
PKG_THREADS => 1,
85-
PKG_POLYESTER => 1)
83+
const CELLARRAY_BLOCKLENGTH = Dict(PKG_NONE => 0,
84+
PKG_CUDA => 0,
85+
PKG_AMDGPU => 0,
86+
PKG_KERNELABSTRACTIONS => 0,
87+
PKG_METAL => 0,
88+
PKG_THREADS => 1,
89+
PKG_POLYESTER => 1)
8690

8791
struct Dim3
8892
x::INT_THREADS
@@ -96,13 +100,30 @@ macro rangelengths() esc(:(($(RANGELENGTHS_VARNAMES...),))) end
96100
function kernel_int_type(package::Symbol)
97101
if (package == PKG_CUDA) int_type = INT_CUDA
98102
elseif (package == PKG_AMDGPU) int_type = INT_AMDGPU
103+
elseif (package == PKG_KERNELABSTRACTIONS) int_type = INT_KERNELABSTRACTIONS
99104
elseif (package == PKG_METAL) int_type = INT_METAL
100105
elseif (package == PKG_THREADS) int_type = INT_THREADS
101106
elseif (package == PKG_POLYESTER) int_type = INT_POLYESTER
102107
end
103108
return int_type
104109
end
105110

111+
function default_hardware_for(package::Symbol)
112+
if package == PKG_KERNELABSTRACTIONS
113+
return :cpu
114+
elseif package == PKG_CUDA
115+
return :gpu_cuda
116+
elseif package == PKG_AMDGPU
117+
return :gpu_amd
118+
elseif package == PKG_METAL
119+
return :gpu_metal
120+
elseif package == PKG_THREADS || package == PKG_POLYESTER
121+
return :cpu
122+
else
123+
@ArgumentError("$ERRMSG_UNSUPPORTED_PACKAGE (obtained: $package). Supported packages are: $(join(SUPPORTED_PACKAGES, ", ")).")
124+
end
125+
end
126+
106127

107128
## FUNCTIONS TO CHECK EXTENSIONS SUPPORT
108129

@@ -577,11 +598,12 @@ quote_expr(expr) = :($(Expr(:quote, expr)))
577598

578599
function get_compute_capability(package::Symbol)
579600
default = COMPUTE_CAPABILITY_DEFAULT
580-
if (package == PKG_CUDA) get_cuda_compute_capability(default)
581-
elseif (package == PKG_AMDGPU) get_amdgpu_compute_capability(default)
582-
elseif (package == PKG_METAL) get_metal_compute_capability(default)
583-
elseif (package == PKG_THREADS) get_cpu_compute_capability(default)
584-
elseif (package == PKG_POLYESTER) get_cpu_compute_capability(default)
601+
if (package == PKG_CUDA) get_cuda_compute_capability(default)
602+
elseif (package == PKG_AMDGPU) get_amdgpu_compute_capability(default)
603+
elseif (package == PKG_KERNELABSTRACTIONS) get_cpu_compute_capability(default)
604+
elseif (package == PKG_METAL) get_metal_compute_capability(default)
605+
elseif (package == PKG_THREADS) get_cpu_compute_capability(default)
606+
elseif (package == PKG_POLYESTER) get_cpu_compute_capability(default)
585607
else
586608
@ArgumentError("$ERRMSG_UNSUPPORTED_PACKAGE (obtained: $package). Supported packages are: $(join(SUPPORTED_PACKAGES, ", ")).")
587609
end

0 commit comments

Comments
 (0)