diff --git a/examples/diffusion2D_shmem_novis.jl b/examples/diffusion2D_shmem_novis.jl index 27bfba8d..3ce406f9 100644 --- a/examples/diffusion2D_shmem_novis.jl +++ b/examples/diffusion2D_shmem_novis.jl @@ -13,7 +13,7 @@ end ty = @threadIdx().y + 1 T_l = @sharedMem(eltype(T), (@blockDim().x+2, @blockDim().y+2)) T_l[tx,ty] = T[ix,iy] - if (ix>1 && ix1 && iy1 && ix1 && iy1 && iz Int + +Return the logical warp / wavefront / SIMD-group width in threads for the active backend. CUDA returns 32. AMD GPUs return the hardware wavefront size (typically 64 or 32). Metal returns the device `threadExecutionWidth`. CPU backend returns 1. Guaranteed constant for the lifetime of the kernel invocation. Use this value (not a hard‑coded constant) for portable intra-warp algorithms. +""" +@doc WARPSIZE_DOC +macro warpsize(args...) check_initialized(__module__); checknoargs(args...); esc(warpsize(__module__, args...)); end + + +## +const LANEID_DOC = """ + @laneid() -> Int + +Return the 1-based logical lane index in the current warp (range: 1:warpsize()). For CUDA this is `CUDA.laneid()+1` internally; for backends with 0-based hardware lane numbering the abstraction adds 1. CPU backend always returns 1. +""" +@doc LANEID_DOC +macro laneid(args...) check_initialized(__module__); checknoargs(args...); esc(laneid(__module__, args...)); end + + +## +const ACTIVE_MASK_DOC = """ + @active_mask() -> Unsigned + +Return a bit mask of currently active (non-exited, converged) lanes in the caller's warp. Bit (laneid()-1) corresponds to that logical lane. CUDA returns a 32-bit value; AMD returns a 64-bit value. Absent (throws) on Metal if not supported; CPU returns UInt64(0x1). +""" +@doc ACTIVE_MASK_DOC +macro active_mask(args...) check_initialized(__module__); checknoargs(args...); esc(active_mask(__module__, args...)); end + + +## +const SHFL_SYNC_DOC = """ + @shfl_sync(mask::Unsigned, val, lane::Integer) + @shfl_sync(mask::Unsigned, val, lane::Integer, width::Integer) + +Return the value of `val` from the source lane `lane` (1-based) among lanes named in `mask`. Optional `width` (power of two, 1 <= width <= warpsize()) logically partitions the warp into independent contiguous sub-groups each behaving as a mini-warp with lanes numbered 1:width. The source lane index is resolved modulo `width`. All participating lanes must supply identical `mask`, `lane`, and (if present) `width`. `val` may be any isbits type; larger composite isbits values are shuffled by decomposition into supported word sizes. CPU backend returns `val` unchanged. +""" +@doc SHFL_SYNC_DOC +macro shfl_sync(args...) check_initialized(__module__); checkargs_shfl_sync(args...); esc(shfl_sync(__module__, args...)); end + + +## +const SHFL_UP_SYNC_DOC = """ + @shfl_up_sync(mask::Unsigned, val, delta::Integer) + @shfl_up_sync(mask::Unsigned, val, delta::Integer, width::Integer) + +Shift `val` up by `delta` lanes within each logical partition (width semantics as in `shfl_sync`). Lanes with no valid upstream partner retain their original `val`. `delta >= 0`. CPU backend returns `val` unchanged. +""" +@doc SHFL_UP_SYNC_DOC +macro shfl_up_sync(args...) check_initialized(__module__); checkargs_shfl_up_down_xor(args...); esc(shfl_up_sync(__module__, args...)); end + + +## +const SHFL_DOWN_SYNC_DOC = """ + @shfl_down_sync(mask::Unsigned, val, delta::Integer) + @shfl_down_sync(mask::Unsigned, val, delta::Integer, width::Integer) + +Shift `val` down by `delta` lanes within each logical partition; lanes without a valid downstream partner retain their original `val`. `delta >= 0`. CPU backend returns `val` unchanged. +""" +@doc SHFL_DOWN_SYNC_DOC +macro shfl_down_sync(args...) check_initialized(__module__); checkargs_shfl_up_down_xor(args...); esc(shfl_down_sync(__module__, args...)); end + + +## +const SHFL_XOR_SYNC_DOC = """ + @shfl_xor_sync(mask::Unsigned, val, lane_mask::Integer) + @shfl_xor_sync(mask::Unsigned, val, lane_mask::Integer, width::Integer) + +Perform a butterfly (bitwise XOR) shuffle: each lane exchanges with the lane whose (laneid()-1) XOR `lane_mask` differs in the specified bits, constrained within each `width` partition if provided. If the computed partner is outside the partition the calling lane's own `val` is returned. CPU backend returns `val` unchanged. +""" +@doc SHFL_XOR_SYNC_DOC +macro shfl_xor_sync(args...) check_initialized(__module__); checkargs_shfl_up_down_xor(args...); esc(shfl_xor_sync(__module__, args...)); end + + +## +const VOTE_ANY_SYNC_DOC = """ + @vote_any_sync(mask::Unsigned, predicate::Bool) -> Bool + +Evaluate `predicate` across all active lanes named in `mask`; return true if any lane's predicate is true. Does not imply a memory fence. CPU backend returns `predicate`. +""" +@doc VOTE_ANY_SYNC_DOC +macro vote_any_sync(args...) check_initialized(__module__); checkargs_vote(args...); esc(vote_any_sync(__module__, args...)); end + + +## +const VOTE_ALL_SYNC_DOC = """ + @vote_all_sync(mask::Unsigned, predicate::Bool) -> Bool + +Evaluate `predicate` across all active lanes named in `mask`; return true only if every such lane's predicate is true. No memory ordering implied. CPU backend returns `predicate`. +""" +@doc VOTE_ALL_SYNC_DOC +macro vote_all_sync(args...) check_initialized(__module__); checkargs_vote(args...); esc(vote_all_sync(__module__, args...)); end + + +## +const VOTE_BALLOT_SYNC_DOC = """ + @vote_ballot_sync(mask::Unsigned, predicate::Bool) -> Unsigned + +Return a bit mask aggregating `predicate` values for lanes named in `mask`: bit (laneid()-1) set iff that lane's predicate is true. Width of result equals hardware warp mask width (32 for CUDA, 64 for AMD, CPU uses 64 with only bit 0 meaningful). Caller may safely promote to `UInt64` for uniform handling; upper bits beyond hardware width are zero. No memory ordering implied. +""" +@doc VOTE_BALLOT_SYNC_DOC +macro vote_ballot_sync(args...) check_initialized(__module__); checkargs_vote(args...); esc(vote_ballot_sync(__module__, args...)); end + + ## const FORALL_DOC = """ @∀ x ∈ X statement @@ -178,6 +282,18 @@ function checkargs_begin_end(args...) if !(2 <= length(args) <= 3) @ArgumentError("wrong number of arguments.") end end +function checkargs_shfl_sync(args...) + if !(3 <= length(args) <= 4) @ArgumentError("wrong number of arguments.") end +end + +function checkargs_shfl_up_down_xor(args...) + if !(3 <= length(args) <= 4) @ArgumentError("wrong number of arguments.") end +end + +function checkargs_vote(args...) + if !(length(args) == 2) @ArgumentError("wrong number of arguments.") end +end + ## FUNCTIONS FOR INDEXING AND DIMENSIONS @@ -270,6 +386,152 @@ function pk_println(caller::Module, args...; package::Symbol=get_package(caller) end +## FUNCTIONS FOR WARP-LEVEL PRIMITIVES (backend mapping) + +function warpsize(caller::Module, args...; package::Symbol=get_package(caller)) + if (package == PKG_CUDA) return :(CUDA.warpsize()) + elseif (package == PKG_AMDGPU) return :(AMDGPU.Device.wavefrontsize()) + elseif (package == PKG_METAL) return :(Metal.threads_per_simdgroup()) + elseif iscpu(package) return :(ParallelStencil.ParallelKernel.warpsize_cpu()) + else @KeywordArgumentError("$ERRMSG_UNSUPPORTED_PACKAGE (obtained: $package).") + end +end + +function laneid(caller::Module, args...; package::Symbol=get_package(caller)) + if (package == PKG_CUDA) return :(CUDA.laneid() + 1) + elseif (package == PKG_AMDGPU) return :(unsafe_trunc(Cint, AMDGPU.Device.activelane()) + Cint(1)) + elseif (package == PKG_METAL) return :(unsafe_trunc(Cint, Metal.thread_index_in_simdgroup()) + Cint(1)) + elseif iscpu(package) return :(ParallelStencil.ParallelKernel.laneid_cpu()) + else @KeywordArgumentError("$ERRMSG_UNSUPPORTED_PACKAGE (obtained: $package).") + end +end + +function active_mask(caller::Module, args...; package::Symbol=get_package(caller)) + if (package == PKG_CUDA) return :(CUDA.active_mask()) + elseif (package == PKG_AMDGPU) return :(AMDGPU.Device.activemask()) + elseif (package == PKG_METAL) @KeywordArgumentError("this functionality is not yet supported in Metal.jl.") + elseif iscpu(package) return :(ParallelStencil.ParallelKernel.active_mask_cpu()) + else @KeywordArgumentError("$ERRMSG_UNSUPPORTED_PACKAGE (obtained: $package).") + end +end + +function shfl_sync(caller::Module, args...; package::Symbol=get_package(caller)) + if (package == PKG_CUDA) + return :(CUDA.shfl_sync($(args...))) + elseif (package == PKG_AMDGPU) + if length(args) == 3 + # (mask, val, lane) + return :(AMDGPU.Device.shfl_sync(UInt64($(args[1])), $(args[2]), unsafe_trunc(Cint, $(args[3])) - Cint(1))) + else + # (mask, val, lane, width) + return :(AMDGPU.Device.shfl_sync(UInt64($(args[1])), $(args[2]), unsafe_trunc(Cint, $(args[3])) - Cint(1), unsafe_trunc(Cuint, $(args[4])))) + end + elseif (package == PKG_METAL) + @KeywordArgumentError("this functionality is not yet supported in Metal.jl.") + elseif iscpu(package) + if length(args) == 3 + return :(ParallelStencil.ParallelKernel.shfl_sync_cpu($(args[1]), $(args[2]), Int64($(args[3])) - Int64(1))) + else + return :(ParallelStencil.ParallelKernel.shfl_sync_cpu($(args[1]), $(args[2]), Int64($(args[3])) - Int64(1), Int64($(args[4])))) + end + else + @KeywordArgumentError("$ERRMSG_UNSUPPORTED_PACKAGE (obtained: $package).") + end +end + +function shfl_up_sync(caller::Module, args...; package::Symbol=get_package(caller)) + if (package == PKG_CUDA) + return :(CUDA.shfl_up_sync($(args...))) + elseif (package == PKG_AMDGPU) + if length(args) == 3 + return :(AMDGPU.Device.shfl_up_sync(UInt64($(args[1])), $(args[2]), unsafe_trunc(Cint, $(args[3])))) + else + return :(AMDGPU.Device.shfl_up_sync(UInt64($(args[1])), $(args[2]), unsafe_trunc(Cint, $(args[3])), unsafe_trunc(Cuint, $(args[4])))) + end + elseif (package == PKG_METAL) + @KeywordArgumentError("this functionality is not yet supported in Metal.jl.") + elseif iscpu(package) + if length(args) == 3 + return :(ParallelStencil.ParallelKernel.shfl_up_sync_cpu($(args[1]), $(args[2]), Int64($(args[3])))) + else + return :(ParallelStencil.ParallelKernel.shfl_up_sync_cpu($(args[1]), $(args[2]), Int64($(args[3])), Int64($(args[4])))) + end + else + @KeywordArgumentError("$ERRMSG_UNSUPPORTED_PACKAGE (obtained: $package).") + end +end + +function shfl_down_sync(caller::Module, args...; package::Symbol=get_package(caller)) + if (package == PKG_CUDA) + return :(CUDA.shfl_down_sync($(args...))) + elseif (package == PKG_AMDGPU) + if length(args) == 3 + return :(AMDGPU.Device.shfl_down_sync(UInt64($(args[1])), $(args[2]), unsafe_trunc(Cint, $(args[3])))) + else + return :(AMDGPU.Device.shfl_down_sync(UInt64($(args[1])), $(args[2]), unsafe_trunc(Cint, $(args[3])), unsafe_trunc(Cuint, $(args[4])))) + end + elseif (package == PKG_METAL) + @KeywordArgumentError("this functionality is not yet supported in Metal.jl.") + elseif iscpu(package) + if length(args) == 3 + return :(ParallelStencil.ParallelKernel.shfl_down_sync_cpu($(args[1]), $(args[2]), Int64($(args[3])))) + else + return :(ParallelStencil.ParallelKernel.shfl_down_sync_cpu($(args[1]), $(args[2]), Int64($(args[3])), Int64($(args[4])))) + end + else + @KeywordArgumentError("$ERRMSG_UNSUPPORTED_PACKAGE (obtained: $package).") + end +end + +function shfl_xor_sync(caller::Module, args...; package::Symbol=get_package(caller)) + if (package == PKG_CUDA) + return :(CUDA.shfl_xor_sync($(args...))) + elseif (package == PKG_AMDGPU) + if length(args) == 3 + return :(AMDGPU.Device.shfl_xor_sync(UInt64($(args[1])), $(args[2]), unsafe_trunc(Cint, $(args[3])) - Cint(1))) + else + return :(AMDGPU.Device.shfl_xor_sync(UInt64($(args[1])), $(args[2]), unsafe_trunc(Cint, $(args[3])) - Cint(1), unsafe_trunc(Cuint, $(args[4])))) + end + elseif (package == PKG_METAL) + @KeywordArgumentError("this functionality is not yet supported in Metal.jl.") + elseif iscpu(package) + if length(args) == 3 + return :(ParallelStencil.ParallelKernel.shfl_xor_sync_cpu($(args[1]), $(args[2]), Int64($(args[3])) - Int64(1))) + else + return :(ParallelStencil.ParallelKernel.shfl_xor_sync_cpu($(args[1]), $(args[2]), Int64($(args[3])) - Int64(1), Int64($(args[4])))) + end + else + @KeywordArgumentError("$ERRMSG_UNSUPPORTED_PACKAGE (obtained: $package).") + end +end + +function vote_any_sync(caller::Module, args...; package::Symbol=get_package(caller)) + if (package == PKG_CUDA) return :(CUDA.vote_any_sync($(args...))) + elseif (package == PKG_AMDGPU) return :(AMDGPU.Device.any_sync(UInt64($(args[1])), $(args[2]))) + elseif (package == PKG_METAL) @KeywordArgumentError("this functionality is not yet supported in Metal.jl.") + elseif iscpu(package) return :(ParallelStencil.ParallelKernel.vote_any_sync_cpu($(args...))) + else @KeywordArgumentError("$ERRMSG_UNSUPPORTED_PACKAGE (obtained: $package).") + end +end + +function vote_all_sync(caller::Module, args...; package::Symbol=get_package(caller)) + if (package == PKG_CUDA) return :(CUDA.vote_all_sync($(args...))) + elseif (package == PKG_AMDGPU) return :(AMDGPU.Device.all_sync(UInt64($(args[1])), $(args[2]))) + elseif (package == PKG_METAL) @KeywordArgumentError("this functionality is not yet supported in Metal.jl.") + elseif iscpu(package) return :(ParallelStencil.ParallelKernel.vote_all_sync_cpu($(args...))) + else @KeywordArgumentError("$ERRMSG_UNSUPPORTED_PACKAGE (obtained: $package).") + end +end + +function vote_ballot_sync(caller::Module, args...; package::Symbol=get_package(caller)) + if (package == PKG_CUDA) return :(CUDA.vote_ballot_sync($(args...))) + elseif (package == PKG_AMDGPU) return :(AMDGPU.Device.ballot_sync(UInt64($(args[1])), $(args[2]))) + elseif (package == PKG_METAL) @KeywordArgumentError("this functionality is not yet supported in Metal.jl.") + elseif iscpu(package) return :(ParallelStencil.ParallelKernel.vote_ballot_sync_cpu($(args...))) + else @KeywordArgumentError("$ERRMSG_UNSUPPORTED_PACKAGE (obtained: $package).") + end +end + ## FUNCTIONS FOR MATH SYNTAX function ∀(caller::Module, member_expr::Expr, statement::Union{Expr, Symbol}) @@ -345,3 +607,45 @@ macro sync_threads_cpu() esc(:(begin end)) end macro sharedMem_cpu(T, dims) :(MArray{Tuple{$(esc(dims))...}, $(esc(T)), length($(esc(dims))), prod($(esc(dims)))}(undef)); end # Note: A macro is used instead of a function as a creating a type stable function is not really possible (dims can take any values and they become part of the MArray type...). MArray is not escaped in order not to have to import StaticArrays in the user code. macro sharedMem_cpu(T, dims, offset) esc(:(ParallelStencil.ParallelKernel.@sharedMem_cpu($T, $dims))) end + +## CPU BACKEND: WARP-LEVEL PRIMITIVES (zero-overhead pure functions) + +# The CPU backend follows a single-thread-per-block model. All warp-level +# operations therefore degenerate to constants or identity operations. +# These functions are intentionally small, @inline, allocation-free, and +# operate on isbits values only. They are called by the macro dispatchers +# for the CPU backend. + +@inline warpsize_cpu()::Int = 1 + +@inline laneid_cpu()::Int = 1 + +@inline active_mask_cpu()::UInt64 = UInt64(0x1) + +# Shuffle: direct, with optional width. Identity on CPU. +@inline shfl_sync_cpu(mask::Unsigned, val, lane0::Int64) = val + +@inline shfl_sync_cpu(mask::Unsigned, val, lane0::Int64, width::Int64) = val + +# Shuffle up +@inline shfl_up_sync_cpu(mask::Unsigned, val, delta::Int64) = val + +@inline shfl_up_sync_cpu(mask::Unsigned, val, delta::Int64, width::Int64) = val + +# Shuffle down +@inline shfl_down_sync_cpu(mask::Unsigned, val, delta::Int64) = val + +@inline shfl_down_sync_cpu(mask::Unsigned, val, delta::Int64, width::Int64) = val + +# Shuffle xor (butterfly) +@inline shfl_xor_sync_cpu(mask::Unsigned, val, lane_mask0::Int64) = val + +@inline shfl_xor_sync_cpu(mask::Unsigned, val, lane_mask0::Int64, width::Int64) = val + +# Vote operations +@inline vote_any_sync_cpu(mask::Unsigned, predicate::Bool)::Bool = predicate + +@inline vote_all_sync_cpu(mask::Unsigned, predicate::Bool)::Bool = predicate + +# Ballot returns a mask with bit 0 set iff predicate is true; CPU uses 64-bit mask. +@inline vote_ballot_sync_cpu(mask::Unsigned, predicate::Bool)::UInt64 = predicate ? UInt64(0x1) : UInt64(0x0) diff --git a/src/ParallelStencil.jl b/src/ParallelStencil.jl index a46433a6..a890ff77 100644 --- a/src/ParallelStencil.jl +++ b/src/ParallelStencil.jl @@ -32,6 +32,17 @@ https://github.com/omlins/ParallelStencil.jl - [`@threadIdx`](@ref) - [`@sync_threads`](@ref) - [`@sharedMem`](@ref) +!!! note "Warp-level primitives" + - [`@warpsize`](@ref) + - [`@laneid`](@ref) + - [`@active_mask`](@ref) + - [`@shfl_sync`](@ref) + - [`@shfl_up_sync`](@ref) + - [`@shfl_down_sync`](@ref) + - [`@shfl_xor_sync`](@ref) + - [`@vote_any_sync`](@ref) + - [`@vote_all_sync`](@ref) + - [`@vote_ballot_sync`](@ref) # Submodules - [`ParallelStencil.AD`](@ref) @@ -60,8 +71,11 @@ using .ParallelKernel.Exceptions include("shared.jl") ## Alphabetical include of function files +include("allocators.jl") +include("hide_communication.jl") include("init_parallel_stencil.jl") include("kernel_language.jl") +include("memopt.jl") include("parallel.jl") include("reset_parallel_stencil.jl") @@ -74,6 +88,7 @@ include("FiniteDifferences.jl") export @init_parallel_stencil, FiniteDifferences1D, FiniteDifferences2D, FiniteDifferences3D, AD export @parallel, @hide_communication, @parallel_indices, @parallel_async, @synchronize, @zeros, @ones, @rand, @falses, @trues, @fill, @fill!, @CellType export @gridDim, @blockIdx, @blockDim, @threadIdx, @sync_threads, @sharedMem, @ps_show, @ps_println, @∀ +export @warpsize, @laneid, @active_mask, @shfl_sync, @shfl_up_sync, @shfl_down_sync, @shfl_xor_sync, @vote_any_sync, @vote_all_sync, @vote_ballot_sync export PSNumber end # Module ParallelStencil diff --git a/src/allocators.jl b/src/allocators.jl new file mode 100644 index 00000000..c8ee4bc5 --- /dev/null +++ b/src/allocators.jl @@ -0,0 +1,8 @@ +@doc replace(ParallelKernel.ZEROS_DOC, "@init_parallel_kernel" => "@init_parallel_stencil") macro zeros(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.@zeros($(args...)))); end +@doc replace(ParallelKernel.ONES_DOC, "@init_parallel_kernel" => "@init_parallel_stencil") macro ones(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.@ones($(args...)))); end +@doc replace(ParallelKernel.RAND_DOC, "@init_parallel_kernel" => "@init_parallel_stencil") macro rand(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.@rand($(args...)))); end +@doc replace(ParallelKernel.FALSES_DOC, "@init_parallel_kernel" => "@init_parallel_stencil") macro falses(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.@falses($(args...)))); end +@doc replace(ParallelKernel.TRUES_DOC, "@init_parallel_kernel" => "@init_parallel_stencil") macro trues(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.@trues($(args...)))); end +@doc replace(ParallelKernel.FILL_DOC, "@init_parallel_kernel" => "@init_parallel_stencil") macro fill(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.@fill($(args...)))); end +@doc replace(ParallelKernel.FILL!_DOC, "@init_parallel_kernel" => "@init_parallel_stencil") macro fill!(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.@fill!($(args...)))); end +@doc replace(ParallelKernel.CELLTYPE_DOC, "@init_parallel_kernel" => "@init_parallel_stencil") macro CellType(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.@CellType($(args...)))); end diff --git a/src/hide_communication.jl b/src/hide_communication.jl new file mode 100644 index 00000000..868acd32 --- /dev/null +++ b/src/hide_communication.jl @@ -0,0 +1 @@ +@doc replace(ParallelKernel.HIDE_COMMUNICATION_DOC, "@init_parallel_kernel" => "@init_parallel_stencil") macro hide_communication(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.@hide_communication($(args...)))); end \ No newline at end of file diff --git a/src/init_parallel_stencil.jl b/src/init_parallel_stencil.jl index 454ee360..e27b62bd 100644 --- a/src/init_parallel_stencil.jl +++ b/src/init_parallel_stencil.jl @@ -1,25 +1,3 @@ -# NOTE: @parallel and @parallel_indices and @parallel_async do not appear in the following as they are extended and therefore defined in parallel.jl -@doc replace(ParallelKernel.HIDE_COMMUNICATION_DOC, "@init_parallel_kernel" => "@init_parallel_stencil") macro hide_communication(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.@hide_communication($(args...)))); end -@doc replace(ParallelKernel.ZEROS_DOC, "@init_parallel_kernel" => "@init_parallel_stencil") macro zeros(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.@zeros($(args...)))); end -@doc replace(ParallelKernel.ONES_DOC, "@init_parallel_kernel" => "@init_parallel_stencil") macro ones(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.@ones($(args...)))); end -@doc replace(ParallelKernel.RAND_DOC, "@init_parallel_kernel" => "@init_parallel_stencil") macro rand(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.@rand($(args...)))); end -@doc replace(ParallelKernel.FALSES_DOC, "@init_parallel_kernel" => "@init_parallel_stencil") macro falses(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.@falses($(args...)))); end -@doc replace(ParallelKernel.TRUES_DOC, "@init_parallel_kernel" => "@init_parallel_stencil") macro trues(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.@trues($(args...)))); end -@doc replace(ParallelKernel.FILL_DOC, "@init_parallel_kernel" => "@init_parallel_stencil") macro fill(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.@fill($(args...)))); end -@doc replace(ParallelKernel.FILL!_DOC, "@init_parallel_kernel" => "@init_parallel_stencil") macro fill!(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.@fill!($(args...)))); end -@doc replace(ParallelKernel.CELLTYPE_DOC, "@init_parallel_kernel" => "@init_parallel_stencil") macro CellType(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.@CellType($(args...)))); end -@doc replace(ParallelKernel.SYNCHRONIZE_DOC, "@init_parallel_kernel" => "@init_parallel_stencil") macro synchronize(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.@synchronize($(args...)))); end -@doc replace(ParallelKernel.GRIDDIM_DOC, "@init_parallel_kernel" => "@init_parallel_stencil") macro gridDim(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.@gridDim($(args...)))); end -@doc replace(ParallelKernel.BLOCKIDX_DOC, "@init_parallel_kernel" => "@init_parallel_stencil") macro blockIdx(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.@blockIdx($(args...)))); end -@doc replace(ParallelKernel.BLOCKDIM_DOC, "@init_parallel_kernel" => "@init_parallel_stencil") macro blockDim(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.@blockDim($(args...)))); end -@doc replace(ParallelKernel.THREADIDX_DOC, "@init_parallel_kernel" => "@init_parallel_stencil") macro threadIdx(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.@threadIdx($(args...)))); end -@doc replace(ParallelKernel.SYNCTHREADS_DOC, "@init_parallel_kernel" => "@init_parallel_stencil") macro sync_threads(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.@sync_threads($(args...)))); end -@doc replace(ParallelKernel.SHAREDMEM_DOC, "@init_parallel_kernel" => "@init_parallel_stencil") macro sharedMem(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.@sharedMem($(args...)))); end -@doc replace(ParallelKernel.FORALL_DOC, "@init_parallel_kernel" => "@init_parallel_stencil") macro ∀(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.@∀($(args...)))); end -@doc replace(replace(ParallelKernel.PKSHOW_DOC, "@init_parallel_kernel" => "@init_parallel_stencil"), "pk_show" => "ps_show") macro ps_show(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.@pk_show($(args...)))); end -@doc replace(replace(ParallelKernel.PKPRINTLN_DOC, "@init_parallel_kernel" => "@init_parallel_stencil"), "pk_println" => "ps_println") macro ps_println(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.@pk_println($(args...)))); end - - """ @init_parallel_stencil(package, numbertype, ndims) @init_parallel_stencil(package, numbertype, ndims, inbounds=...) diff --git a/src/kernel_language.jl b/src/kernel_language.jl index 43175650..b161a3a5 100644 --- a/src/kernel_language.jl +++ b/src/kernel_language.jl @@ -1,1074 +1,19 @@ -#TODO: add ParallelStencil.ParallelKernel. in front of all kernel lang in macros! Later: generalize more for z? - -## -macro loop(args...) check_initialized(__module__); checkargs_loop(args...); esc(loop(__module__, args...)); end - - -## -macro memopt(args...) check_initialized(__module__); checkargs_memopt(args...); esc(memopt(args[1], __module__, args[2:end]...)); end - - -## -macro shortif(args...) check_initialized(__module__); checktwoargs(args...); esc(shortif(__module__, args...)); end - - -## ARGUMENT CHECKS - -function checknoargs(args...) - if (length(args) != 0) @ArgumentError("no arguments allowed.") end -end - -function checksinglearg(args...) - if (length(args) != 1) @ArgumentError("wrong number of arguments.") end -end - -function checktwoargs(args...) - if (length(args) != 2) @ArgumentError("wrong number of arguments.") end -end - -function checkargs_loop(args...) - if (length(args) != 4) @ArgumentError("wrong number of arguments.") end -end - -function checkargs_memopt(args...) - if (length(args) != 8 && length(args) != 7 && length(args) != 4) @ArgumentError("wrong number of arguments.") end -end - - -## FUNCTIONS FOR PERFORMANCE OPTIMSATIONS - -function loop(caller::Module, index::Symbol, loopdim::Integer, loopsize, body; package::Symbol=get_package(caller)) - if (package ∉ SUPPORTED_PACKAGES) @KeywordArgumentError("$ERRMSG_UNSUPPORTED_PACKAGE (obtained: $package).") end - dimvar = (:x,:y,:z)[loopdim] - loopoffset = gensym_world("loopoffset", @__MODULE__) - i = gensym_world("i", @__MODULE__) - return quote - $loopoffset = (@blockIdx().$dimvar-1)*$loopsize - for $i = 1:$loopsize - $index = $i + $loopoffset - $body - end - end -end - -#TODO: add input check and errors -# TODO: create a run time check for requirement: -# In order to be able to read the data into shared memory in only two statements, the number of threads must be at least half of the size of the shared memory block plus halo; thus, the total number of threads in each dimension must equal the range length, as else there would be smaller thread blocks at the boundaries (threads overlapping the range are sent home). These smaller blocks would be likely not to match the criteria for a correct reading of the data to shared memory. In summary the following requirements must be matched: @gridDim().x*@blockDim().x - $rangelength_x == 0; @gridDim().y*@blockDim().y - $rangelength_y > 0 -function memopt(metadata_module::Module, is_parallel_kernel::Bool, caller::Module, indices::Union{Symbol,Expr}, optvars::Union{Expr,Symbol}, loopdim::Integer, loopsize::Integer, optranges::Union{Nothing, NamedTuple{t, <:NTuple{N,NTuple{3,UnitRange}} where N} where t}, use_shmemhalos::Union{Nothing, NamedTuple{t, <:NTuple{N,Bool} where N} where t}, optimize_halo_read::Bool, body::Expr; package::Symbol=get_package(caller)) - optvars = Tuple(extract_tuple(optvars)) #TODO: make this function actually return directly a tuple rather than an array - indices = Tuple(extract_tuple(indices)) - use_shmemhalos = isnothing(use_shmemhalos) ? use_shmemhalos : eval_arg(caller, use_shmemhalos) - optranges = isnothing(optranges) ? optranges : eval_arg(caller, optranges) - readonlyvars = find_vars(body, indices; readonly=true) - if length(indices) != 3 @IncoherentArgumentError("incoherent arguments memopt in @parallel[_indices] : optimization can only be applied in 3-D @parallel kernels and @parallel_indices kernels with three indices.") end - if optvars == (Symbol(""),) - optvars = Tuple(keys(readonlyvars)) - else - for A in optvars - if !haskey(readonlyvars, A) @IncoherentArgumentError("incoherent argument memopt in @parallel[_indices] : optimization can only be applied to arrays that are only read within the kernel (not applicable to: $A).") end - end - end - if (package ∉ SUPPORTED_PACKAGES) @KeywordArgumentError("$ERRMSG_UNSUPPORTED_PACKAGE (obtained: $package).") end - if (package == PKG_CUDA) int_type = INT_CUDA - elseif (package == PKG_AMDGPU) int_type = INT_AMDGPU - elseif (package == PKG_METAL) int_type = INT_METAL - elseif (package == PKG_THREADS) int_type = INT_THREADS - end - body = eval_offsets(caller, body, indices, int_type) - offsets, offsets_by_z = extract_offsets(caller, body, indices, int_type, optvars, loopdim) - optvars = remove_single_point_optvars(optvars, optranges, offsets, offsets_by_z) - if (length(optvars)==0) @IncoherentArgumentError("incoherent argument memopt in @parallel[_indices] : optimization can only be applied if there is at least one array that is read-only within the kernel (and accessed with a multi-point stencil). Set memopt=false for this kernel.") end - optranges = define_optranges(optranges, optvars, offsets, int_type, package) - regqueue_heads, regqueue_tails, offset_mins, offset_maxs, nb_regs_heads, nb_regs_tails = define_regqueues(offsets, optranges, optvars, indices, int_type, loopdim) - - if loopdim == 3 - oz_maxs, hx1s, hy1s, hx2s, hy2s, use_shmems, use_shmem_xs, use_shmem_ys, use_shmemhalos, use_shmemindices, offset_spans, oz_spans, loopentrys = define_helper_variables(offset_mins, offset_maxs, optvars, use_shmemhalos, loopdim) - oz_span_max = maximum(values(oz_spans)) - # TODO: this only leads to correct result after row two executions in a row, probably due to the same compiler bug has below. # loopsize = (oz_span_max<=0) ? 1 : loopsize # NOTE: if the stencilrange in z is only one point, no loop is needed. - loopstart = minimum(values(loopentrys)) - loopend = loopsize - use_any_shmem = any(values(use_shmems)) - shmem_index_groups = define_shmem_index_groups(hx1s, hy1s, hx2s, hy2s, optvars, use_shmems, loopdim) - shmem_vars = define_shmem_vars(oz_maxs, hx1s, hy1s, hx2s, hy2s, optvars, indices, use_shmems, use_shmem_xs, use_shmem_ys, shmem_index_groups, use_shmemhalos, use_shmemindices, loopdim) - shmem_exprs = define_shmem_exprs(shmem_vars, loopdim) - shmem_z_ranges = define_shmem_z_ranges(offsets_by_z, use_shmems, loopdim) - shmem_loopentrys = define_shmem_loopentrys(loopentrys, shmem_z_ranges, offset_mins, loopdim) - shmem_loopexits = define_shmem_loopexits(loopend, shmem_z_ranges, offset_maxs, loopdim) - mainloopstart = (optimize_halo_read && !isempty(shmem_loopentrys)) ? minimum(values(shmem_loopentrys)) : loopstart - mainloopend = loopend # TODO: the second loop split leads to wrong results, probably due to a compiler bug. # mainloopend = (optimize_halo_read && !isempty(shmem_loopexits) ) ? maximum(values(shmem_loopexits) ) : loopend - ix, iy, iz = indices - tz_g = THREADIDS_VARNAMES[3] - rangelength_z = RANGELENGTHS_VARNAMES[3] - ranges = RANGES_VARNAME - range_z = :(($ranges[3])[$tz_g]) - range_z_start = :(($ranges[3])[1]) - range_z_end = :(($ranges[3])[end]) - i = gensym_world("i", @__MODULE__) - loopoffset = gensym_world("loopoffset", @__MODULE__) - - for A in optvars - regqueue_tail = regqueue_tails[A] - regqueue_head = regqueue_heads[A] - for oxy in keys(regqueue_tail) - for oz in keys(regqueue_tail[oxy]) - body = substitute(body, regtarget(A, (oxy..., oz), indices), regqueue_tail[oxy][oz]) - end - end - for oxy in keys(regqueue_head) - for oz in keys(regqueue_head[oxy]) - body = substitute(body, regtarget(A, (oxy..., oz), indices), regqueue_head[oxy][oz]) - end - end - end - - nb_indexing_vars = 1 + 14*length(keys(shmem_index_groups)) # TODO: a group must not be counted if none of the variables uses the shmem indices symbols. - nb_cell_vars = sum(values(nb_regs_heads)) + sum(values(nb_regs_tails)) - - #TODO: replace wrap_if where possible with in-line if - compare performance when doing it - body = quote - $loopoffset = (@blockIdx().z-1)*$loopsize + $range_z_start-1 #TODO: MOVE UP - see no perf change! interchange other lines! -$((quote - $tx = @threadIdx().x + $hx1 - $ty = @threadIdx().y + $hy1 - $nx_l = @blockDim().x + UInt32($(hx1+hx2)) # NOTE: cast to UInt32 is necessary to avoid promotion, which can lead to a tuple with different integers, resulting in an error. - $ny_l = @blockDim().y + UInt32($(hy1+hy2)) # ... - $t_h = (@threadIdx().y-1)*@blockDim().x + @threadIdx().x # NOTE: here it must be bx, not @blockDim().x - $t_h2 = $t_h + $nx_l*$ny_l - @blockDim().x*@blockDim().y - $ty_h = ($t_h-1) ÷ $nx_l + 1 - $tx_h = ($t_h-1) % $nx_l + 1 # NOTE: equivalent to (worse performance has uses registers probably differently): ($t_h-1) - $nx_l*($ty_h-1) + 1 - $ty_h2 = ($t_h2-1) ÷ $nx_l + 1 - $tx_h2 = ($t_h2-1) % $nx_l + 1 # NOTE: equivalent to (worse performance has uses registers probably differently): ($t_h2-1) - $nx_l*($ty_h2-1) + 1 - $ix_h = $ix - @threadIdx().x + $tx_h - $hx1 # NOTE: here it must be @blockDim().x, not bx - $ix_h2 = $ix - @threadIdx().x + $tx_h2 - $hx1 # ... - $iy_h = $iy - @threadIdx().y + $ty_h - $hy1 # ... - $iy_h2 = $iy - @threadIdx().y + $ty_h2 - $hy1 # ... - end - for vars in values(shmem_index_groups) for A in (vars[1],) if use_shmemindices[A] for s in (shmem_vars[A],) for (shmem_offset, hx1, hx2, hy1, hy2, tx, ty, nx_l, ny_l, t_h, t_h2, tx_h, tx_h2, ty_h, ty_h2, ix_h, ix_h2, iy_h, iy_h2, A_head) = ((shmem_exprs[A][:offset], hx1s[A], hx2s[A], hy1s[A], hy2s[A], s[:tx], s[:ty], s[:nx_l], s[:ny_l], s[:t_h], s[:t_h2], s[:tx_h], s[:tx_h2], s[:ty_h], s[:ty_h2], s[:ix_h], s[:ix_h2], s[:iy_h], s[:iy_h2], s[:A_head]),) - )... -) -$((:( $A_head = @sharedMem(eltype($A), (Int64($nx_l), Int64($ny_l)), $shmem_offset) # e.g. A_izp3 = @sharedMem(eltype(A), (nx_l, ny_l), +(nx_l_A * ny_l_A)*eltype(A)) - ) - for (A, s) in shmem_vars for (shmem_offset, nx_l, ny_l, A_head) = ((shmem_exprs[A][:offset], s[:nx_l], s[:ny_l], s[:A_head]),) - )... -) -$((:( $reg = 0.0 # e.g. A_ixm1_iyp2_izp2 = 0.0 - ) - for A in optvars for regs in values(regqueue_tails[A]) for reg in values(regs) - )... -) -$((:( $reg = 0.0 # e.g. A_ixm1_iyp2_izp3 = 0.0 - ) - for A in optvars for regs in values(regqueue_heads[A]) for reg in values(regs) - )... -) -# Pre-loop - # for $i = $loopstart:$(mainloopstart-1) -$(wrap_loop(i, loopstart:mainloopstart-1, - quote - $iz = $i + $loopoffset - if ($iz > $range_z_end) ParallelStencil.@return_nothing; end - # NOTE: the following is now fully included in the loopoffset (0.25% performance gain measured on H100) but is still of interest if we implement step ranges: - # $tz_g = $i + $loopoffset - # if ($tz_g > $rangelength_z) ParallelStencil.@return_nothing; end - # $iz = ($tz_g < 1) ? $range_z_start-(1-$tz_g) : $range_z # TODO: this will probably always be formulated with range_z_start -$((wrap_if(:($i > $(loopentry-1)), - :( $reg = (0<$ix+$(oxy[1])<=size($A,1) && 0<$iy+$(oxy[2])<=size($A,2) && 0<$iz+$oz<=size($A,3)) ? $(regtarget(A, (oxy...,oz), indices)) : $reg - ) - ;unless=(loopentry==loopstart) - ) - for A in keys(shmem_vars) for (oxy, regs) in regqueue_heads[A] for (oz, reg) in regs for loopentry = (loopentrys[A],) - )... -) -$((wrap_if(:($i > $(loopentry-1)), - :( $reg = (0<$ix+$(oxy[1])<=size($A,1) && 0<$iy+$(oxy[2])<=size($A,2) && 0<$iz+$oz<=size($A,3)) ? $(regtarget(A, (oxy...,oz), indices)) : $reg - ) - ;unless=(loopentry==loopstart) - ) - for A in optvars for (oxy, regs) in regqueue_heads[A] for (oz, reg) in regs for loopentry = (loopentrys[A],) if !use_shmems[A] - )... -) -$(( # NOTE: the if statement is not needed here as we only deal with registers - # wrap_if(:($i > $(loopentry-1)), - :( - $(regs[oz]) = $(regs[oz+1]) # e.g. A_ixm1_iyp2_iz = A_ixm1_iyp2_izp1 - ) - # ;unless=(loopentry==loopstart) - # ) - for A in optvars for regs in values(regqueue_tails[A]) for oz in sort(keys(regs)) for (loopentry, oz_max) = ((loopentrys[A], oz_maxs[A]),) if oz<=oz_max-2 - )... -) -$(( # NOTE: the if statement is not needed here as we only deal with registers - # wrap_if(:($i > $(loopentry-1)), - :( - $reg = $(regqueue_heads[A][oxy][oz_max]) # e.g. A_ixm1_iyp2_izp2 = A_ixm1_iyp2_izp3 - ) - # ;unless=(loopentry==loopstart) - # ) - for A in optvars for (oxy, regs) in regqueue_tails[A] for (oz, reg) in regs for (loopentry, oz_max) = ((loopentrys[A], oz_maxs[A]),) if oz==oz_max-1 && haskey(regqueue_heads[A], oxy) && haskey(regqueue_heads[A][oxy], oz_max) - )... -) - end - # ;unroll=true - ) # wrap_loop end -) # end - -# Main loop - # for $i = $mainloopstart:$mainloopend # ParallelStencil.@unroll -$(wrap_loop(i, mainloopstart:mainloopend, - quote - $iz = $i + $loopoffset - if ($iz > $range_z_end) ParallelStencil.@return_nothing; end - # NOTE: the following is now fully included in the loopoffset (0.25% performance gain measured on H100) but is still of interest if we implement step ranges: - # $tz_g = $i + $loopoffset - # if ($tz_g > $rangelength_z) ParallelStencil.@return_nothing; end - # $iz = ($tz_g < 1) ? $range_z_start-(1-$tz_g) : $range_z # TODO: this will probably always be formulated with range_z_start -$(use_any_shmem ? - :( @sync_threads() - ) : NOEXPR -) -$((wrap_if(:($i > $(loopentry-1)), - quote - if (2*$t_h <= $n_l && $ix_h>0 && $ix_h<=size($A,1) && $iy_h>0 && $iy_h<=size($A,2) && 0<$iz+$oz_max<=size($A,3)) - $A_head[$tx_h,$ty_h] = $A[$ix_h,$iy_h,$iz+$oz_max] - end - if (2*$t_h2 > $n_l && $ix_h2>0 && $ix_h2<=size($A,1) && $iy_h2>0 && $iy_h2<=size($A,2) && 0<$iz+$oz_max<=size($A,3)) - $A_head[$tx_h2,$ty_h2] = $A[$ix_h2,$iy_h2,$iz+$oz_max] - end - end - ;unless=(loopentry<=mainloopstart) - ) - for (A, s) in shmem_vars if use_shmemhalos[A] for (loopentry, oz_max, tx, ty, nx_l, ny_l, n_l, t_h, t_h2, tx_h, tx_h2, ty_h, ty_h2, ix_h, ix_h2, iy_h, iy_h2, A_head) = ((loopentrys[A], oz_maxs[A], s[:tx], s[:ty], s[:nx_l], s[:ny_l], s[:n_l], s[:t_h], s[:t_h2], s[:tx_h], s[:tx_h2], s[:ty_h], s[:ty_h2], s[:ix_h], s[:ix_h2], s[:iy_h], s[:iy_h2], s[:A_head]),) - )... -) -# $((wrap_if(:($i > $(loopentry-1)), -# quote -# if (2*$tx_h <= $nx_l && $ix_h>0 && $ix_h<=size($A,1) && $iy>0 && $iy<=size($A,2) && 0<$iz+$oz_max<=size($A,3)) -# $A_head[$tx_h,$ty] = $A[$ix_h,$iy,$iz+$oz_max] -# end -# if (2*$tx_h2 > $nx_l && $ix_h2>0 && $ix_h2<=size($A,1) && $iy>0 && $iy<=size($A,2) && 0<$iz+$oz_max<=size($A,3)) -# $A_head[$tx_h2,$ty] = $A[$ix_h2,$iy,$iz+$oz_max] -# end -# end -# ;unless=(loopentry<=mainloopstart) -# ) -# for (A, s) in shmem_vars if (use_shmemhalos[A] && use_shmem_xs[A] && !use_shmem_ys[A]) for (loopentry, oz_max, tx, ty, nx_l, ny_l, tx_h, tx_h2, ix_h, ix_h2, A_head) = ((loopentrys[A], oz_maxs[A], s[:tx], s[:ty], s[:nx_l], s[:ny_l], s[:tx_h], s[:tx_h2], s[:ix_h], s[:ix_h2], s[:A_head]),) -# )... -# ) -# $((wrap_if(:($i > $(loopentry-1)), -# quote -# if (2*$ty_h <= $ny_l && $ix>0 && $ix<=size($A,1) && $iy_h>0 && $iy_h<=size($A,2) && 0<$iz+$oz_max<=size($A,3)) -# $A_head[$tx,$ty_h] = $A[$ix,$iy_h,$iz+$oz_max] -# end -# if (2*$ty_h2 > $ny_l && $ix>0 && $ix<=size($A,1) && $iy_h2>0 && $iy_h2<=size($A,2) && 0<$iz+$oz_max<=size($A,3)) -# $A_head[$tx,$ty_h2] = $A[$ix,$iy_h2,$iz+$oz_max] -# end -# end -# ;unless=(loopentry<=mainloopstart) -# ) -# for (A, s) in shmem_vars if (use_shmemhalos[A] && !use_shmem_xs[A] && use_shmem_ys[A]) for (loopentry, oz_max, tx, ty, nx_l, ny_l, ty_h, ty_h2, iy_h, iy_h2, A_head) = ((loopentrys[A], oz_maxs[A], s[:tx], s[:ty], s[:nx_l], s[:ny_l], s[:ty_h], s[:ty_h2], s[:iy_h], s[:iy_h2], s[:A_head]),) -# )... -# ) -$((wrap_if(:($i > $(loopentry-1)), - quote - if ($ix>0 && $ix<=size($A,1) && $iy>0 && $iy<=size($A,2) && 0<$iz+$oz_max<=size($A,3)) - $A_head[$tx,$ty] = $A[$ix,$iy,$iz+$oz_max] - end - end - ;unless=(loopentry<=mainloopstart) - ) - for (A, s) in shmem_vars if !use_shmemhalos[A] for (loopentry, oz_max, tx, ty, nx_l, ny_l, A_head) = ((loopentrys[A], oz_maxs[A], s[:tx], s[:ty], s[:nx_l], s[:ny_l], s[:A_head]),) - )... -) -$(use_any_shmem ? - :( @sync_threads() - ) : NOEXPR -) -$((wrap_if(:($i > $(loopentry-1)), - :( $reg = (0<$ix+$(oxy[1])<=size($A,1) && 0<$iy+$(oxy[2])<=size($A,2) && 0<$iz+$oz<=size($A,3)) ? $(regtarget(A, (oxy...,oz), indices)) : $reg - ) - ;unless=(loopentry<=mainloopstart) - ) - for A in optvars for (oxy, regs) in regqueue_heads[A] for (oz, reg) in regs for loopentry = (loopentrys[A],) if !use_shmems[A] - )... -) -$((wrap_if(:($i > $(loopentry-1)), - use_shmemhalo ? - :( $reg = $(regsource(A_head, oxy, (tx, ty))) # e.g. A_ixm1_iyp2_izp3 = A_izp3[tx - 1, ty + 2] - ) - : - :( $reg = (0<$tx+$(oxy[1])<=$nx_l && 0<$ty+$(oxy[2])<=$ny_l) ? $(regsource(A_head, oxy, (tx, ty))) : (0<$ix+$(oxy[1])<=size($A,1) && 0<$iy+$(oxy[2])<=size($A,2) && 0<$iz+$oz<=size($A,3)) ? $(regtarget(A, (oxy...,oz), indices)) : $reg - ) - ;unless=(loopentry<=mainloopstart) - ) - for (A, s) in shmem_vars for (oxy, regs) in regqueue_heads[A] for (oz, reg) in regs for (use_shmemhalo, loopentry, tx, ty, nx_l, ny_l, A_head) = ((use_shmemhalos[A], loopentrys[A], s[:tx], s[:ty], s[:nx_l], s[:ny_l], s[:A_head]),) - )... -) -$((wrap_if(:($i > 0), - quote - $body - end; - unless=(mainloopstart>=1) - ) -)) -$(( # NOTE: the if statement is not needed here as we only deal with registers - # wrap_if(:($i > $(loopentry-1)), - :( - $(regs[oz]) = $(regs[oz+1]) # e.g. A_ixm1_iyp2_iz = A_ixm1_iyp2_izp1 - ) - # ;unless=(loopentry<=mainloopstart) - # ) - for A in optvars for regs in values(regqueue_tails[A]) for oz in sort(keys(regs)) for (loopentry, oz_max) = ((loopentrys[A], oz_maxs[A]),) if oz<=oz_max-2 - )... -) -$((wrap_if(:($i > $(loopentry-1)), - use_shmemhalo ? - :( $reg = $(regsource(A_head, oxy, (tx, ty))) # e.g. A_ixm3_iyp2_izp2 = A_izp3[tx - 3, ty + 2] - ) - : - :( $reg = (0<$tx+$(oxy[1])<=$nx_l && 0<$ty+$(oxy[2])<=$ny_l) ? $(regsource(A_head, oxy, (tx, ty))) : (0<$ix+$(oxy[1])<=size($A,1) && 0<$iy+$(oxy[2])<=size($A,2) && 0<$iz+$oz<=size($A,3)) ? $(regtarget(A, (oxy...,oz), indices)) : $reg - ) - ;unless=(loopentry<=mainloopstart) - ) - for (A, s) in shmem_vars for (oxy, regs) in regqueue_tails[A] for (oz, reg) in regs for (use_shmemhalo, loopentry, oz_max, tx, ty, nx_l, ny_l, A_head) = ((use_shmemhalos[A], loopentrys[A], oz_maxs[A], s[:tx], s[:ty], s[:nx_l], s[:ny_l], s[:A_head]),) if oz==oz_max-1 && !(haskey(regqueue_heads[A], oxy) && haskey(regqueue_heads[A][oxy], oz_max)) - )... -) -# TODO: remove these as soon as the above is tested: -# $((wrap_if(:($i > $(loopentry-1)), -# :( $reg = $(regsource(A_head, oxy, (tx, ty))) # e.g. A_ixm3_iyp2_izp2 = A_izp3[tx - 3, ty + 2] -# ) -# ;unless=(loopentry<=mainloopstart) -# ) -# for (A, s) in shmem_vars for (oxy, regs) in regqueue_tails[A] for (oz, reg) in regs for (loopentry, oz_max, tx, ty, A_head) = ((loopentrys[A], oz_maxs[A], s[:tx], s[:ty], s[:A_head]),) if oz==oz_max-1 && !(haskey(regqueue_heads[A], oxy) && haskey(regqueue_heads[A][oxy], oz_max)) -# )... -# ) -$(( # NOTE: the if statement is not needed here as we only deal with registers - # wrap_if(:($i > $(loopentry-1)), - :( - $reg = $(regqueue_heads[A][oxy][oz_max]) # e.g. A_ixm1_iyp2_izp2 = A_ixm1_iyp2_izp3 - ) - # ;unless=(loopentry<=mainloopstart) - # ) - for A in optvars for (oxy, regs) in regqueue_tails[A] for (oz, reg) in regs for (loopentry, oz_max) = ((loopentrys[A], oz_maxs[A]),) if oz==oz_max-1 && haskey(regqueue_heads[A], oxy) && haskey(regqueue_heads[A][oxy], oz_max) - )... -) - end - # ;unroll=true - ) # wrap_loop end -) # end - -# Wrap-up-loop -# ParallelStencil.@unroll for $i = $(mainloopend+1):$loopend -# $tz_g = $i + $loopoffset -# if ($tz_g > $rangelength_z) ParallelStencil.@return_nothing; end -# $iz = ($tz_g < 1) ? $range_z_start-(1-$tz_g) : $range_z # TODO: this will probably always be formulated with range_z_start -# $((wrap_if(:($i > $(loopentry-1)), -# quote -# @sync_threads() -# if (2*$t_h <= $nx_l*$ny_l && $ix_h>0 && $ix_h<=size($A,1) && $iy_h>0 && $iy_h<=size($A,2) && 0<$iz+$oz_max<=size($A,3)) -# $A_head[$tx_h,$ty_h] = $A[$ix_h,$iy_h,$iz+$oz_max] -# end -# if (2*$t_h2 <= $nx_l*$ny_l && $ix_h2>0 && $ix_h2<=size($A,1) && $iy_h2>0 && $iy_h2<=size($A,2) && 0<$iz+$oz_max<=size($A,3)) -# $A_head[$tx_h2,$ty_h2] = $A[$ix_h2,$iy_h2,$iz+$oz_max] -# end -# @sync_threads() -# end -# ;unless=(loopentry<=mainloopstart) -# ) -# for (A, s) in shmem_vars for (loopentry, oz_max, tx, ty, nx_l, ny_l, t_h, t_h2, tx_h, tx_h2, ty_h, ty_h2, ix_h, ix_h2, iy_h, iy_h2, A_head) = ((loopentrys[A], oz_maxs[A], s[:tx], s[:ty], s[:nx_l], s[:ny_l], s[:t_h], s[:t_h2], s[:tx_h], s[:tx_h2], s[:ty_h], s[:ty_h2], s[:ix_h], s[:ix_h2], s[:iy_h], s[:iy_h2], s[:A_head]),) -# )... -# ) -# $((wrap_if(:($i > $(loopentry-1)), -# :( $reg = (0<$ix+$(oxy[1])<=size($A,1) && 0<$iy+$(oxy[2])<=size($A,2) && 0<$iz+$oz<=size($A,3)) ? $(regtarget(A, (oxy...,oz), indices)) : $reg -# ) -# ;unless=(loopentry<=mainloopstart) -# ) -# for A in optvars for (oxy, regs) in regqueue_heads[A] for (oz, reg) in regs for loopentry = (loopentrys[A],) if !use_shmems[A] -# )... -# ) -# $((wrap_if(:($i > $(loopentry-1)), -# :( $reg = $(regsource(A_head, oxy, (tx, ty))) # e.g. A_ixm1_iyp2_izp3 = A_izp3[tx - 1, ty + 2] -# ) -# ;unless=(loopentry<=mainloopstart) -# ) -# for (A, s) in shmem_vars for (oxy, regs) in regqueue_heads[A] for reg in values(regs) for (loopentry, tx, ty, A_head) = ((loopentrys[A], s[:tx], s[:ty], s[:A_head]),) -# )... -# ) -# $((wrap_if(:($i > 0), -# quote -# $body -# end; -# unless=(mainloopstart>=1) -# ) -# )) -# $((wrap_if(:($i > $(loopentry-1)), -# :( -# $(regs[oz]) = $(regs[oz+1]) # e.g. A_ixm1_iyp2_iz = A_ixm1_iyp2_izp1 -# ) -# ;unless=(loopentry<=mainloopstart) -# ) -# for A in optvars for regs in values(regqueue_tails[A]) for oz in sort(keys(regs)) for (loopentry, oz_max) = ((loopentrys[A], oz_maxs[A]),) if oz<=oz_max-2 -# )... -# ) -# $((wrap_if(:($i > $(loopentry-1)), -# :( $reg = $(regsource(A_head, oxy, (tx, ty))) # e.g. A_ixm3_iyp2_izp2 = A_izp3[tx - 3, ty + 2] -# ) -# ;unless=(loopentry<=mainloopstart) -# ) -# for (A, s) in shmem_vars for (oxy, regs) in regqueue_tails[A] for (oz, reg) in regs for (loopentry, oz_max, tx, ty, A_head) = ((loopentrys[A], oz_maxs[A], s[:tx], s[:ty], s[:A_head]),) if oz==oz_max-1 && !(haskey(regqueue_heads[A], oxy) && haskey(regqueue_heads[A][oxy], oz_max)) -# )... -# ) -# $((wrap_if(:($i > $(loopentry-1)), -# :( -# $reg = $(regqueue_heads[A][oxy][oz_max]) # e.g. A_ixm1_iyp2_izp2 = A_ixm1_iyp2_izp3 -# ) -# ;unless=(loopentry<=mainloopstart) -# ) -# for A in optvars for (oxy, regs) in regqueue_tails[A] for (oz, reg) in regs for (loopentry, oz_max) = ((loopentrys[A], oz_maxs[A]),) if oz==oz_max-1 && haskey(regqueue_heads[A], oxy) && haskey(regqueue_heads[A][oxy], oz_max) -# )... -# ) - -# $tz_g = $i + $loopoffset -# if ($tz_g > $rangelength_z) ParallelStencil.@return_nothing; end -# $iz = ($tz_g < 1) ? $range_z_start-(1-$tz_g) : $range_z # TODO: this will probably always be formulated with range_z_start -# $(( -# # wrap_if(:(($(loopentry-1) < $i < $(shmem_loopentry)) || ($(shmem_loopexit) < $i)), -# :( $reg = (0<$ix+$(oxy[1])<=size($A,1) && 0<$iy+$(oxy[2])<=size($A,2) && 0<$iz+$oz<=size($A,3)) ? $(regtarget(A, (oxy...,oz), indices)) : $reg -# ) -# for A in keys(shmem_vars) for (oxy, regs) in regqueue_heads[A] for (oz, reg) in regs for loopentry = (loopentrys[A],) -# )... -# ) -# $(( -# :( $reg = (0<$ix+$(oxy[1])<=size($A,1) && 0<$iy+$(oxy[2])<=size($A,2) && 0<$iz+$oz<=size($A,3)) ? $(regtarget(A, (oxy...,oz), indices)) : $reg -# ) -# for A in optvars for (oxy, regs) in regqueue_heads[A] for (oz, reg) in regs for loopentry = (loopentrys[A],) if !use_shmems[A] -# )... -# ) -# $(( -# quote -# $body -# end -# )) -# $(( -# :( -# $(regs[oz]) = $(regs[oz+1]) # e.g. A_ixm1_iyp2_iz = A_ixm1_iyp2_izp1 -# ) -# for A in optvars for regs in values(regqueue_tails[A]) for oz in sort(keys(regs)) for (loopentry, oz_max) = ((loopentrys[A], oz_maxs[A]),) if oz<=oz_max-2 -# )... -# ) -# $(( -# :( -# $reg = $(regqueue_heads[A][oxy][oz_max]) # e.g. A_ixm1_iyp2_izp2 = A_ixm1_iyp2_izp3 -# ) -# for A in optvars for (oxy, regs) in regqueue_tails[A] for (oz, reg) in regs for (loopentry, oz_max) = ((loopentrys[A], oz_maxs[A]),) if oz==oz_max-1 && haskey(regqueue_heads[A], oxy) && haskey(regqueue_heads[A][oxy], oz_max) -# )... -# ) - # end - end - else - @ArgumentError("memopt: only loopdim=3 is currently supported.") - end - store_metadata(metadata_module, is_parallel_kernel, caller, offset_mins, offset_maxs, offsets, optvars, loopdim, loopsize, optranges, use_shmemhalos) - # @show QuoteNode(ParallelKernel.simplify_varnames!(ParallelKernel.remove_linenumbernodes!(deepcopy(body)))) - return body -end - - -function memopt(metadata_module::Module, is_parallel_kernel::Bool, caller::Module, indices::Union{Symbol,Expr}, optvars::Union{Expr,Symbol}, body::Expr; package::Symbol=get_package(caller)) - loopdim = isa(indices,Expr) ? length(indices.args) : 1 - loopsize = compute_loopsize(package) - optranges = nothing - use_shmemhalos = nothing - optimize_halo_read = true - return memopt(metadata_module, is_parallel_kernel, caller, indices, optvars, loopdim, loopsize, optranges, use_shmemhalos, optimize_halo_read, body; package=package) -end - - -function shortif(caller::Module, else_val, if_expr; package::Symbol=get_package(caller)) - if (package ∉ SUPPORTED_PACKAGES) @KeywordArgumentError("$ERRMSG_UNSUPPORTED_PACKAGE (obtained: $package).") end - @capture(if_expr, if condition_ body_ end) || @ArgumentError("@shortif: the second argument must be an if statement.") - @capture(body, lhs_ = rhs_) || @ArgumentError("@shortif: the if statement body must contain a assignement.") - return :($lhs = $condition ? $rhs : $else_val) -end - - -## FUNCTIONS FOR SHARED MEMORY ALLOCATION - - -## HELPER FUNCTIONS - -function eval_offsets(caller::Module, body::Expr, indices::NTuple{N,<:Union{Symbol,Expr}} where N, int_type::Type{<:Integer}) - return postwalk(body) do ex - if !is_stencil_access(ex, indices...) return ex; end - @capture(ex, A_[indices_expr__]) || @ModuleInternalError("a stencil access could not be pattern matched.") - for i = 1:length(indices) - offset_expr = substitute(indices_expr[i], indices[i], 0) - offset = eval_arg(caller, offset_expr) - if (offset > 0) indices_expr[i] = :($(indices[i]) + $(int_type(offset)) ) - elseif (offset < 0) indices_expr[i] = :($(indices[i]) - $(int_type(abs(offset)))) - else indices_expr[i] = indices[i] - end - end - return :($A[$(indices_expr...)]) - end -end - -function extract_offsets(caller::Module, body::Expr, indices::NTuple{N,<:Union{Symbol,Expr}} where N, int_type::Type{<:Integer}, optvars::NTuple{N,Symbol} where N, loopdim::Integer) - offsets_by_xy = Dict(A => Dict() for A in optvars) - offsets_by_z = Dict(A => Dict() for A in optvars) - postwalk(body) do ex - if is_stencil_access(ex, indices...) - @capture(ex, A_[indices_expr__]) || @ModuleInternalError("a stencil access could not be pattern matched.") - if A in optvars - offsets = () - for i = 1:length(indices) - offset_expr = substitute(indices_expr[i], indices[i], 0) - offset = int_type(eval_arg(caller, offset_expr)) # TODO: do this and cast later to enable unsigned integer (also dealing with negative rangers is required elsewhere): offset = eval_arg(caller, offset_expr) - offsets = (offsets..., offset) - end - if loopdim == 3 - k1 = offsets[1:2] - k2 = offsets[end] - if haskey(offsets_by_xy[A], k1) && haskey(offsets_by_xy[A][k1], k2) offsets_by_xy[A][k1][k2] += 1 - elseif haskey(offsets_by_xy[A], k1) offsets_by_xy[A][k1][k2] = 1 - else offsets_by_xy[A][k1] = Dict(k2 => 1) - end - k1 = offsets[end] - k2 = offsets[1:2] - if haskey(offsets_by_z[A], k1) && haskey(offsets_by_z[A][k1], k2) offsets_by_z[A][k1][k2] += 1 - elseif haskey(offsets_by_z[A], k1) offsets_by_z[A][k1][k2] = 1 - else offsets_by_z[A][k1] = Dict(k2 => 1) - end - else - @ArgumentError("memopt: only loopdim=3 is currently supported.") - end - end - end - return ex - end - return offsets_by_xy, offsets_by_z -end - -function remove_single_point_optvars(optvars, optranges_arg, offsets, offsets_by_z) - return tuple((A for A in optvars if !(length(keys(offsets[A]))==1 && length(keys(offsets_by_z[A]))==1) || (!isnothing(optranges_arg) && A ∈ keys(optranges_arg)))...) -end - -function define_optranges(optranges_arg, optvars, offsets, int_type, package) - compute_capability = get_compute_capability(package) - optranges = Dict() - for A in optvars - zspan_max = 0 - oxy_zspan_max = () - for oxy in keys(offsets[A]) - zspan = length(keys(offsets[A][oxy])) - if zspan > zspan_max - zspan_max = zspan - oxy_zspan_max = oxy - end - end - fullrange = typemin(int_type):typemax(int_type) - pointrange_x = oxy_zspan_max[1]: oxy_zspan_max[1] - pointrange_y = oxy_zspan_max[2]: oxy_zspan_max[2] - if (!isnothing(optranges_arg) && A ∈ keys(optranges_arg)) optranges[A] = getproperty(optranges_arg, A) - elseif (compute_capability < v"8" && (length(optvars) <= FULLRANGE_THRESHOLD)) optranges[A] = (fullrange, fullrange, fullrange) - elseif (USE_FULLRANGE_DEFAULT == (true, true, true)) optranges[A] = (fullrange, fullrange, fullrange) - elseif (USE_FULLRANGE_DEFAULT == (false, true, true)) optranges[A] = (pointrange_x, fullrange, fullrange) - elseif (USE_FULLRANGE_DEFAULT == (true, false, true)) optranges[A] = (fullrange, pointrange_y, fullrange) - elseif (USE_FULLRANGE_DEFAULT == (false, false, true)) optranges[A] = (pointrange_x, pointrange_y, fullrange) - end - end - return optranges -end - -function define_regqueues(offsets::Dict{Symbol, Dict{Any, Any}}, optranges::Dict{Any, Any}, optvars::NTuple{N,Symbol} where N, indices::NTuple{N,<:Union{Symbol,Expr}} where N, int_type::Type{<:Integer}, loopdim::Integer) - regqueue_heads = Dict(A => Dict() for A in optvars) - regqueue_tails = Dict(A => Dict() for A in optvars) - offset_mins = Dict{Symbol, NTuple{3,Integer}}() - offset_maxs = Dict{Symbol, NTuple{3,Integer}}() - nb_regs_heads = Dict{Symbol, Integer}() - nb_regs_tails = Dict{Symbol, Integer}() - for A in optvars - regqueue_heads[A], regqueue_tails[A], offset_mins[A], offset_maxs[A], nb_regs_heads[A], nb_regs_tails[A] = define_regqueue(offsets[A], optranges[A], A, indices, int_type, loopdim) - end - return regqueue_heads, regqueue_tails, offset_mins, offset_maxs, nb_regs_heads, nb_regs_tails -end - -function define_regqueue(offsets::Dict{Any, Any}, optranges::NTuple{3,UnitRange}, A::Symbol, indices::NTuple{N,<:Union{Symbol,Expr}} where N, int_type::Type{<:Integer}, loopdim::Integer) - regqueue_head = Dict() - regqueue_tail = Dict() - nb_regs_head = 0 - nb_regs_tail = 0 - if loopdim == 3 - optranges_xy = optranges[1:2] - optranges_z = optranges[3] - offsets_xy = filter(oxy -> all(oxy .∈ optranges_xy), keys(offsets)) - if isempty(offsets_xy) @IncoherentArgumentError("incoherent argument in memopt: optranges in x-y dimension do not include any array access.") end - offset_min = (typemax(int_type), typemax(int_type), typemax(int_type)) - offset_max = (typemin(int_type), typemin(int_type), typemin(int_type)) - for oxy in offsets_xy - offsets_z = filter(x -> x ∈ optranges_z, keys(offsets[oxy])) - if isempty(offsets_z) @IncoherentArgumentError("incoherent argument in memopt: optranges in z dimension do not include any array access.") end - offset_min = (min(offset_min[1], oxy[1]), - min(offset_min[2], oxy[2]), - min(offset_min[3], minimum(offsets_z))) - offset_max = (max(offset_max[1], oxy[1]), - max(offset_max[2], oxy[2]), - max(offset_max[3], maximum(offsets_z))) - end - oz_max = offset_max[3] - for oxy in offsets_xy - offsets_z = sort(filter(x -> x ∈ optranges_z, keys(offsets[oxy]))) - k1 = oxy - for oz = offsets_z[1]:oz_max-1 - k2 = oz - if haskey(regqueue_tail, k1) && haskey(regqueue_tail[k1], k2) @ModuleInternalError("regqueue_tail entry exists already.") end - reg = gensym_world(varname(A, (oxy..., oz)), @__MODULE__); nb_regs_tail += 1 - if haskey(regqueue_tail, k1) regqueue_tail[k1][k2] = reg - else regqueue_tail[k1] = Dict(k2 => reg) - end - end - oz = offsets_z[end] - if oz == oz_max - k2 = oz - if haskey(regqueue_head, k1) && haskey(regqueue_head[k1], k2) @ModuleInternalError("regqueue_head entry exists already.") end - reg = gensym_world(varname(A, (oxy..., oz)), @__MODULE__); nb_regs_head += 1 - if haskey(regqueue_head, k1) regqueue_head[k1][k2] = reg - else regqueue_head[k1] = Dict(k2 => reg) - end - end - end - else - @ArgumentError("memopt: only loopdim=3 is currently supported.") - end - return regqueue_head, regqueue_tail, offset_min, offset_max, nb_regs_head, nb_regs_tail -end - -function define_helper_variables(offset_mins::Dict{Symbol, <:NTuple{3,Integer}}, offset_maxs::Dict{Symbol, <:NTuple{3,Integer}}, optvars::NTuple{N,Symbol} where N, use_shmemhalos_arg, loopdim::Integer) - oz_maxs, hx1s, hy1s, hx2s, hy2s, use_shmems, use_shmem_xs, use_shmem_ys, use_shmemhalos, use_shmemindices, offset_spans, oz_spans, loopentrys = Dict(), Dict(), Dict(), Dict(), Dict(), Dict(), Dict(), Dict(), Dict(), Dict(), Dict(), Dict(), Dict() - if loopdim == 3 - for A in optvars - offset_min, offset_max = offset_mins[A], offset_maxs[A] - oz_max = offset_max[3] - hx1, hy1 = -1 .* offset_min[1:2] - hx2, hy2 = offset_max[1:2] - use_shmem_x = (hx1 + hx2 > 0) - use_shmem_y = (hy1 + hy2 > 0) - use_shmem = use_shmem_x || use_shmem_y - use_shmemhalo = if (!isnothing(use_shmemhalos_arg) && (A ∈ keys(use_shmemhalos_arg))) getproperty(use_shmemhalos_arg, A) - elseif !(use_shmem_x && use_shmem_y) USE_SHMEMHALO_1D_DEFAULT - else USE_SHMEMHALO_DEFAULT - end - use_shmemindex = use_shmem && use_shmemhalo && (use_shmem_x && use_shmem_y) - offset_span = offset_max .- offset_min - oz_span = offset_span[3] - loopentry = 1 - oz_span #TODO: make possibility to do first and last read in z dimension directly into registers without halo - oz_maxs[A], hx1s[A], hy1s[A], hx2s[A], hy2s[A], use_shmems[A], use_shmem_xs[A], use_shmem_ys[A], use_shmemhalos[A], use_shmemindices[A], offset_spans[A], oz_spans[A], loopentrys[A] = oz_max, hx1, hy1, hx2, hy2, use_shmem, use_shmem_x, use_shmem_y, use_shmemhalo, use_shmemindex, offset_span, oz_span, loopentry - end - else - @ArgumentError("memopt: only loopdim=3 is currently supported.") - end - return oz_maxs, hx1s, hy1s, hx2s, hy2s, use_shmems, use_shmem_xs, use_shmem_ys, use_shmemhalos, use_shmemindices, offset_spans, oz_spans, loopentrys -end - -function define_shmem_index_groups(hx1s, hy1s, hx2s, hy2s, optvars::NTuple{N,Symbol} where N, use_shmems::Dict{Any, Any}, loopdim::Integer) - shmem_index_groups = Dict() - if loopdim == 3 - for A in optvars - if use_shmems[A] - k = (hx1s[A], hy1s[A], hx2s[A], hy2s[A]) - if !haskey(shmem_index_groups, k) shmem_index_groups[k] = (A,) - else shmem_index_groups[k] = (shmem_index_groups[k]..., A) - end - end - end - end - return shmem_index_groups -end - -function define_shmem_vars(oz_maxs::Dict{Any, Any}, hx1s, hy1s, hx2s, hy2s, optvars::NTuple{N,Symbol} where N, indices, use_shmems::Dict{Any, Any}, use_shmem_xs, use_shmem_ys, shmem_index_groups, use_shmemhalos, use_shmemindices, loopdim::Integer) - ix, iy, iz = indices - shmem_vars = Dict(A => Dict() for A in optvars if use_shmems[A]) - if loopdim == 3 - for vars in values(shmem_index_groups) - suffix = join(string.(vars), "_") - sym_tx = gensym_world("tx_$suffix", @__MODULE__) - sym_ty = gensym_world("ty_$suffix", @__MODULE__) - sym_nx_l = gensym_world("nx_l_$suffix", @__MODULE__) - sym_ny_l = gensym_world("ny_l_$suffix", @__MODULE__) - sym_t_h = gensym_world("t_h_$suffix", @__MODULE__) - sym_t_h2 = gensym_world("t_h2_$suffix", @__MODULE__) - sym_tx_h = gensym_world("tx_h_$suffix", @__MODULE__) - sym_tx_h2 = gensym_world("tx_h2_$suffix", @__MODULE__) - sym_ty_h = gensym_world("ty_h_$suffix", @__MODULE__) - sym_ty_h2 = gensym_world("ty_h2_$suffix", @__MODULE__) - sym_ix_h = gensym_world("ix_h_$suffix", @__MODULE__) - sym_ix_h2 = gensym_world("ix_h2_$suffix", @__MODULE__) - sym_iy_h = gensym_world("iy_h_$suffix", @__MODULE__) - sym_iy_h2 = gensym_world("iy_h2_$suffix", @__MODULE__) - for A in vars - if use_shmemindices[A] - n_l = quote $sym_nx_l*$sym_ny_l end - shmem_vars[A][:tx] = sym_tx - shmem_vars[A][:ty] = sym_ty - shmem_vars[A][:nx_l] = sym_nx_l - shmem_vars[A][:ny_l] = sym_ny_l - shmem_vars[A][:n_l] = n_l - shmem_vars[A][:t_h] = sym_t_h - shmem_vars[A][:t_h2] = sym_t_h2 - shmem_vars[A][:tx_h] = sym_tx_h - shmem_vars[A][:tx_h2] = sym_tx_h2 - shmem_vars[A][:ty_h] = sym_ty_h - shmem_vars[A][:ty_h2] = sym_ty_h2 - shmem_vars[A][:ix_h] = sym_ix_h - shmem_vars[A][:ix_h2] = sym_ix_h2 - shmem_vars[A][:iy_h] = sym_iy_h - shmem_vars[A][:iy_h2] = sym_iy_h2 - else - if use_shmemhalos[A] - use_shmem_x, use_shmem_y = use_shmem_xs[A], use_shmem_ys[A] - hx1, hy1, hx2, hy2 = hx1s[A], hy1s[A], hx2s[A], hy2s[A] - if use_shmem_x && use_shmem_y # NOTE: if the following expressions are noted with ":()" then it will cause a segmentation fault and run time. - tx = quote @threadIdx().x + $hx1 end - ty = quote @threadIdx().y + $hy1 end - nx_l = quote @blockDim().x + UInt32($(hx1+hx2)) end # NOTE: cast to UInt32 is necessary to avoid promotion, which can lead to a tuple with different integers, resulting in an error. - ny_l = quote @blockDim().y + UInt32($(hy1+hy2)) end # ... - n_l = quote $nx_l*$ny_l end - t_h = quote (@threadIdx().y-1)*@blockDim().x + @threadIdx().x end # NOTE: here it must be bx, not @blockDim().x - t_h2 = quote $t_h + $nx_l*$ny_l - @blockDim().x*@blockDim().y end - ty_h = quote ($t_h-1) ÷ $nx_l + 1 end - tx_h = quote ($t_h-1) % $nx_l + 1 end # NOTE: equivalent to (worse performance has uses registers probably differently): ($t_h-1) - $nx_l*($ty_h-1) + 1 - ty_h2 = quote ($t_h2-1) ÷ $nx_l + 1 end - tx_h2 = quote ($t_h2-1) % $nx_l + 1 end # NOTE: equivalent to (worse performance has uses registers probably differently): ($t_h2-1) - $nx_l*($ty_h2-1) + 1 - ix_h = quote $ix - @threadIdx().x + $tx_h - $hx1 end # NOTE: here it must be @blockDim().x, not bx - ix_h2 = quote $ix - @threadIdx().x + $tx_h2 - $hx1 end # ... - iy_h = quote $iy - @threadIdx().y + $ty_h - $hy1 end # ... - iy_h2 = quote $iy - @threadIdx().y + $ty_h2 - $hy1 end # ... - elseif use_shmem_x - tx = quote @threadIdx().x + $hx1 end - ty = quote @threadIdx().y + $hy1 end - nx_l = quote @blockDim().x + UInt32($(hx1+hx2)) end # NOTE: cast to UInt32 is necessary to avoid promotion, which can lead to a tuple with different integers, resulting in an error. - ny_l = quote @blockDim().y end - tx_h = quote @threadIdx().x end - ty_h = quote @threadIdx().y end - tx_h2 = quote @threadIdx().x + $(hx1+hx2) end # NOTE: alternative: shmem_vars[A][:tx_h2] = :(@threadIdx().x + @blockDim().x) - ty_h2 = ty_h - ix_h = quote $ix - @threadIdx().x + $tx_h - $hx1 end - ix_h2 = quote $ix - @threadIdx().x + $tx_h2 - $hx1 end - iy_h = quote $iy - @threadIdx().y + $ty_h - $hy1 end - iy_h2 = quote $iy - @threadIdx().y + $ty_h2 - $hy1 end - n_l = nx_l - t_h = tx_h - t_h2 = tx_h2 - elseif use_shmem_y - tx = quote @threadIdx().x + $hx1 end - ty = quote @threadIdx().y + $hy1 end - nx_l = quote @blockDim().x end - ny_l = quote @blockDim().y + UInt32($(hy1+hy2)) end # NOTE: cast to UInt32 is necessary to avoid promotion, which can lead to a tuple with different integers, resulting in an error. - tx_h = quote @threadIdx().x end - ty_h = quote @threadIdx().y end - tx_h2 = tx_h - ty_h2 = quote @threadIdx().y + $(hy1+hy2) end # NOTE: alternative: # shmem_vars[A][:ty_h2] = :(@threadIdx().y + @blockDim().y) - ix_h = quote $ix - @threadIdx().x + $tx_h - $hx1 end - ix_h2 = quote $ix - @threadIdx().x + $tx_h2 - $hx1 end - iy_h = quote $iy - @threadIdx().y + $ty_h - $hy1 end - iy_h2 = quote $iy - @threadIdx().y + $ty_h2 - $hy1 end - n_l = ny_l - t_h = ty_h - t_h2 = ty_h2 - end - shmem_vars[A][:tx] = tx - shmem_vars[A][:ty] = ty - shmem_vars[A][:nx_l] = nx_l - shmem_vars[A][:ny_l] = ny_l - shmem_vars[A][:n_l] = n_l - shmem_vars[A][:t_h] = t_h - shmem_vars[A][:t_h2] = t_h2 - shmem_vars[A][:tx_h] = tx_h - shmem_vars[A][:tx_h2] = tx_h2 - shmem_vars[A][:ty_h] = ty_h - shmem_vars[A][:ty_h2] = ty_h2 - shmem_vars[A][:ix_h] = ix_h - shmem_vars[A][:ix_h2] = ix_h2 - shmem_vars[A][:iy_h] = iy_h - shmem_vars[A][:iy_h2] = iy_h2 - else - shmem_vars[A][:tx] = :(@threadIdx().x) - shmem_vars[A][:ty] = :(@threadIdx().y) - shmem_vars[A][:nx_l] = :(@blockDim().x) - shmem_vars[A][:ny_l] = :(@blockDim().y) - end - end - shmem_vars[A][:A_head] = gensym_world(varname(A, (oz_maxs[A],); i="iz"), @__MODULE__) - end - end - else - @ArgumentError("memopt: only loopdim=3 is currently supported.") - end - return shmem_vars -end - -function define_shmem_exprs(shmem_vars::Dict{Symbol, Dict{Any, Any}}, loopdim::Integer) - exprs = Dict(A => Dict() for A in keys(shmem_vars)) - offset = () - if loopdim == 3 - for A in keys(shmem_vars) - exprs[A][:offset] = (length(offset) > 0) ? Expr(:call, :+, offset...) : 0 - offset = (offset..., :($(shmem_vars[A][:nx_l]) * $(shmem_vars[A][:ny_l]) * sizeof(eltype($A)))) - end - else - @ArgumentError("memopt: only loopdim=3 is currently supported.") - end - return exprs -end - -function define_shmem_z_ranges(offsets_by_z::Dict{Symbol, Dict{Any, Any}}, use_shmems::Dict{Any, Any}, loopdim::Integer) - shmem_z_ranges = Dict() - shmem_As = (A for (A, use_shmem) in use_shmems if use_shmem) - for A in shmem_As - shmem_z_ranges[A] = define_shmem_z_range(offsets_by_z[A], loopdim) - end - return shmem_z_ranges -end - -function define_shmem_z_range(offsets_by_z::Dict{Any, Any}, loopdim::Integer) - start, start_offsets_xy = find_rangelimit(offsets_by_z, loopdim; upper=false) - stop, stop_offsets_xy = find_rangelimit(offsets_by_z, loopdim; upper=true) - if (length(start_offsets_xy) != 1 || length(stop_offsets_xy) != 1 || start_offsets_xy[1] != stop_offsets_xy[1]) # NOTE: shared memory range is not reduced in asymmetric case - return minimum(keys(offsets_by_z)):maximum(keys(offsets_by_z)) - end - return start:stop -end - -function find_rangelimit(offsets_by_z::Dict{Any, Any}, loopdim::Integer; upper=false) - if loopdim == 3 - offsets_z = sort(keys(offsets_by_z); rev=upper) - oz1 = offsets_z[1] - rangelimit = oz1 - offsets_xy1 = (keys(offsets_by_z[oz1])...,) - if length(offsets_xy1) == 1 - rangelimit = offsets_z[2] - oxy1 = offsets_xy1[1] - for oz in offsets_z[2:end] - offsets_xy = (keys(offsets_by_z[oz])...,) - if (length(offsets_xy) == 1) && (offsets_xy[1] == oxy1) - rangelimit = offsets_z[oz+1] - else - break - end - end - end - else - @ArgumentError("memopt: only loopdim=3 is currently supported.") - end - return rangelimit, offsets_xy1 -end - -function define_shmem_loopentrys(loopentrys, shmem_z_ranges, offset_mins, loopdim::Integer) - shmem_loopentrys = Dict() - shmem_As = (A for A in keys(shmem_z_ranges)) - for A in shmem_As - shmem_loopentrys[A] = define_shmem_loopentry(loopentrys[A], shmem_z_ranges[A], offset_mins[A], loopdim) - end - return shmem_loopentrys -end - -function define_shmem_loopentry(loopentry, shmem_z_range, offset_min, loopdim::Integer) - if loopdim == 3 - shmem_loopentry = loopentry + (shmem_z_range.start - offset_min[3]) - else - @ArgumentError("memopt: only loopdim=3 is currently supported.") - end - return shmem_loopentry -end - -function define_shmem_loopexits(loopexit, shmem_z_ranges, offset_maxs, loopdim::Integer) - shmem_loopexits = Dict() - shmem_As = (A for A in keys(shmem_z_ranges)) - for A in shmem_As - shmem_loopexits[A] = define_shmem_loopexit(loopexit, shmem_z_ranges[A], offset_maxs[A], loopdim) - end - return shmem_loopexits -end - -function define_shmem_loopexit(loopexit, shmem_z_range, offset_max, loopdim::Integer) - if loopdim == 3 - shmem_loopexit = loopexit - (offset_max[3] - shmem_z_range.stop) - else - @ArgumentError("memopt: only loopdim=3 is currently supported.") - end - return shmem_loopexit -end - -function varname(A::Symbol, offsets::NTuple{N,Integer} where N; i::String="ix", j::String="iy", k::String="iz") - ndims = length(offsets) - ox = offsets[1] - x = if (ox > 0) i * "p" * string(ox) - elseif (ox < 0) i * "m" * string(abs(ox)) - else i - end - if ndims > 1 - oy = offsets[2] - y = if (oy > 0) j * "p" * string(oy) - elseif (oy < 0) j * "m" * string(abs(oy)) - else j - end - end - if ndims > 2 - oz = offsets[3] - z = if (oz > 0) k * "p" * string(oz) - elseif (oz < 0) k * "m" * string(abs(oz)) - else k - end - end - if (ndims == 1) return string(A, "_$(x)") - elseif (ndims == 2) return string(A, "_$(x)_$(y)") - elseif (ndims == 3) return string(A, "_$(x)_$(y)_$(z)") - end -end - -function regtarget(A::Symbol, offsets::NTuple{N,Integer} where N, indices::NTuple{N,<:Union{Symbol,Expr}} where N) - ndims = length(offsets) - ox = offsets[1] - ix = indices[1] - if (ox > 0) x = :($ix + $ox) - elseif (ox < 0) x = :($ix - $(abs(ox))) - else x = ix - end - if ndims > 1 - oy = offsets[2] - iy = indices[2] - if (oy > 0) y = :($iy + $oy) - elseif (oy < 0) y = :($iy - $(abs(oy))) - else y = iy - end - end - if ndims > 2 - oz = offsets[3] - iz = indices[3] - if (oz > 0) z = :($iz + $oz) - elseif (oz < 0) z = :($iz - $(abs(oz))) - else z = iz - end - end - if (ndims == 1) return :($A[$x]) - elseif (ndims == 2) return :($A[$x,$y]) - elseif (ndims == 3) return :($A[$x,$y,$z]) - end -end - -function regsource(A_head::Symbol, offsets::NTuple{N,Integer} where N, local_indices::NTuple{N,<:Union{Symbol,Expr}} where N) - ndims = length(offsets) - ox = offsets[1] - tx = local_indices[1] - if (ox > 0) x = :($tx + $ox) - elseif (ox < 0) x = :($tx - $(abs(ox))) - else x = tx - end - if ndims > 1 - oy = offsets[2] - ty = local_indices[2] - if (oy > 0) y = :($ty + $oy) - elseif (oy < 0) y = :($ty - $(abs(oy))) - else y = ty - end - end - if (ndims == 1) return :($A_head[$x]) - elseif (ndims == 2) return :($A_head[$x,$y]) # e.g. :($A_head[$tx,$ty-1]) - end -end - -function wrap_if(condition::Expr, block::Expr; unless::Bool=false) - if unless - return block - else - return quote - if $condition - $block - end - end - end -end - -function wrap_loop(index::Symbol, range::UnitRange, block::Expr; unroll=false) - if length(range) == 0 - return NOEXPR - elseif length(range) == 1 - return quote - $index = $(range.start) - $block - end - else - if unroll - return quote - $(( quote - $index = $i - $block - end - for i in range - )... - ) - end - else - return quote - for $index = $(range.start):$(range.stop) - $block - end - end - end - end -end - -function store_metadata(metadata_module::Module, is_parallel_kernel::Bool, caller::Module, offset_mins::Dict{Symbol, <:NTuple{3,Integer}}, offset_maxs::Dict{Symbol, <:NTuple{3,Integer}}, offsets::Dict{Symbol, Dict{Any, Any}}, optvars::NTuple{N,Symbol} where N, loopdim::Integer, loopsize::Integer, optranges::Dict{Any, Any}, use_shmemhalos) - memopt = true - nonconst_metadata = get_nonconst_metadata(caller) - stencilranges = NamedTuple(A => (offset_mins[A][1]:offset_maxs[A][1], offset_mins[A][2]:offset_maxs[A][2], offset_mins[A][3]:offset_maxs[A][3]) for A in optvars) - if nonconst_metadata - storeexpr = quote - is_parallel_kernel = $is_parallel_kernel - memopt = $memopt - nonconst_metadata = $nonconst_metadata - stencilranges = $stencilranges - offsets = $offsets - optvars = $optvars - loopdim = $loopdim - loopsize = $loopsize - optranges = $optranges - use_shmemhalos = $use_shmemhalos - end - else - storeexpr = quote - const is_parallel_kernel = $is_parallel_kernel - const memopt = $memopt - const nonconst_metadata = $nonconst_metadata - const stencilranges = $stencilranges - const offsets = $offsets - const optvars = $optvars - const loopdim = $loopdim - const loopsize = $loopsize - const optranges = $optranges - const use_shmemhalos = $use_shmemhalos - end - end - @eval(metadata_module, $storeexpr) -end - -Base.sort(keys::T; kwargs...) where T<:Base.AbstractSet = sort([keys...]; kwargs...) - - -# macro unroll(args...) check_initialized(__module__); checkargs_unroll(args...); esc(unroll(args...)); end - -# function checkargs_unroll(args...) -# if (length(args) != 1) @ArgumentError("wrong number of arguments.") end -# end - -# function unroll(expr) -# if @capture(expr, for i_ = range_ body__ end) #TODO: enable in instead of equal -# return quote -# for $i = $range -# $(body...) -# $(Expr(:loopinfo, nodes...)) -# end -# end -# else -# error("Syntax error: loopinfo needs a for loop") -# end -# end +@doc replace(ParallelKernel.GRIDDIM_DOC, "@init_parallel_kernel" => "@init_parallel_stencil") macro gridDim(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.@gridDim($(args...)))); end +@doc replace(ParallelKernel.BLOCKIDX_DOC, "@init_parallel_kernel" => "@init_parallel_stencil") macro blockIdx(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.@blockIdx($(args...)))); end +@doc replace(ParallelKernel.BLOCKDIM_DOC, "@init_parallel_kernel" => "@init_parallel_stencil") macro blockDim(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.@blockDim($(args...)))); end +@doc replace(ParallelKernel.THREADIDX_DOC, "@init_parallel_kernel" => "@init_parallel_stencil") macro threadIdx(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.@threadIdx($(args...)))); end +@doc replace(ParallelKernel.SYNCTHREADS_DOC, "@init_parallel_kernel" => "@init_parallel_stencil") macro sync_threads(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.@sync_threads($(args...)))); end +@doc replace(ParallelKernel.SHAREDMEM_DOC, "@init_parallel_kernel" => "@init_parallel_stencil") macro sharedMem(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.@sharedMem($(args...)))); end +@doc replace(ParallelKernel.FORALL_DOC, "@init_parallel_kernel" => "@init_parallel_stencil") macro ∀(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.@∀($(args...)))); end +@doc replace(replace(ParallelKernel.PKSHOW_DOC, "@init_parallel_kernel" => "@init_parallel_stencil"), "pk_show" => "ps_show") macro ps_show(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.@pk_show($(args...)))); end +@doc replace(replace(ParallelKernel.PKPRINTLN_DOC, "@init_parallel_kernel" => "@init_parallel_stencil"), "pk_println" => "ps_println") macro ps_println(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.@pk_println($(args...)))); end +@doc replace(ParallelKernel.WARPSIZE_DOC, "@init_parallel_kernel" => "@init_parallel_stencil") macro warpsize(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.@warpsize($(args...)))); end +@doc replace(ParallelKernel.LANEID_DOC, "@init_parallel_kernel" => "@init_parallel_stencil") macro laneid(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.@laneid($(args...)))); end +@doc replace(ParallelKernel.ACTIVE_MASK_DOC, "@init_parallel_kernel" => "@init_parallel_stencil") macro active_mask(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.@active_mask($(args...)))); end +@doc replace(ParallelKernel.SHFL_SYNC_DOC, "@init_parallel_kernel" => "@init_parallel_stencil") macro shfl_sync(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.@shfl_sync($(args...)))); end +@doc replace(ParallelKernel.SHFL_UP_SYNC_DOC, "@init_parallel_kernel" => "@init_parallel_stencil") macro shfl_up_sync(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.@shfl_up_sync($(args...)))); end +@doc replace(ParallelKernel.SHFL_DOWN_SYNC_DOC, "@init_parallel_kernel" => "@init_parallel_stencil") macro shfl_down_sync(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.@shfl_down_sync($(args...)))); end +@doc replace(ParallelKernel.SHFL_XOR_SYNC_DOC, "@init_parallel_kernel" => "@init_parallel_stencil") macro shfl_xor_sync(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.@shfl_xor_sync($(args...)))); end +@doc replace(ParallelKernel.VOTE_ANY_SYNC_DOC, "@init_parallel_kernel" => "@init_parallel_stencil") macro vote_any_sync(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.@vote_any_sync($(args...)))); end +@doc replace(ParallelKernel.VOTE_ALL_SYNC_DOC, "@init_parallel_kernel" => "@init_parallel_stencil") macro vote_all_sync(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.@vote_all_sync($(args...)))); end +@doc replace(ParallelKernel.VOTE_BALLOT_SYNC_DOC, "@init_parallel_kernel" => "@init_parallel_stencil") macro vote_ballot_sync(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.@vote_ballot_sync($(args...)))); end diff --git a/src/memopt.jl b/src/memopt.jl new file mode 100644 index 00000000..43175650 --- /dev/null +++ b/src/memopt.jl @@ -0,0 +1,1074 @@ +#TODO: add ParallelStencil.ParallelKernel. in front of all kernel lang in macros! Later: generalize more for z? + +## +macro loop(args...) check_initialized(__module__); checkargs_loop(args...); esc(loop(__module__, args...)); end + + +## +macro memopt(args...) check_initialized(__module__); checkargs_memopt(args...); esc(memopt(args[1], __module__, args[2:end]...)); end + + +## +macro shortif(args...) check_initialized(__module__); checktwoargs(args...); esc(shortif(__module__, args...)); end + + +## ARGUMENT CHECKS + +function checknoargs(args...) + if (length(args) != 0) @ArgumentError("no arguments allowed.") end +end + +function checksinglearg(args...) + if (length(args) != 1) @ArgumentError("wrong number of arguments.") end +end + +function checktwoargs(args...) + if (length(args) != 2) @ArgumentError("wrong number of arguments.") end +end + +function checkargs_loop(args...) + if (length(args) != 4) @ArgumentError("wrong number of arguments.") end +end + +function checkargs_memopt(args...) + if (length(args) != 8 && length(args) != 7 && length(args) != 4) @ArgumentError("wrong number of arguments.") end +end + + +## FUNCTIONS FOR PERFORMANCE OPTIMSATIONS + +function loop(caller::Module, index::Symbol, loopdim::Integer, loopsize, body; package::Symbol=get_package(caller)) + if (package ∉ SUPPORTED_PACKAGES) @KeywordArgumentError("$ERRMSG_UNSUPPORTED_PACKAGE (obtained: $package).") end + dimvar = (:x,:y,:z)[loopdim] + loopoffset = gensym_world("loopoffset", @__MODULE__) + i = gensym_world("i", @__MODULE__) + return quote + $loopoffset = (@blockIdx().$dimvar-1)*$loopsize + for $i = 1:$loopsize + $index = $i + $loopoffset + $body + end + end +end + +#TODO: add input check and errors +# TODO: create a run time check for requirement: +# In order to be able to read the data into shared memory in only two statements, the number of threads must be at least half of the size of the shared memory block plus halo; thus, the total number of threads in each dimension must equal the range length, as else there would be smaller thread blocks at the boundaries (threads overlapping the range are sent home). These smaller blocks would be likely not to match the criteria for a correct reading of the data to shared memory. In summary the following requirements must be matched: @gridDim().x*@blockDim().x - $rangelength_x == 0; @gridDim().y*@blockDim().y - $rangelength_y > 0 +function memopt(metadata_module::Module, is_parallel_kernel::Bool, caller::Module, indices::Union{Symbol,Expr}, optvars::Union{Expr,Symbol}, loopdim::Integer, loopsize::Integer, optranges::Union{Nothing, NamedTuple{t, <:NTuple{N,NTuple{3,UnitRange}} where N} where t}, use_shmemhalos::Union{Nothing, NamedTuple{t, <:NTuple{N,Bool} where N} where t}, optimize_halo_read::Bool, body::Expr; package::Symbol=get_package(caller)) + optvars = Tuple(extract_tuple(optvars)) #TODO: make this function actually return directly a tuple rather than an array + indices = Tuple(extract_tuple(indices)) + use_shmemhalos = isnothing(use_shmemhalos) ? use_shmemhalos : eval_arg(caller, use_shmemhalos) + optranges = isnothing(optranges) ? optranges : eval_arg(caller, optranges) + readonlyvars = find_vars(body, indices; readonly=true) + if length(indices) != 3 @IncoherentArgumentError("incoherent arguments memopt in @parallel[_indices] : optimization can only be applied in 3-D @parallel kernels and @parallel_indices kernels with three indices.") end + if optvars == (Symbol(""),) + optvars = Tuple(keys(readonlyvars)) + else + for A in optvars + if !haskey(readonlyvars, A) @IncoherentArgumentError("incoherent argument memopt in @parallel[_indices] : optimization can only be applied to arrays that are only read within the kernel (not applicable to: $A).") end + end + end + if (package ∉ SUPPORTED_PACKAGES) @KeywordArgumentError("$ERRMSG_UNSUPPORTED_PACKAGE (obtained: $package).") end + if (package == PKG_CUDA) int_type = INT_CUDA + elseif (package == PKG_AMDGPU) int_type = INT_AMDGPU + elseif (package == PKG_METAL) int_type = INT_METAL + elseif (package == PKG_THREADS) int_type = INT_THREADS + end + body = eval_offsets(caller, body, indices, int_type) + offsets, offsets_by_z = extract_offsets(caller, body, indices, int_type, optvars, loopdim) + optvars = remove_single_point_optvars(optvars, optranges, offsets, offsets_by_z) + if (length(optvars)==0) @IncoherentArgumentError("incoherent argument memopt in @parallel[_indices] : optimization can only be applied if there is at least one array that is read-only within the kernel (and accessed with a multi-point stencil). Set memopt=false for this kernel.") end + optranges = define_optranges(optranges, optvars, offsets, int_type, package) + regqueue_heads, regqueue_tails, offset_mins, offset_maxs, nb_regs_heads, nb_regs_tails = define_regqueues(offsets, optranges, optvars, indices, int_type, loopdim) + + if loopdim == 3 + oz_maxs, hx1s, hy1s, hx2s, hy2s, use_shmems, use_shmem_xs, use_shmem_ys, use_shmemhalos, use_shmemindices, offset_spans, oz_spans, loopentrys = define_helper_variables(offset_mins, offset_maxs, optvars, use_shmemhalos, loopdim) + oz_span_max = maximum(values(oz_spans)) + # TODO: this only leads to correct result after row two executions in a row, probably due to the same compiler bug has below. # loopsize = (oz_span_max<=0) ? 1 : loopsize # NOTE: if the stencilrange in z is only one point, no loop is needed. + loopstart = minimum(values(loopentrys)) + loopend = loopsize + use_any_shmem = any(values(use_shmems)) + shmem_index_groups = define_shmem_index_groups(hx1s, hy1s, hx2s, hy2s, optvars, use_shmems, loopdim) + shmem_vars = define_shmem_vars(oz_maxs, hx1s, hy1s, hx2s, hy2s, optvars, indices, use_shmems, use_shmem_xs, use_shmem_ys, shmem_index_groups, use_shmemhalos, use_shmemindices, loopdim) + shmem_exprs = define_shmem_exprs(shmem_vars, loopdim) + shmem_z_ranges = define_shmem_z_ranges(offsets_by_z, use_shmems, loopdim) + shmem_loopentrys = define_shmem_loopentrys(loopentrys, shmem_z_ranges, offset_mins, loopdim) + shmem_loopexits = define_shmem_loopexits(loopend, shmem_z_ranges, offset_maxs, loopdim) + mainloopstart = (optimize_halo_read && !isempty(shmem_loopentrys)) ? minimum(values(shmem_loopentrys)) : loopstart + mainloopend = loopend # TODO: the second loop split leads to wrong results, probably due to a compiler bug. # mainloopend = (optimize_halo_read && !isempty(shmem_loopexits) ) ? maximum(values(shmem_loopexits) ) : loopend + ix, iy, iz = indices + tz_g = THREADIDS_VARNAMES[3] + rangelength_z = RANGELENGTHS_VARNAMES[3] + ranges = RANGES_VARNAME + range_z = :(($ranges[3])[$tz_g]) + range_z_start = :(($ranges[3])[1]) + range_z_end = :(($ranges[3])[end]) + i = gensym_world("i", @__MODULE__) + loopoffset = gensym_world("loopoffset", @__MODULE__) + + for A in optvars + regqueue_tail = regqueue_tails[A] + regqueue_head = regqueue_heads[A] + for oxy in keys(regqueue_tail) + for oz in keys(regqueue_tail[oxy]) + body = substitute(body, regtarget(A, (oxy..., oz), indices), regqueue_tail[oxy][oz]) + end + end + for oxy in keys(regqueue_head) + for oz in keys(regqueue_head[oxy]) + body = substitute(body, regtarget(A, (oxy..., oz), indices), regqueue_head[oxy][oz]) + end + end + end + + nb_indexing_vars = 1 + 14*length(keys(shmem_index_groups)) # TODO: a group must not be counted if none of the variables uses the shmem indices symbols. + nb_cell_vars = sum(values(nb_regs_heads)) + sum(values(nb_regs_tails)) + + #TODO: replace wrap_if where possible with in-line if - compare performance when doing it + body = quote + $loopoffset = (@blockIdx().z-1)*$loopsize + $range_z_start-1 #TODO: MOVE UP - see no perf change! interchange other lines! +$((quote + $tx = @threadIdx().x + $hx1 + $ty = @threadIdx().y + $hy1 + $nx_l = @blockDim().x + UInt32($(hx1+hx2)) # NOTE: cast to UInt32 is necessary to avoid promotion, which can lead to a tuple with different integers, resulting in an error. + $ny_l = @blockDim().y + UInt32($(hy1+hy2)) # ... + $t_h = (@threadIdx().y-1)*@blockDim().x + @threadIdx().x # NOTE: here it must be bx, not @blockDim().x + $t_h2 = $t_h + $nx_l*$ny_l - @blockDim().x*@blockDim().y + $ty_h = ($t_h-1) ÷ $nx_l + 1 + $tx_h = ($t_h-1) % $nx_l + 1 # NOTE: equivalent to (worse performance has uses registers probably differently): ($t_h-1) - $nx_l*($ty_h-1) + 1 + $ty_h2 = ($t_h2-1) ÷ $nx_l + 1 + $tx_h2 = ($t_h2-1) % $nx_l + 1 # NOTE: equivalent to (worse performance has uses registers probably differently): ($t_h2-1) - $nx_l*($ty_h2-1) + 1 + $ix_h = $ix - @threadIdx().x + $tx_h - $hx1 # NOTE: here it must be @blockDim().x, not bx + $ix_h2 = $ix - @threadIdx().x + $tx_h2 - $hx1 # ... + $iy_h = $iy - @threadIdx().y + $ty_h - $hy1 # ... + $iy_h2 = $iy - @threadIdx().y + $ty_h2 - $hy1 # ... + end + for vars in values(shmem_index_groups) for A in (vars[1],) if use_shmemindices[A] for s in (shmem_vars[A],) for (shmem_offset, hx1, hx2, hy1, hy2, tx, ty, nx_l, ny_l, t_h, t_h2, tx_h, tx_h2, ty_h, ty_h2, ix_h, ix_h2, iy_h, iy_h2, A_head) = ((shmem_exprs[A][:offset], hx1s[A], hx2s[A], hy1s[A], hy2s[A], s[:tx], s[:ty], s[:nx_l], s[:ny_l], s[:t_h], s[:t_h2], s[:tx_h], s[:tx_h2], s[:ty_h], s[:ty_h2], s[:ix_h], s[:ix_h2], s[:iy_h], s[:iy_h2], s[:A_head]),) + )... +) +$((:( $A_head = @sharedMem(eltype($A), (Int64($nx_l), Int64($ny_l)), $shmem_offset) # e.g. A_izp3 = @sharedMem(eltype(A), (nx_l, ny_l), +(nx_l_A * ny_l_A)*eltype(A)) + ) + for (A, s) in shmem_vars for (shmem_offset, nx_l, ny_l, A_head) = ((shmem_exprs[A][:offset], s[:nx_l], s[:ny_l], s[:A_head]),) + )... +) +$((:( $reg = 0.0 # e.g. A_ixm1_iyp2_izp2 = 0.0 + ) + for A in optvars for regs in values(regqueue_tails[A]) for reg in values(regs) + )... +) +$((:( $reg = 0.0 # e.g. A_ixm1_iyp2_izp3 = 0.0 + ) + for A in optvars for regs in values(regqueue_heads[A]) for reg in values(regs) + )... +) +# Pre-loop + # for $i = $loopstart:$(mainloopstart-1) +$(wrap_loop(i, loopstart:mainloopstart-1, + quote + $iz = $i + $loopoffset + if ($iz > $range_z_end) ParallelStencil.@return_nothing; end + # NOTE: the following is now fully included in the loopoffset (0.25% performance gain measured on H100) but is still of interest if we implement step ranges: + # $tz_g = $i + $loopoffset + # if ($tz_g > $rangelength_z) ParallelStencil.@return_nothing; end + # $iz = ($tz_g < 1) ? $range_z_start-(1-$tz_g) : $range_z # TODO: this will probably always be formulated with range_z_start +$((wrap_if(:($i > $(loopentry-1)), + :( $reg = (0<$ix+$(oxy[1])<=size($A,1) && 0<$iy+$(oxy[2])<=size($A,2) && 0<$iz+$oz<=size($A,3)) ? $(regtarget(A, (oxy...,oz), indices)) : $reg + ) + ;unless=(loopentry==loopstart) + ) + for A in keys(shmem_vars) for (oxy, regs) in regqueue_heads[A] for (oz, reg) in regs for loopentry = (loopentrys[A],) + )... +) +$((wrap_if(:($i > $(loopentry-1)), + :( $reg = (0<$ix+$(oxy[1])<=size($A,1) && 0<$iy+$(oxy[2])<=size($A,2) && 0<$iz+$oz<=size($A,3)) ? $(regtarget(A, (oxy...,oz), indices)) : $reg + ) + ;unless=(loopentry==loopstart) + ) + for A in optvars for (oxy, regs) in regqueue_heads[A] for (oz, reg) in regs for loopentry = (loopentrys[A],) if !use_shmems[A] + )... +) +$(( # NOTE: the if statement is not needed here as we only deal with registers + # wrap_if(:($i > $(loopentry-1)), + :( + $(regs[oz]) = $(regs[oz+1]) # e.g. A_ixm1_iyp2_iz = A_ixm1_iyp2_izp1 + ) + # ;unless=(loopentry==loopstart) + # ) + for A in optvars for regs in values(regqueue_tails[A]) for oz in sort(keys(regs)) for (loopentry, oz_max) = ((loopentrys[A], oz_maxs[A]),) if oz<=oz_max-2 + )... +) +$(( # NOTE: the if statement is not needed here as we only deal with registers + # wrap_if(:($i > $(loopentry-1)), + :( + $reg = $(regqueue_heads[A][oxy][oz_max]) # e.g. A_ixm1_iyp2_izp2 = A_ixm1_iyp2_izp3 + ) + # ;unless=(loopentry==loopstart) + # ) + for A in optvars for (oxy, regs) in regqueue_tails[A] for (oz, reg) in regs for (loopentry, oz_max) = ((loopentrys[A], oz_maxs[A]),) if oz==oz_max-1 && haskey(regqueue_heads[A], oxy) && haskey(regqueue_heads[A][oxy], oz_max) + )... +) + end + # ;unroll=true + ) # wrap_loop end +) # end + +# Main loop + # for $i = $mainloopstart:$mainloopend # ParallelStencil.@unroll +$(wrap_loop(i, mainloopstart:mainloopend, + quote + $iz = $i + $loopoffset + if ($iz > $range_z_end) ParallelStencil.@return_nothing; end + # NOTE: the following is now fully included in the loopoffset (0.25% performance gain measured on H100) but is still of interest if we implement step ranges: + # $tz_g = $i + $loopoffset + # if ($tz_g > $rangelength_z) ParallelStencil.@return_nothing; end + # $iz = ($tz_g < 1) ? $range_z_start-(1-$tz_g) : $range_z # TODO: this will probably always be formulated with range_z_start +$(use_any_shmem ? + :( @sync_threads() + ) : NOEXPR +) +$((wrap_if(:($i > $(loopentry-1)), + quote + if (2*$t_h <= $n_l && $ix_h>0 && $ix_h<=size($A,1) && $iy_h>0 && $iy_h<=size($A,2) && 0<$iz+$oz_max<=size($A,3)) + $A_head[$tx_h,$ty_h] = $A[$ix_h,$iy_h,$iz+$oz_max] + end + if (2*$t_h2 > $n_l && $ix_h2>0 && $ix_h2<=size($A,1) && $iy_h2>0 && $iy_h2<=size($A,2) && 0<$iz+$oz_max<=size($A,3)) + $A_head[$tx_h2,$ty_h2] = $A[$ix_h2,$iy_h2,$iz+$oz_max] + end + end + ;unless=(loopentry<=mainloopstart) + ) + for (A, s) in shmem_vars if use_shmemhalos[A] for (loopentry, oz_max, tx, ty, nx_l, ny_l, n_l, t_h, t_h2, tx_h, tx_h2, ty_h, ty_h2, ix_h, ix_h2, iy_h, iy_h2, A_head) = ((loopentrys[A], oz_maxs[A], s[:tx], s[:ty], s[:nx_l], s[:ny_l], s[:n_l], s[:t_h], s[:t_h2], s[:tx_h], s[:tx_h2], s[:ty_h], s[:ty_h2], s[:ix_h], s[:ix_h2], s[:iy_h], s[:iy_h2], s[:A_head]),) + )... +) +# $((wrap_if(:($i > $(loopentry-1)), +# quote +# if (2*$tx_h <= $nx_l && $ix_h>0 && $ix_h<=size($A,1) && $iy>0 && $iy<=size($A,2) && 0<$iz+$oz_max<=size($A,3)) +# $A_head[$tx_h,$ty] = $A[$ix_h,$iy,$iz+$oz_max] +# end +# if (2*$tx_h2 > $nx_l && $ix_h2>0 && $ix_h2<=size($A,1) && $iy>0 && $iy<=size($A,2) && 0<$iz+$oz_max<=size($A,3)) +# $A_head[$tx_h2,$ty] = $A[$ix_h2,$iy,$iz+$oz_max] +# end +# end +# ;unless=(loopentry<=mainloopstart) +# ) +# for (A, s) in shmem_vars if (use_shmemhalos[A] && use_shmem_xs[A] && !use_shmem_ys[A]) for (loopentry, oz_max, tx, ty, nx_l, ny_l, tx_h, tx_h2, ix_h, ix_h2, A_head) = ((loopentrys[A], oz_maxs[A], s[:tx], s[:ty], s[:nx_l], s[:ny_l], s[:tx_h], s[:tx_h2], s[:ix_h], s[:ix_h2], s[:A_head]),) +# )... +# ) +# $((wrap_if(:($i > $(loopentry-1)), +# quote +# if (2*$ty_h <= $ny_l && $ix>0 && $ix<=size($A,1) && $iy_h>0 && $iy_h<=size($A,2) && 0<$iz+$oz_max<=size($A,3)) +# $A_head[$tx,$ty_h] = $A[$ix,$iy_h,$iz+$oz_max] +# end +# if (2*$ty_h2 > $ny_l && $ix>0 && $ix<=size($A,1) && $iy_h2>0 && $iy_h2<=size($A,2) && 0<$iz+$oz_max<=size($A,3)) +# $A_head[$tx,$ty_h2] = $A[$ix,$iy_h2,$iz+$oz_max] +# end +# end +# ;unless=(loopentry<=mainloopstart) +# ) +# for (A, s) in shmem_vars if (use_shmemhalos[A] && !use_shmem_xs[A] && use_shmem_ys[A]) for (loopentry, oz_max, tx, ty, nx_l, ny_l, ty_h, ty_h2, iy_h, iy_h2, A_head) = ((loopentrys[A], oz_maxs[A], s[:tx], s[:ty], s[:nx_l], s[:ny_l], s[:ty_h], s[:ty_h2], s[:iy_h], s[:iy_h2], s[:A_head]),) +# )... +# ) +$((wrap_if(:($i > $(loopentry-1)), + quote + if ($ix>0 && $ix<=size($A,1) && $iy>0 && $iy<=size($A,2) && 0<$iz+$oz_max<=size($A,3)) + $A_head[$tx,$ty] = $A[$ix,$iy,$iz+$oz_max] + end + end + ;unless=(loopentry<=mainloopstart) + ) + for (A, s) in shmem_vars if !use_shmemhalos[A] for (loopentry, oz_max, tx, ty, nx_l, ny_l, A_head) = ((loopentrys[A], oz_maxs[A], s[:tx], s[:ty], s[:nx_l], s[:ny_l], s[:A_head]),) + )... +) +$(use_any_shmem ? + :( @sync_threads() + ) : NOEXPR +) +$((wrap_if(:($i > $(loopentry-1)), + :( $reg = (0<$ix+$(oxy[1])<=size($A,1) && 0<$iy+$(oxy[2])<=size($A,2) && 0<$iz+$oz<=size($A,3)) ? $(regtarget(A, (oxy...,oz), indices)) : $reg + ) + ;unless=(loopentry<=mainloopstart) + ) + for A in optvars for (oxy, regs) in regqueue_heads[A] for (oz, reg) in regs for loopentry = (loopentrys[A],) if !use_shmems[A] + )... +) +$((wrap_if(:($i > $(loopentry-1)), + use_shmemhalo ? + :( $reg = $(regsource(A_head, oxy, (tx, ty))) # e.g. A_ixm1_iyp2_izp3 = A_izp3[tx - 1, ty + 2] + ) + : + :( $reg = (0<$tx+$(oxy[1])<=$nx_l && 0<$ty+$(oxy[2])<=$ny_l) ? $(regsource(A_head, oxy, (tx, ty))) : (0<$ix+$(oxy[1])<=size($A,1) && 0<$iy+$(oxy[2])<=size($A,2) && 0<$iz+$oz<=size($A,3)) ? $(regtarget(A, (oxy...,oz), indices)) : $reg + ) + ;unless=(loopentry<=mainloopstart) + ) + for (A, s) in shmem_vars for (oxy, regs) in regqueue_heads[A] for (oz, reg) in regs for (use_shmemhalo, loopentry, tx, ty, nx_l, ny_l, A_head) = ((use_shmemhalos[A], loopentrys[A], s[:tx], s[:ty], s[:nx_l], s[:ny_l], s[:A_head]),) + )... +) +$((wrap_if(:($i > 0), + quote + $body + end; + unless=(mainloopstart>=1) + ) +)) +$(( # NOTE: the if statement is not needed here as we only deal with registers + # wrap_if(:($i > $(loopentry-1)), + :( + $(regs[oz]) = $(regs[oz+1]) # e.g. A_ixm1_iyp2_iz = A_ixm1_iyp2_izp1 + ) + # ;unless=(loopentry<=mainloopstart) + # ) + for A in optvars for regs in values(regqueue_tails[A]) for oz in sort(keys(regs)) for (loopentry, oz_max) = ((loopentrys[A], oz_maxs[A]),) if oz<=oz_max-2 + )... +) +$((wrap_if(:($i > $(loopentry-1)), + use_shmemhalo ? + :( $reg = $(regsource(A_head, oxy, (tx, ty))) # e.g. A_ixm3_iyp2_izp2 = A_izp3[tx - 3, ty + 2] + ) + : + :( $reg = (0<$tx+$(oxy[1])<=$nx_l && 0<$ty+$(oxy[2])<=$ny_l) ? $(regsource(A_head, oxy, (tx, ty))) : (0<$ix+$(oxy[1])<=size($A,1) && 0<$iy+$(oxy[2])<=size($A,2) && 0<$iz+$oz<=size($A,3)) ? $(regtarget(A, (oxy...,oz), indices)) : $reg + ) + ;unless=(loopentry<=mainloopstart) + ) + for (A, s) in shmem_vars for (oxy, regs) in regqueue_tails[A] for (oz, reg) in regs for (use_shmemhalo, loopentry, oz_max, tx, ty, nx_l, ny_l, A_head) = ((use_shmemhalos[A], loopentrys[A], oz_maxs[A], s[:tx], s[:ty], s[:nx_l], s[:ny_l], s[:A_head]),) if oz==oz_max-1 && !(haskey(regqueue_heads[A], oxy) && haskey(regqueue_heads[A][oxy], oz_max)) + )... +) +# TODO: remove these as soon as the above is tested: +# $((wrap_if(:($i > $(loopentry-1)), +# :( $reg = $(regsource(A_head, oxy, (tx, ty))) # e.g. A_ixm3_iyp2_izp2 = A_izp3[tx - 3, ty + 2] +# ) +# ;unless=(loopentry<=mainloopstart) +# ) +# for (A, s) in shmem_vars for (oxy, regs) in regqueue_tails[A] for (oz, reg) in regs for (loopentry, oz_max, tx, ty, A_head) = ((loopentrys[A], oz_maxs[A], s[:tx], s[:ty], s[:A_head]),) if oz==oz_max-1 && !(haskey(regqueue_heads[A], oxy) && haskey(regqueue_heads[A][oxy], oz_max)) +# )... +# ) +$(( # NOTE: the if statement is not needed here as we only deal with registers + # wrap_if(:($i > $(loopentry-1)), + :( + $reg = $(regqueue_heads[A][oxy][oz_max]) # e.g. A_ixm1_iyp2_izp2 = A_ixm1_iyp2_izp3 + ) + # ;unless=(loopentry<=mainloopstart) + # ) + for A in optvars for (oxy, regs) in regqueue_tails[A] for (oz, reg) in regs for (loopentry, oz_max) = ((loopentrys[A], oz_maxs[A]),) if oz==oz_max-1 && haskey(regqueue_heads[A], oxy) && haskey(regqueue_heads[A][oxy], oz_max) + )... +) + end + # ;unroll=true + ) # wrap_loop end +) # end + +# Wrap-up-loop +# ParallelStencil.@unroll for $i = $(mainloopend+1):$loopend +# $tz_g = $i + $loopoffset +# if ($tz_g > $rangelength_z) ParallelStencil.@return_nothing; end +# $iz = ($tz_g < 1) ? $range_z_start-(1-$tz_g) : $range_z # TODO: this will probably always be formulated with range_z_start +# $((wrap_if(:($i > $(loopentry-1)), +# quote +# @sync_threads() +# if (2*$t_h <= $nx_l*$ny_l && $ix_h>0 && $ix_h<=size($A,1) && $iy_h>0 && $iy_h<=size($A,2) && 0<$iz+$oz_max<=size($A,3)) +# $A_head[$tx_h,$ty_h] = $A[$ix_h,$iy_h,$iz+$oz_max] +# end +# if (2*$t_h2 <= $nx_l*$ny_l && $ix_h2>0 && $ix_h2<=size($A,1) && $iy_h2>0 && $iy_h2<=size($A,2) && 0<$iz+$oz_max<=size($A,3)) +# $A_head[$tx_h2,$ty_h2] = $A[$ix_h2,$iy_h2,$iz+$oz_max] +# end +# @sync_threads() +# end +# ;unless=(loopentry<=mainloopstart) +# ) +# for (A, s) in shmem_vars for (loopentry, oz_max, tx, ty, nx_l, ny_l, t_h, t_h2, tx_h, tx_h2, ty_h, ty_h2, ix_h, ix_h2, iy_h, iy_h2, A_head) = ((loopentrys[A], oz_maxs[A], s[:tx], s[:ty], s[:nx_l], s[:ny_l], s[:t_h], s[:t_h2], s[:tx_h], s[:tx_h2], s[:ty_h], s[:ty_h2], s[:ix_h], s[:ix_h2], s[:iy_h], s[:iy_h2], s[:A_head]),) +# )... +# ) +# $((wrap_if(:($i > $(loopentry-1)), +# :( $reg = (0<$ix+$(oxy[1])<=size($A,1) && 0<$iy+$(oxy[2])<=size($A,2) && 0<$iz+$oz<=size($A,3)) ? $(regtarget(A, (oxy...,oz), indices)) : $reg +# ) +# ;unless=(loopentry<=mainloopstart) +# ) +# for A in optvars for (oxy, regs) in regqueue_heads[A] for (oz, reg) in regs for loopentry = (loopentrys[A],) if !use_shmems[A] +# )... +# ) +# $((wrap_if(:($i > $(loopentry-1)), +# :( $reg = $(regsource(A_head, oxy, (tx, ty))) # e.g. A_ixm1_iyp2_izp3 = A_izp3[tx - 1, ty + 2] +# ) +# ;unless=(loopentry<=mainloopstart) +# ) +# for (A, s) in shmem_vars for (oxy, regs) in regqueue_heads[A] for reg in values(regs) for (loopentry, tx, ty, A_head) = ((loopentrys[A], s[:tx], s[:ty], s[:A_head]),) +# )... +# ) +# $((wrap_if(:($i > 0), +# quote +# $body +# end; +# unless=(mainloopstart>=1) +# ) +# )) +# $((wrap_if(:($i > $(loopentry-1)), +# :( +# $(regs[oz]) = $(regs[oz+1]) # e.g. A_ixm1_iyp2_iz = A_ixm1_iyp2_izp1 +# ) +# ;unless=(loopentry<=mainloopstart) +# ) +# for A in optvars for regs in values(regqueue_tails[A]) for oz in sort(keys(regs)) for (loopentry, oz_max) = ((loopentrys[A], oz_maxs[A]),) if oz<=oz_max-2 +# )... +# ) +# $((wrap_if(:($i > $(loopentry-1)), +# :( $reg = $(regsource(A_head, oxy, (tx, ty))) # e.g. A_ixm3_iyp2_izp2 = A_izp3[tx - 3, ty + 2] +# ) +# ;unless=(loopentry<=mainloopstart) +# ) +# for (A, s) in shmem_vars for (oxy, regs) in regqueue_tails[A] for (oz, reg) in regs for (loopentry, oz_max, tx, ty, A_head) = ((loopentrys[A], oz_maxs[A], s[:tx], s[:ty], s[:A_head]),) if oz==oz_max-1 && !(haskey(regqueue_heads[A], oxy) && haskey(regqueue_heads[A][oxy], oz_max)) +# )... +# ) +# $((wrap_if(:($i > $(loopentry-1)), +# :( +# $reg = $(regqueue_heads[A][oxy][oz_max]) # e.g. A_ixm1_iyp2_izp2 = A_ixm1_iyp2_izp3 +# ) +# ;unless=(loopentry<=mainloopstart) +# ) +# for A in optvars for (oxy, regs) in regqueue_tails[A] for (oz, reg) in regs for (loopentry, oz_max) = ((loopentrys[A], oz_maxs[A]),) if oz==oz_max-1 && haskey(regqueue_heads[A], oxy) && haskey(regqueue_heads[A][oxy], oz_max) +# )... +# ) + +# $tz_g = $i + $loopoffset +# if ($tz_g > $rangelength_z) ParallelStencil.@return_nothing; end +# $iz = ($tz_g < 1) ? $range_z_start-(1-$tz_g) : $range_z # TODO: this will probably always be formulated with range_z_start +# $(( +# # wrap_if(:(($(loopentry-1) < $i < $(shmem_loopentry)) || ($(shmem_loopexit) < $i)), +# :( $reg = (0<$ix+$(oxy[1])<=size($A,1) && 0<$iy+$(oxy[2])<=size($A,2) && 0<$iz+$oz<=size($A,3)) ? $(regtarget(A, (oxy...,oz), indices)) : $reg +# ) +# for A in keys(shmem_vars) for (oxy, regs) in regqueue_heads[A] for (oz, reg) in regs for loopentry = (loopentrys[A],) +# )... +# ) +# $(( +# :( $reg = (0<$ix+$(oxy[1])<=size($A,1) && 0<$iy+$(oxy[2])<=size($A,2) && 0<$iz+$oz<=size($A,3)) ? $(regtarget(A, (oxy...,oz), indices)) : $reg +# ) +# for A in optvars for (oxy, regs) in regqueue_heads[A] for (oz, reg) in regs for loopentry = (loopentrys[A],) if !use_shmems[A] +# )... +# ) +# $(( +# quote +# $body +# end +# )) +# $(( +# :( +# $(regs[oz]) = $(regs[oz+1]) # e.g. A_ixm1_iyp2_iz = A_ixm1_iyp2_izp1 +# ) +# for A in optvars for regs in values(regqueue_tails[A]) for oz in sort(keys(regs)) for (loopentry, oz_max) = ((loopentrys[A], oz_maxs[A]),) if oz<=oz_max-2 +# )... +# ) +# $(( +# :( +# $reg = $(regqueue_heads[A][oxy][oz_max]) # e.g. A_ixm1_iyp2_izp2 = A_ixm1_iyp2_izp3 +# ) +# for A in optvars for (oxy, regs) in regqueue_tails[A] for (oz, reg) in regs for (loopentry, oz_max) = ((loopentrys[A], oz_maxs[A]),) if oz==oz_max-1 && haskey(regqueue_heads[A], oxy) && haskey(regqueue_heads[A][oxy], oz_max) +# )... +# ) + # end + end + else + @ArgumentError("memopt: only loopdim=3 is currently supported.") + end + store_metadata(metadata_module, is_parallel_kernel, caller, offset_mins, offset_maxs, offsets, optvars, loopdim, loopsize, optranges, use_shmemhalos) + # @show QuoteNode(ParallelKernel.simplify_varnames!(ParallelKernel.remove_linenumbernodes!(deepcopy(body)))) + return body +end + + +function memopt(metadata_module::Module, is_parallel_kernel::Bool, caller::Module, indices::Union{Symbol,Expr}, optvars::Union{Expr,Symbol}, body::Expr; package::Symbol=get_package(caller)) + loopdim = isa(indices,Expr) ? length(indices.args) : 1 + loopsize = compute_loopsize(package) + optranges = nothing + use_shmemhalos = nothing + optimize_halo_read = true + return memopt(metadata_module, is_parallel_kernel, caller, indices, optvars, loopdim, loopsize, optranges, use_shmemhalos, optimize_halo_read, body; package=package) +end + + +function shortif(caller::Module, else_val, if_expr; package::Symbol=get_package(caller)) + if (package ∉ SUPPORTED_PACKAGES) @KeywordArgumentError("$ERRMSG_UNSUPPORTED_PACKAGE (obtained: $package).") end + @capture(if_expr, if condition_ body_ end) || @ArgumentError("@shortif: the second argument must be an if statement.") + @capture(body, lhs_ = rhs_) || @ArgumentError("@shortif: the if statement body must contain a assignement.") + return :($lhs = $condition ? $rhs : $else_val) +end + + +## FUNCTIONS FOR SHARED MEMORY ALLOCATION + + +## HELPER FUNCTIONS + +function eval_offsets(caller::Module, body::Expr, indices::NTuple{N,<:Union{Symbol,Expr}} where N, int_type::Type{<:Integer}) + return postwalk(body) do ex + if !is_stencil_access(ex, indices...) return ex; end + @capture(ex, A_[indices_expr__]) || @ModuleInternalError("a stencil access could not be pattern matched.") + for i = 1:length(indices) + offset_expr = substitute(indices_expr[i], indices[i], 0) + offset = eval_arg(caller, offset_expr) + if (offset > 0) indices_expr[i] = :($(indices[i]) + $(int_type(offset)) ) + elseif (offset < 0) indices_expr[i] = :($(indices[i]) - $(int_type(abs(offset)))) + else indices_expr[i] = indices[i] + end + end + return :($A[$(indices_expr...)]) + end +end + +function extract_offsets(caller::Module, body::Expr, indices::NTuple{N,<:Union{Symbol,Expr}} where N, int_type::Type{<:Integer}, optvars::NTuple{N,Symbol} where N, loopdim::Integer) + offsets_by_xy = Dict(A => Dict() for A in optvars) + offsets_by_z = Dict(A => Dict() for A in optvars) + postwalk(body) do ex + if is_stencil_access(ex, indices...) + @capture(ex, A_[indices_expr__]) || @ModuleInternalError("a stencil access could not be pattern matched.") + if A in optvars + offsets = () + for i = 1:length(indices) + offset_expr = substitute(indices_expr[i], indices[i], 0) + offset = int_type(eval_arg(caller, offset_expr)) # TODO: do this and cast later to enable unsigned integer (also dealing with negative rangers is required elsewhere): offset = eval_arg(caller, offset_expr) + offsets = (offsets..., offset) + end + if loopdim == 3 + k1 = offsets[1:2] + k2 = offsets[end] + if haskey(offsets_by_xy[A], k1) && haskey(offsets_by_xy[A][k1], k2) offsets_by_xy[A][k1][k2] += 1 + elseif haskey(offsets_by_xy[A], k1) offsets_by_xy[A][k1][k2] = 1 + else offsets_by_xy[A][k1] = Dict(k2 => 1) + end + k1 = offsets[end] + k2 = offsets[1:2] + if haskey(offsets_by_z[A], k1) && haskey(offsets_by_z[A][k1], k2) offsets_by_z[A][k1][k2] += 1 + elseif haskey(offsets_by_z[A], k1) offsets_by_z[A][k1][k2] = 1 + else offsets_by_z[A][k1] = Dict(k2 => 1) + end + else + @ArgumentError("memopt: only loopdim=3 is currently supported.") + end + end + end + return ex + end + return offsets_by_xy, offsets_by_z +end + +function remove_single_point_optvars(optvars, optranges_arg, offsets, offsets_by_z) + return tuple((A for A in optvars if !(length(keys(offsets[A]))==1 && length(keys(offsets_by_z[A]))==1) || (!isnothing(optranges_arg) && A ∈ keys(optranges_arg)))...) +end + +function define_optranges(optranges_arg, optvars, offsets, int_type, package) + compute_capability = get_compute_capability(package) + optranges = Dict() + for A in optvars + zspan_max = 0 + oxy_zspan_max = () + for oxy in keys(offsets[A]) + zspan = length(keys(offsets[A][oxy])) + if zspan > zspan_max + zspan_max = zspan + oxy_zspan_max = oxy + end + end + fullrange = typemin(int_type):typemax(int_type) + pointrange_x = oxy_zspan_max[1]: oxy_zspan_max[1] + pointrange_y = oxy_zspan_max[2]: oxy_zspan_max[2] + if (!isnothing(optranges_arg) && A ∈ keys(optranges_arg)) optranges[A] = getproperty(optranges_arg, A) + elseif (compute_capability < v"8" && (length(optvars) <= FULLRANGE_THRESHOLD)) optranges[A] = (fullrange, fullrange, fullrange) + elseif (USE_FULLRANGE_DEFAULT == (true, true, true)) optranges[A] = (fullrange, fullrange, fullrange) + elseif (USE_FULLRANGE_DEFAULT == (false, true, true)) optranges[A] = (pointrange_x, fullrange, fullrange) + elseif (USE_FULLRANGE_DEFAULT == (true, false, true)) optranges[A] = (fullrange, pointrange_y, fullrange) + elseif (USE_FULLRANGE_DEFAULT == (false, false, true)) optranges[A] = (pointrange_x, pointrange_y, fullrange) + end + end + return optranges +end + +function define_regqueues(offsets::Dict{Symbol, Dict{Any, Any}}, optranges::Dict{Any, Any}, optvars::NTuple{N,Symbol} where N, indices::NTuple{N,<:Union{Symbol,Expr}} where N, int_type::Type{<:Integer}, loopdim::Integer) + regqueue_heads = Dict(A => Dict() for A in optvars) + regqueue_tails = Dict(A => Dict() for A in optvars) + offset_mins = Dict{Symbol, NTuple{3,Integer}}() + offset_maxs = Dict{Symbol, NTuple{3,Integer}}() + nb_regs_heads = Dict{Symbol, Integer}() + nb_regs_tails = Dict{Symbol, Integer}() + for A in optvars + regqueue_heads[A], regqueue_tails[A], offset_mins[A], offset_maxs[A], nb_regs_heads[A], nb_regs_tails[A] = define_regqueue(offsets[A], optranges[A], A, indices, int_type, loopdim) + end + return regqueue_heads, regqueue_tails, offset_mins, offset_maxs, nb_regs_heads, nb_regs_tails +end + +function define_regqueue(offsets::Dict{Any, Any}, optranges::NTuple{3,UnitRange}, A::Symbol, indices::NTuple{N,<:Union{Symbol,Expr}} where N, int_type::Type{<:Integer}, loopdim::Integer) + regqueue_head = Dict() + regqueue_tail = Dict() + nb_regs_head = 0 + nb_regs_tail = 0 + if loopdim == 3 + optranges_xy = optranges[1:2] + optranges_z = optranges[3] + offsets_xy = filter(oxy -> all(oxy .∈ optranges_xy), keys(offsets)) + if isempty(offsets_xy) @IncoherentArgumentError("incoherent argument in memopt: optranges in x-y dimension do not include any array access.") end + offset_min = (typemax(int_type), typemax(int_type), typemax(int_type)) + offset_max = (typemin(int_type), typemin(int_type), typemin(int_type)) + for oxy in offsets_xy + offsets_z = filter(x -> x ∈ optranges_z, keys(offsets[oxy])) + if isempty(offsets_z) @IncoherentArgumentError("incoherent argument in memopt: optranges in z dimension do not include any array access.") end + offset_min = (min(offset_min[1], oxy[1]), + min(offset_min[2], oxy[2]), + min(offset_min[3], minimum(offsets_z))) + offset_max = (max(offset_max[1], oxy[1]), + max(offset_max[2], oxy[2]), + max(offset_max[3], maximum(offsets_z))) + end + oz_max = offset_max[3] + for oxy in offsets_xy + offsets_z = sort(filter(x -> x ∈ optranges_z, keys(offsets[oxy]))) + k1 = oxy + for oz = offsets_z[1]:oz_max-1 + k2 = oz + if haskey(regqueue_tail, k1) && haskey(regqueue_tail[k1], k2) @ModuleInternalError("regqueue_tail entry exists already.") end + reg = gensym_world(varname(A, (oxy..., oz)), @__MODULE__); nb_regs_tail += 1 + if haskey(regqueue_tail, k1) regqueue_tail[k1][k2] = reg + else regqueue_tail[k1] = Dict(k2 => reg) + end + end + oz = offsets_z[end] + if oz == oz_max + k2 = oz + if haskey(regqueue_head, k1) && haskey(regqueue_head[k1], k2) @ModuleInternalError("regqueue_head entry exists already.") end + reg = gensym_world(varname(A, (oxy..., oz)), @__MODULE__); nb_regs_head += 1 + if haskey(regqueue_head, k1) regqueue_head[k1][k2] = reg + else regqueue_head[k1] = Dict(k2 => reg) + end + end + end + else + @ArgumentError("memopt: only loopdim=3 is currently supported.") + end + return regqueue_head, regqueue_tail, offset_min, offset_max, nb_regs_head, nb_regs_tail +end + +function define_helper_variables(offset_mins::Dict{Symbol, <:NTuple{3,Integer}}, offset_maxs::Dict{Symbol, <:NTuple{3,Integer}}, optvars::NTuple{N,Symbol} where N, use_shmemhalos_arg, loopdim::Integer) + oz_maxs, hx1s, hy1s, hx2s, hy2s, use_shmems, use_shmem_xs, use_shmem_ys, use_shmemhalos, use_shmemindices, offset_spans, oz_spans, loopentrys = Dict(), Dict(), Dict(), Dict(), Dict(), Dict(), Dict(), Dict(), Dict(), Dict(), Dict(), Dict(), Dict() + if loopdim == 3 + for A in optvars + offset_min, offset_max = offset_mins[A], offset_maxs[A] + oz_max = offset_max[3] + hx1, hy1 = -1 .* offset_min[1:2] + hx2, hy2 = offset_max[1:2] + use_shmem_x = (hx1 + hx2 > 0) + use_shmem_y = (hy1 + hy2 > 0) + use_shmem = use_shmem_x || use_shmem_y + use_shmemhalo = if (!isnothing(use_shmemhalos_arg) && (A ∈ keys(use_shmemhalos_arg))) getproperty(use_shmemhalos_arg, A) + elseif !(use_shmem_x && use_shmem_y) USE_SHMEMHALO_1D_DEFAULT + else USE_SHMEMHALO_DEFAULT + end + use_shmemindex = use_shmem && use_shmemhalo && (use_shmem_x && use_shmem_y) + offset_span = offset_max .- offset_min + oz_span = offset_span[3] + loopentry = 1 - oz_span #TODO: make possibility to do first and last read in z dimension directly into registers without halo + oz_maxs[A], hx1s[A], hy1s[A], hx2s[A], hy2s[A], use_shmems[A], use_shmem_xs[A], use_shmem_ys[A], use_shmemhalos[A], use_shmemindices[A], offset_spans[A], oz_spans[A], loopentrys[A] = oz_max, hx1, hy1, hx2, hy2, use_shmem, use_shmem_x, use_shmem_y, use_shmemhalo, use_shmemindex, offset_span, oz_span, loopentry + end + else + @ArgumentError("memopt: only loopdim=3 is currently supported.") + end + return oz_maxs, hx1s, hy1s, hx2s, hy2s, use_shmems, use_shmem_xs, use_shmem_ys, use_shmemhalos, use_shmemindices, offset_spans, oz_spans, loopentrys +end + +function define_shmem_index_groups(hx1s, hy1s, hx2s, hy2s, optvars::NTuple{N,Symbol} where N, use_shmems::Dict{Any, Any}, loopdim::Integer) + shmem_index_groups = Dict() + if loopdim == 3 + for A in optvars + if use_shmems[A] + k = (hx1s[A], hy1s[A], hx2s[A], hy2s[A]) + if !haskey(shmem_index_groups, k) shmem_index_groups[k] = (A,) + else shmem_index_groups[k] = (shmem_index_groups[k]..., A) + end + end + end + end + return shmem_index_groups +end + +function define_shmem_vars(oz_maxs::Dict{Any, Any}, hx1s, hy1s, hx2s, hy2s, optvars::NTuple{N,Symbol} where N, indices, use_shmems::Dict{Any, Any}, use_shmem_xs, use_shmem_ys, shmem_index_groups, use_shmemhalos, use_shmemindices, loopdim::Integer) + ix, iy, iz = indices + shmem_vars = Dict(A => Dict() for A in optvars if use_shmems[A]) + if loopdim == 3 + for vars in values(shmem_index_groups) + suffix = join(string.(vars), "_") + sym_tx = gensym_world("tx_$suffix", @__MODULE__) + sym_ty = gensym_world("ty_$suffix", @__MODULE__) + sym_nx_l = gensym_world("nx_l_$suffix", @__MODULE__) + sym_ny_l = gensym_world("ny_l_$suffix", @__MODULE__) + sym_t_h = gensym_world("t_h_$suffix", @__MODULE__) + sym_t_h2 = gensym_world("t_h2_$suffix", @__MODULE__) + sym_tx_h = gensym_world("tx_h_$suffix", @__MODULE__) + sym_tx_h2 = gensym_world("tx_h2_$suffix", @__MODULE__) + sym_ty_h = gensym_world("ty_h_$suffix", @__MODULE__) + sym_ty_h2 = gensym_world("ty_h2_$suffix", @__MODULE__) + sym_ix_h = gensym_world("ix_h_$suffix", @__MODULE__) + sym_ix_h2 = gensym_world("ix_h2_$suffix", @__MODULE__) + sym_iy_h = gensym_world("iy_h_$suffix", @__MODULE__) + sym_iy_h2 = gensym_world("iy_h2_$suffix", @__MODULE__) + for A in vars + if use_shmemindices[A] + n_l = quote $sym_nx_l*$sym_ny_l end + shmem_vars[A][:tx] = sym_tx + shmem_vars[A][:ty] = sym_ty + shmem_vars[A][:nx_l] = sym_nx_l + shmem_vars[A][:ny_l] = sym_ny_l + shmem_vars[A][:n_l] = n_l + shmem_vars[A][:t_h] = sym_t_h + shmem_vars[A][:t_h2] = sym_t_h2 + shmem_vars[A][:tx_h] = sym_tx_h + shmem_vars[A][:tx_h2] = sym_tx_h2 + shmem_vars[A][:ty_h] = sym_ty_h + shmem_vars[A][:ty_h2] = sym_ty_h2 + shmem_vars[A][:ix_h] = sym_ix_h + shmem_vars[A][:ix_h2] = sym_ix_h2 + shmem_vars[A][:iy_h] = sym_iy_h + shmem_vars[A][:iy_h2] = sym_iy_h2 + else + if use_shmemhalos[A] + use_shmem_x, use_shmem_y = use_shmem_xs[A], use_shmem_ys[A] + hx1, hy1, hx2, hy2 = hx1s[A], hy1s[A], hx2s[A], hy2s[A] + if use_shmem_x && use_shmem_y # NOTE: if the following expressions are noted with ":()" then it will cause a segmentation fault and run time. + tx = quote @threadIdx().x + $hx1 end + ty = quote @threadIdx().y + $hy1 end + nx_l = quote @blockDim().x + UInt32($(hx1+hx2)) end # NOTE: cast to UInt32 is necessary to avoid promotion, which can lead to a tuple with different integers, resulting in an error. + ny_l = quote @blockDim().y + UInt32($(hy1+hy2)) end # ... + n_l = quote $nx_l*$ny_l end + t_h = quote (@threadIdx().y-1)*@blockDim().x + @threadIdx().x end # NOTE: here it must be bx, not @blockDim().x + t_h2 = quote $t_h + $nx_l*$ny_l - @blockDim().x*@blockDim().y end + ty_h = quote ($t_h-1) ÷ $nx_l + 1 end + tx_h = quote ($t_h-1) % $nx_l + 1 end # NOTE: equivalent to (worse performance has uses registers probably differently): ($t_h-1) - $nx_l*($ty_h-1) + 1 + ty_h2 = quote ($t_h2-1) ÷ $nx_l + 1 end + tx_h2 = quote ($t_h2-1) % $nx_l + 1 end # NOTE: equivalent to (worse performance has uses registers probably differently): ($t_h2-1) - $nx_l*($ty_h2-1) + 1 + ix_h = quote $ix - @threadIdx().x + $tx_h - $hx1 end # NOTE: here it must be @blockDim().x, not bx + ix_h2 = quote $ix - @threadIdx().x + $tx_h2 - $hx1 end # ... + iy_h = quote $iy - @threadIdx().y + $ty_h - $hy1 end # ... + iy_h2 = quote $iy - @threadIdx().y + $ty_h2 - $hy1 end # ... + elseif use_shmem_x + tx = quote @threadIdx().x + $hx1 end + ty = quote @threadIdx().y + $hy1 end + nx_l = quote @blockDim().x + UInt32($(hx1+hx2)) end # NOTE: cast to UInt32 is necessary to avoid promotion, which can lead to a tuple with different integers, resulting in an error. + ny_l = quote @blockDim().y end + tx_h = quote @threadIdx().x end + ty_h = quote @threadIdx().y end + tx_h2 = quote @threadIdx().x + $(hx1+hx2) end # NOTE: alternative: shmem_vars[A][:tx_h2] = :(@threadIdx().x + @blockDim().x) + ty_h2 = ty_h + ix_h = quote $ix - @threadIdx().x + $tx_h - $hx1 end + ix_h2 = quote $ix - @threadIdx().x + $tx_h2 - $hx1 end + iy_h = quote $iy - @threadIdx().y + $ty_h - $hy1 end + iy_h2 = quote $iy - @threadIdx().y + $ty_h2 - $hy1 end + n_l = nx_l + t_h = tx_h + t_h2 = tx_h2 + elseif use_shmem_y + tx = quote @threadIdx().x + $hx1 end + ty = quote @threadIdx().y + $hy1 end + nx_l = quote @blockDim().x end + ny_l = quote @blockDim().y + UInt32($(hy1+hy2)) end # NOTE: cast to UInt32 is necessary to avoid promotion, which can lead to a tuple with different integers, resulting in an error. + tx_h = quote @threadIdx().x end + ty_h = quote @threadIdx().y end + tx_h2 = tx_h + ty_h2 = quote @threadIdx().y + $(hy1+hy2) end # NOTE: alternative: # shmem_vars[A][:ty_h2] = :(@threadIdx().y + @blockDim().y) + ix_h = quote $ix - @threadIdx().x + $tx_h - $hx1 end + ix_h2 = quote $ix - @threadIdx().x + $tx_h2 - $hx1 end + iy_h = quote $iy - @threadIdx().y + $ty_h - $hy1 end + iy_h2 = quote $iy - @threadIdx().y + $ty_h2 - $hy1 end + n_l = ny_l + t_h = ty_h + t_h2 = ty_h2 + end + shmem_vars[A][:tx] = tx + shmem_vars[A][:ty] = ty + shmem_vars[A][:nx_l] = nx_l + shmem_vars[A][:ny_l] = ny_l + shmem_vars[A][:n_l] = n_l + shmem_vars[A][:t_h] = t_h + shmem_vars[A][:t_h2] = t_h2 + shmem_vars[A][:tx_h] = tx_h + shmem_vars[A][:tx_h2] = tx_h2 + shmem_vars[A][:ty_h] = ty_h + shmem_vars[A][:ty_h2] = ty_h2 + shmem_vars[A][:ix_h] = ix_h + shmem_vars[A][:ix_h2] = ix_h2 + shmem_vars[A][:iy_h] = iy_h + shmem_vars[A][:iy_h2] = iy_h2 + else + shmem_vars[A][:tx] = :(@threadIdx().x) + shmem_vars[A][:ty] = :(@threadIdx().y) + shmem_vars[A][:nx_l] = :(@blockDim().x) + shmem_vars[A][:ny_l] = :(@blockDim().y) + end + end + shmem_vars[A][:A_head] = gensym_world(varname(A, (oz_maxs[A],); i="iz"), @__MODULE__) + end + end + else + @ArgumentError("memopt: only loopdim=3 is currently supported.") + end + return shmem_vars +end + +function define_shmem_exprs(shmem_vars::Dict{Symbol, Dict{Any, Any}}, loopdim::Integer) + exprs = Dict(A => Dict() for A in keys(shmem_vars)) + offset = () + if loopdim == 3 + for A in keys(shmem_vars) + exprs[A][:offset] = (length(offset) > 0) ? Expr(:call, :+, offset...) : 0 + offset = (offset..., :($(shmem_vars[A][:nx_l]) * $(shmem_vars[A][:ny_l]) * sizeof(eltype($A)))) + end + else + @ArgumentError("memopt: only loopdim=3 is currently supported.") + end + return exprs +end + +function define_shmem_z_ranges(offsets_by_z::Dict{Symbol, Dict{Any, Any}}, use_shmems::Dict{Any, Any}, loopdim::Integer) + shmem_z_ranges = Dict() + shmem_As = (A for (A, use_shmem) in use_shmems if use_shmem) + for A in shmem_As + shmem_z_ranges[A] = define_shmem_z_range(offsets_by_z[A], loopdim) + end + return shmem_z_ranges +end + +function define_shmem_z_range(offsets_by_z::Dict{Any, Any}, loopdim::Integer) + start, start_offsets_xy = find_rangelimit(offsets_by_z, loopdim; upper=false) + stop, stop_offsets_xy = find_rangelimit(offsets_by_z, loopdim; upper=true) + if (length(start_offsets_xy) != 1 || length(stop_offsets_xy) != 1 || start_offsets_xy[1] != stop_offsets_xy[1]) # NOTE: shared memory range is not reduced in asymmetric case + return minimum(keys(offsets_by_z)):maximum(keys(offsets_by_z)) + end + return start:stop +end + +function find_rangelimit(offsets_by_z::Dict{Any, Any}, loopdim::Integer; upper=false) + if loopdim == 3 + offsets_z = sort(keys(offsets_by_z); rev=upper) + oz1 = offsets_z[1] + rangelimit = oz1 + offsets_xy1 = (keys(offsets_by_z[oz1])...,) + if length(offsets_xy1) == 1 + rangelimit = offsets_z[2] + oxy1 = offsets_xy1[1] + for oz in offsets_z[2:end] + offsets_xy = (keys(offsets_by_z[oz])...,) + if (length(offsets_xy) == 1) && (offsets_xy[1] == oxy1) + rangelimit = offsets_z[oz+1] + else + break + end + end + end + else + @ArgumentError("memopt: only loopdim=3 is currently supported.") + end + return rangelimit, offsets_xy1 +end + +function define_shmem_loopentrys(loopentrys, shmem_z_ranges, offset_mins, loopdim::Integer) + shmem_loopentrys = Dict() + shmem_As = (A for A in keys(shmem_z_ranges)) + for A in shmem_As + shmem_loopentrys[A] = define_shmem_loopentry(loopentrys[A], shmem_z_ranges[A], offset_mins[A], loopdim) + end + return shmem_loopentrys +end + +function define_shmem_loopentry(loopentry, shmem_z_range, offset_min, loopdim::Integer) + if loopdim == 3 + shmem_loopentry = loopentry + (shmem_z_range.start - offset_min[3]) + else + @ArgumentError("memopt: only loopdim=3 is currently supported.") + end + return shmem_loopentry +end + +function define_shmem_loopexits(loopexit, shmem_z_ranges, offset_maxs, loopdim::Integer) + shmem_loopexits = Dict() + shmem_As = (A for A in keys(shmem_z_ranges)) + for A in shmem_As + shmem_loopexits[A] = define_shmem_loopexit(loopexit, shmem_z_ranges[A], offset_maxs[A], loopdim) + end + return shmem_loopexits +end + +function define_shmem_loopexit(loopexit, shmem_z_range, offset_max, loopdim::Integer) + if loopdim == 3 + shmem_loopexit = loopexit - (offset_max[3] - shmem_z_range.stop) + else + @ArgumentError("memopt: only loopdim=3 is currently supported.") + end + return shmem_loopexit +end + +function varname(A::Symbol, offsets::NTuple{N,Integer} where N; i::String="ix", j::String="iy", k::String="iz") + ndims = length(offsets) + ox = offsets[1] + x = if (ox > 0) i * "p" * string(ox) + elseif (ox < 0) i * "m" * string(abs(ox)) + else i + end + if ndims > 1 + oy = offsets[2] + y = if (oy > 0) j * "p" * string(oy) + elseif (oy < 0) j * "m" * string(abs(oy)) + else j + end + end + if ndims > 2 + oz = offsets[3] + z = if (oz > 0) k * "p" * string(oz) + elseif (oz < 0) k * "m" * string(abs(oz)) + else k + end + end + if (ndims == 1) return string(A, "_$(x)") + elseif (ndims == 2) return string(A, "_$(x)_$(y)") + elseif (ndims == 3) return string(A, "_$(x)_$(y)_$(z)") + end +end + +function regtarget(A::Symbol, offsets::NTuple{N,Integer} where N, indices::NTuple{N,<:Union{Symbol,Expr}} where N) + ndims = length(offsets) + ox = offsets[1] + ix = indices[1] + if (ox > 0) x = :($ix + $ox) + elseif (ox < 0) x = :($ix - $(abs(ox))) + else x = ix + end + if ndims > 1 + oy = offsets[2] + iy = indices[2] + if (oy > 0) y = :($iy + $oy) + elseif (oy < 0) y = :($iy - $(abs(oy))) + else y = iy + end + end + if ndims > 2 + oz = offsets[3] + iz = indices[3] + if (oz > 0) z = :($iz + $oz) + elseif (oz < 0) z = :($iz - $(abs(oz))) + else z = iz + end + end + if (ndims == 1) return :($A[$x]) + elseif (ndims == 2) return :($A[$x,$y]) + elseif (ndims == 3) return :($A[$x,$y,$z]) + end +end + +function regsource(A_head::Symbol, offsets::NTuple{N,Integer} where N, local_indices::NTuple{N,<:Union{Symbol,Expr}} where N) + ndims = length(offsets) + ox = offsets[1] + tx = local_indices[1] + if (ox > 0) x = :($tx + $ox) + elseif (ox < 0) x = :($tx - $(abs(ox))) + else x = tx + end + if ndims > 1 + oy = offsets[2] + ty = local_indices[2] + if (oy > 0) y = :($ty + $oy) + elseif (oy < 0) y = :($ty - $(abs(oy))) + else y = ty + end + end + if (ndims == 1) return :($A_head[$x]) + elseif (ndims == 2) return :($A_head[$x,$y]) # e.g. :($A_head[$tx,$ty-1]) + end +end + +function wrap_if(condition::Expr, block::Expr; unless::Bool=false) + if unless + return block + else + return quote + if $condition + $block + end + end + end +end + +function wrap_loop(index::Symbol, range::UnitRange, block::Expr; unroll=false) + if length(range) == 0 + return NOEXPR + elseif length(range) == 1 + return quote + $index = $(range.start) + $block + end + else + if unroll + return quote + $(( quote + $index = $i + $block + end + for i in range + )... + ) + end + else + return quote + for $index = $(range.start):$(range.stop) + $block + end + end + end + end +end + +function store_metadata(metadata_module::Module, is_parallel_kernel::Bool, caller::Module, offset_mins::Dict{Symbol, <:NTuple{3,Integer}}, offset_maxs::Dict{Symbol, <:NTuple{3,Integer}}, offsets::Dict{Symbol, Dict{Any, Any}}, optvars::NTuple{N,Symbol} where N, loopdim::Integer, loopsize::Integer, optranges::Dict{Any, Any}, use_shmemhalos) + memopt = true + nonconst_metadata = get_nonconst_metadata(caller) + stencilranges = NamedTuple(A => (offset_mins[A][1]:offset_maxs[A][1], offset_mins[A][2]:offset_maxs[A][2], offset_mins[A][3]:offset_maxs[A][3]) for A in optvars) + if nonconst_metadata + storeexpr = quote + is_parallel_kernel = $is_parallel_kernel + memopt = $memopt + nonconst_metadata = $nonconst_metadata + stencilranges = $stencilranges + offsets = $offsets + optvars = $optvars + loopdim = $loopdim + loopsize = $loopsize + optranges = $optranges + use_shmemhalos = $use_shmemhalos + end + else + storeexpr = quote + const is_parallel_kernel = $is_parallel_kernel + const memopt = $memopt + const nonconst_metadata = $nonconst_metadata + const stencilranges = $stencilranges + const offsets = $offsets + const optvars = $optvars + const loopdim = $loopdim + const loopsize = $loopsize + const optranges = $optranges + const use_shmemhalos = $use_shmemhalos + end + end + @eval(metadata_module, $storeexpr) +end + +Base.sort(keys::T; kwargs...) where T<:Base.AbstractSet = sort([keys...]; kwargs...) + + +# macro unroll(args...) check_initialized(__module__); checkargs_unroll(args...); esc(unroll(args...)); end + +# function checkargs_unroll(args...) +# if (length(args) != 1) @ArgumentError("wrong number of arguments.") end +# end + +# function unroll(expr) +# if @capture(expr, for i_ = range_ body__ end) #TODO: enable in instead of equal +# return quote +# for $i = $range +# $(body...) +# $(Expr(:loopinfo, nodes...)) +# end +# end +# else +# error("Syntax error: loopinfo needs a for loop") +# end +# end diff --git a/src/parallel.jl b/src/parallel.jl index c3ca3055..574741f1 100644 --- a/src/parallel.jl +++ b/src/parallel.jl @@ -1,5 +1,9 @@ import .ParallelKernel: get_name, set_name, get_body, set_body!, add_return, remove_return, extract_kwargs, split_parallel_args, extract_tuple, substitute, literaltypes, push_to_signature!, add_loop, add_threadids, promote_maxsize +# NOTE: @parallel and @parallel_indices and @parallel_async do not appear in the following as they are extended and therefore re-defined here in parallel.jl +@doc replace(ParallelKernel.SYNCHRONIZE_DOC, "@init_parallel_kernel" => "@init_parallel_stencil") macro synchronize(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.@synchronize($(args...)))); end + + const PARALLEL_DOC = """ @parallel kernel @parallel inbounds=... memopt=... ndims=... kernel diff --git a/test/ParallelKernel/test_kernel_language.jl b/test/ParallelKernel/test_kernel_language.jl index 87bc6473..1517677f 100644 --- a/test/ParallelKernel/test_kernel_language.jl +++ b/test/ParallelKernel/test_kernel_language.jl @@ -24,6 +24,20 @@ end Base.retry_load_extensions() # Potentially needed to load the extensions after the packages have been filtered. +macro expr_allocated(ex) + expanded = Base.macroexpand(__module__, ex; recursive=true) + quote + # Warm-up evaluation to exclude first-call setup allocations + let + $(esc(expanded)) + end + @allocated begin + $(esc(expanded)) + end + end +end + + @static for package in TEST_PACKAGES FloatDefault = (package == PKG_METAL) ? Float32 : Float64 # Metal does not support Float64 @@ -79,6 +93,180 @@ eval(:( @test @prettystring(1, ParallelStencil.ParallelKernel.@threads()) == "Polyester.@batch" end; end; + @testset "Warp level primitives" begin + @testset "Parse-time direct call mapping" begin + # Common test variables used in macro expansions + mask = UInt64(0xffff_ffff_ffff_ffff) + mask32 = UInt32(0xffff_ffff) + val = one($FloatDefault) + lane = 1 + width = 32 + delta = 1 + lane_mask = 1 + predicate = true + + if $package == $PKG_CUDA + @test @prettystring(1, @warpsize()) == "CUDA.warpsize()" + @test @prettystring(1, @laneid()) == "CUDA.laneid() + 1" + @test @prettystring(1, @active_mask()) == "CUDA.active_mask()" + + @test @prettystring(1, @shfl_sync(mask32, val, lane)) == "CUDA.shfl_sync(mask32, val, lane)" + @test @prettystring(1, @shfl_sync(mask32, val, lane, width)) == "CUDA.shfl_sync(mask32, val, lane, width)" + @test @prettystring(1, @shfl_up_sync(mask32, val, delta)) == "CUDA.shfl_up_sync(mask32, val, delta)" + @test @prettystring(1, @shfl_up_sync(mask32, val, delta, width)) == "CUDA.shfl_up_sync(mask32, val, delta, width)" + @test @prettystring(1, @shfl_down_sync(mask32, val, delta)) == "CUDA.shfl_down_sync(mask32, val, delta)" + @test @prettystring(1, @shfl_down_sync(mask32, val, delta, width)) == "CUDA.shfl_down_sync(mask32, val, delta, width)" + @test @prettystring(1, @shfl_xor_sync(mask32, val, lane_mask)) == "CUDA.shfl_xor_sync(mask32, val, lane_mask)" + @test @prettystring(1, @shfl_xor_sync(mask32, val, lane_mask, width)) == "CUDA.shfl_xor_sync(mask32, val, lane_mask, width)" + + @test @prettystring(1, @vote_any_sync(mask32, predicate)) == "CUDA.vote_any_sync(mask32, predicate)" + @test @prettystring(1, @vote_all_sync(mask32, predicate)) == "CUDA.vote_all_sync(mask32, predicate)" + @test @prettystring(1, @vote_ballot_sync(mask32, predicate)) == "CUDA.vote_ballot_sync(mask32, predicate)" + + elseif $package == $PKG_AMDGPU + @test @prettystring(1, @warpsize()) == "AMDGPU.Device.wavefrontsize()" + @test @prettystring(1, @laneid()) == "unsafe_trunc(Cint, AMDGPU.Device.activelane()) + Cint(1)" + @test @prettystring(1, @active_mask()) == "AMDGPU.Device.activemask()" + + @test @prettystring(1, @shfl_sync(mask, val, lane)) == "AMDGPU.Device.shfl_sync(UInt64(mask), val, unsafe_trunc(Cint, lane) - Cint(1))" + @test @prettystring(1, @shfl_sync(mask, val, lane, width)) == "AMDGPU.Device.shfl_sync(UInt64(mask), val, unsafe_trunc(Cint, lane) - Cint(1), unsafe_trunc(Cuint, width))" + @test @prettystring(1, @shfl_up_sync(mask, val, delta)) == "AMDGPU.Device.shfl_up_sync(UInt64(mask), val, unsafe_trunc(Cint, delta))" + @test @prettystring(1, @shfl_up_sync(mask, val, delta, width)) == "AMDGPU.Device.shfl_up_sync(UInt64(mask), val, unsafe_trunc(Cint, delta), unsafe_trunc(Cuint, width))" + @test @prettystring(1, @shfl_down_sync(mask, val, delta)) == "AMDGPU.Device.shfl_down_sync(UInt64(mask), val, unsafe_trunc(Cint, delta))" + @test @prettystring(1, @shfl_down_sync(mask, val, delta, width)) == "AMDGPU.Device.shfl_down_sync(UInt64(mask), val, unsafe_trunc(Cint, delta), unsafe_trunc(Cuint, width))" + @test @prettystring(1, @shfl_xor_sync(mask, val, lane_mask)) == "AMDGPU.Device.shfl_xor_sync(UInt64(mask), val, unsafe_trunc(Cint, lane_mask) - Cint(1))" + @test @prettystring(1, @shfl_xor_sync(mask, val, lane_mask, width)) == "AMDGPU.Device.shfl_xor_sync(UInt64(mask), val, unsafe_trunc(Cint, lane_mask) - Cint(1), unsafe_trunc(Cuint, width))" + + @test @prettystring(1, @vote_any_sync(mask, predicate)) == "AMDGPU.Device.any_sync(UInt64(mask), predicate)" + @test @prettystring(1, @vote_all_sync(mask, predicate)) == "AMDGPU.Device.all_sync(UInt64(mask), predicate)" + @test @prettystring(1, @vote_ballot_sync(mask, predicate)) == "AMDGPU.Device.ballot_sync(UInt64(mask), predicate)" + + elseif $package == $PKG_METAL + @test @prettystring(1, @warpsize()) == "Metal.threads_per_simdgroup()" + @test @prettystring(1, @laneid()) == "unsafe_trunc(Cint, Metal.thread_index_in_simdgroup()) + Cint(1)" + @test_throws Exception @prettystring(1, @active_mask()) + @test_throws Exception @prettystring(1, @shfl_sync(mask, val, lane)) + @test_throws Exception @prettystring(1, @vote_ballot_sync(mask, predicate)) + + elseif @iscpu($package) + @test @prettystring(1, @warpsize()) == "ParallelStencil.ParallelKernel.warpsize_cpu()" + @test @prettystring(1, @laneid()) == "ParallelStencil.ParallelKernel.laneid_cpu()" + @test @prettystring(1, @active_mask()) == "ParallelStencil.ParallelKernel.active_mask_cpu()" + + @test @prettystring(1, @shfl_sync(mask, val, lane)) == "ParallelStencil.ParallelKernel.shfl_sync_cpu(mask, val, Int64(lane) - Int64(1))" + @test @prettystring(1, @shfl_sync(mask, val, lane, width)) == "ParallelStencil.ParallelKernel.shfl_sync_cpu(mask, val, Int64(lane) - Int64(1), Int64(width))" + @test @prettystring(1, @shfl_up_sync(mask, val, delta)) == "ParallelStencil.ParallelKernel.shfl_up_sync_cpu(mask, val, Int64(delta))" + @test @prettystring(1, @shfl_up_sync(mask, val, delta, width)) == "ParallelStencil.ParallelKernel.shfl_up_sync_cpu(mask, val, Int64(delta), Int64(width))" + @test @prettystring(1, @shfl_down_sync(mask, val, delta)) == "ParallelStencil.ParallelKernel.shfl_down_sync_cpu(mask, val, Int64(delta))" + @test @prettystring(1, @shfl_down_sync(mask, val, delta, width)) == "ParallelStencil.ParallelKernel.shfl_down_sync_cpu(mask, val, Int64(delta), Int64(width))" + @test @prettystring(1, @shfl_xor_sync(mask, val, lane_mask)) == "ParallelStencil.ParallelKernel.shfl_xor_sync_cpu(mask, val, Int64(lane_mask) - Int64(1))" + @test @prettystring(1, @shfl_xor_sync(mask, val, lane_mask, width)) == "ParallelStencil.ParallelKernel.shfl_xor_sync_cpu(mask, val, Int64(lane_mask) - Int64(1), Int64(width))" + + @test @prettystring(1, @vote_any_sync(mask, predicate)) == "ParallelStencil.ParallelKernel.vote_any_sync_cpu(mask, predicate)" + @test @prettystring(1, @vote_all_sync(mask, predicate)) == "ParallelStencil.ParallelKernel.vote_all_sync_cpu(mask, predicate)" + @test @prettystring(1, @vote_ballot_sync(mask, predicate)) == "ParallelStencil.ParallelKernel.vote_ballot_sync_cpu(mask, predicate)" + end + end; + @testset "CPU zero overhead" begin + @static if @iscpu($package) + # Use stable literal arguments to exercise CPU code paths + mask = UInt64(0x1) + valf = one($FloatDefault) + lane = 1 + width = 1 + delta = 1 + lanemask = 1 + predicate = true + + @test @expr_allocated(@warpsize()) == 0 + @test @expr_allocated(@laneid()) == 0 + @test @expr_allocated(@active_mask()) == 0 + + @test @expr_allocated(@shfl_sync(mask, valf, lane)) == 0 + @test @expr_allocated(@shfl_sync(mask, valf, lane, width)) == 0 + @test @expr_allocated(@shfl_up_sync(mask, valf, delta)) == 0 + @test @expr_allocated(@shfl_up_sync(mask, valf, delta, width)) == 0 + @test @expr_allocated(@shfl_down_sync(mask, valf, delta)) == 0 + @test @expr_allocated(@shfl_down_sync(mask, valf, delta, width)) == 0 + @test @expr_allocated(@shfl_xor_sync(mask, valf, lanemask)) == 0 + @test @expr_allocated(@shfl_xor_sync(mask, valf, lanemask, width)) == 0 + + @test @expr_allocated(@vote_any_sync(mask, predicate)) == 0 + @test @expr_allocated(@vote_all_sync(mask, predicate)) == 0 + @test @expr_allocated(@vote_ballot_sync(mask, predicate)) == 0 + end + end; + @testset "Semantic smoke tests" begin + @static if @iscpu($package) + N = 8 + A = @rand(N) + P = [isfinite(A[i]) && (A[i] > zero($FloatDefault)) for i in 1:N] # simple predicate + Bout_any = Vector{Bool}(undef, N) + Bout_all = Vector{Bool}(undef, N) + Bout_ballot = Vector{UInt64}(undef, N) + Bshfl = similar(A) + Bshfl_up = similar(A) + Bshfl_down = similar(A) + Bshfl_xor = similar(A) + + @parallel_indices (ix) function kernel_semantics!(Bout_any, Bout_all, Bout_ballot, Bshfl, Bshfl_up, Bshfl_down, Bshfl_xor, A, P) + m = @active_mask() + w = @warpsize() + l = @laneid() + # basic invariants under CPU model + @test w == 1 + @test l == 1 + # shuffle identities + Bshfl[ix] = @shfl_sync(m, A[ix], l) + Bshfl_up[ix] = @shfl_up_sync(m, A[ix], 1) + Bshfl_down[ix] = @shfl_down_sync(m, A[ix], 1) + Bshfl_xor[ix] = @shfl_xor_sync(m, A[ix], 1) + # votes + pa = P[ix] + Bout_any[ix] = @vote_any_sync(m, pa) + Bout_all[ix] = @vote_all_sync(m, pa) + Bout_ballot[ix] = @vote_ballot_sync(m, pa) + return + end + @parallel (1:N) kernel_semantics!(Bout_any, Bout_all, Bout_ballot, Bshfl, Bshfl_up, Bshfl_down, Bshfl_xor, A, P) + + @test all(Bshfl .== A) + @test all(Bshfl_up .== A) + @test all(Bshfl_down .== A) + @test all(Bshfl_xor .== A) + @test Bout_any == P + @test Bout_all == P + @test Bout_ballot == map(p -> p ? UInt64(0x1) : UInt64(0x0), P) + end + end; + @testset "Unsupported primitives" begin + @static if $package == $PKG_METAL + mask = UInt64(0x1) + mask32 = UInt32(0x1) + valf = one($FloatDefault) + lane = 1 + width = 1 + delta = 1 + lanemask = 1 + predicate = true + + @test_throws Exception @prettystring(1, @active_mask()) + + @test_throws Exception @prettystring(1, @shfl_sync(mask, valf, lane)) + @test_throws Exception @prettystring(1, @shfl_sync(mask, valf, lane, width)) + @test_throws Exception @prettystring(1, @shfl_up_sync(mask, valf, delta)) + @test_throws Exception @prettystring(1, @shfl_up_sync(mask, valf, delta, width)) + @test_throws Exception @prettystring(1, @shfl_down_sync(mask, valf, delta)) + @test_throws Exception @prettystring(1, @shfl_down_sync(mask, valf, delta, width)) + @test_throws Exception @prettystring(1, @shfl_xor_sync(mask, valf, lanemask)) + @test_throws Exception @prettystring(1, @shfl_xor_sync(mask, valf, lanemask, width)) + + @test_throws Exception @prettystring(1, @vote_any_sync(mask32, predicate)) + @test_throws Exception @prettystring(1, @vote_all_sync(mask32, predicate)) + @test_throws Exception @prettystring(1, @vote_ballot_sync(mask32, predicate)) + end + end; + end; @testset "@gridDim, @blockIdx, @blockDim, @threadIdx (1D)" begin @static if $package == $PKG_THREADS A = @zeros(4) diff --git a/test/runtests.jl b/test/runtests.jl index 3692d64a..9be36612 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,13 +9,14 @@ import ParallelStencil: SUPPORTED_PACKAGES, PKG_CUDA, PKG_AMDGPU, PKG_METAL excludedfiles = [ "test_excluded.jl", "test_incremental_compilation.jl", "test_revise.jl"]; # TODO: test_incremental_compilation has to be deactivated until Polyester support released -function runtests() +function runtests(testfiles=String[]) exename = joinpath(Sys.BINDIR, Base.julia_exename()) testdir = pwd() istest(f) = endswith(f, ".jl") && startswith(basename(f), "test_") - testfiles = sort(filter(istest, vcat([joinpath.(root, files) for (root, dirs, files) in walkdir(testdir)]...))) + testfiles = isempty(testfiles) ? sort(filter(istest, vcat([joinpath.(root, files) for (root, dirs, files) in walkdir(testdir)]...))) : testfiles - nfail = 0 + nerror = 0 + errorfiles = String[] printstyled("Testing package ParallelStencil.jl\n"; bold=true, color=:white) if (PKG_CUDA in SUPPORTED_PACKAGES && !CUDA.functional()) @@ -37,13 +38,49 @@ function runtests() println("$f") continue end + cmd = `$exename -O3 --startup-file=no $(joinpath(testdir, f))` + stdout_path = tempname() + stderr_path = tempname() + stdout_content = "" + stderr_content = "" try - run(`$exename -O3 --startup-file=no $(joinpath(testdir, f))`) + open(stdout_path, "w") do stdout_io + open(stderr_path, "w") do stderr_io + proc = run(pipeline(Cmd(cmd; ignorestatus=true), stdout=stdout_io, stderr=stderr_io); wait=false) + wait(proc) + end + end + stdout_content = read(stdout_path, String) + stderr_content = read(stderr_path, String) + print(stdout_content) + print(Base.stderr, stderr_content) catch ex - nfail += 1 + println("Test Error: an exception occurred while running the test file $f :") + println(ex) + finally + if ispath(stdout_path) + rm(stdout_path; force=true) + end + if ispath(stderr_path) + rm(stderr_path; force=true) + end + end + if !occursin(r"(?i)test summary", stdout_content) + nerror += 1 + push!(errorfiles, f) + end + end + println("") + if nerror == 0 + printstyled("Test suite: all selected test files executed (see above for results).\n"; bold=true, color=:green) + else + printstyled("Test suite: $nerror test file(s) aborted execution due to error (see above for details); files aborting execution:\n"; bold=true, color=:red) + for f in errorfiles + println(" - $f") end end - return nfail + println("") + return nerror end -exit(runtests()) +exit(runtests(ARGS)) diff --git a/test/test_kernel_language.jl b/test/test_kernel_language.jl new file mode 100644 index 00000000..e0d090e2 --- /dev/null +++ b/test/test_kernel_language.jl @@ -0,0 +1,114 @@ +using Test +using ParallelStencil +import ParallelStencil: @reset_parallel_stencil, @is_initialized, SUPPORTED_PACKAGES, PKG_CUDA, PKG_AMDGPU, PKG_METAL, PKG_THREADS, PKG_POLYESTER +import ParallelStencil: @require, @prettystring, @iscpu + +TEST_PACKAGES = SUPPORTED_PACKAGES +@static if PKG_CUDA in TEST_PACKAGES + import CUDA + if !CUDA.functional() + TEST_PACKAGES = filter!(x -> x ≠ PKG_CUDA, TEST_PACKAGES) + end +end +@static if PKG_AMDGPU in TEST_PACKAGES + import AMDGPU + if !AMDGPU.functional() + TEST_PACKAGES = filter!(x -> x ≠ PKG_AMDGPU, TEST_PACKAGES) + end +end +@static if PKG_METAL in TEST_PACKAGES + @static if Sys.isapple() + import Metal + if !Metal.functional() + TEST_PACKAGES = filter!(x -> x ≠ PKG_METAL, TEST_PACKAGES) + end + else + TEST_PACKAGES = filter!(x -> x ≠ PKG_METAL, TEST_PACKAGES) + end +end +@static if PKG_POLYESTER in TEST_PACKAGES + import Polyester +end +Base.retry_load_extensions() + +@static for package in TEST_PACKAGES + FloatDefault = (package == PKG_METAL) ? Float32 : Float64 + + eval(:( + @testset "$(basename(@__FILE__)) (package: $(nameof($package)))" begin + @require !@is_initialized() + @init_parallel_stencil($package, $FloatDefault, 3) + @require @is_initialized() + + @testset "Pass-through macro mapping" begin + @test @prettystring(1, @gridDim()) == "ParallelStencil.ParallelKernel.@gridDim" + @test @prettystring(1, @blockIdx()) == "ParallelStencil.ParallelKernel.@blockIdx" + @test @prettystring(1, @blockDim()) == "ParallelStencil.ParallelKernel.@blockDim" + @test @prettystring(1, @threadIdx()) == "ParallelStencil.ParallelKernel.@threadIdx" + @test @prettystring(1, @sync_threads()) == "ParallelStencil.ParallelKernel.@sync_threads" + @test @prettystring(1, @sharedMem(T, dims)) == "ParallelStencil.ParallelKernel.@sharedMem T dims" + @test @prettystring(1, @ps_show args) == "ParallelStencil.ParallelKernel.@pk_show args" + @test @prettystring(1, @ps_println args) == "ParallelStencil.ParallelKernel.@pk_println args" + @test @prettystring(1, @∀ i ∈ (x, y) body) == "ParallelStencil.ParallelKernel.@∀ i ∈ (x, y) body" + + @test @prettystring(1, @warpsize()) == "ParallelStencil.ParallelKernel.@warpsize" + @test @prettystring(1, @laneid()) == "ParallelStencil.ParallelKernel.@laneid" + @test @prettystring(1, @active_mask()) == "ParallelStencil.ParallelKernel.@active_mask" + @test @prettystring(1, @shfl_sync(mask, val, lane)) == "ParallelStencil.ParallelKernel.@shfl_sync mask val lane" + @test @prettystring(1, @shfl_sync(mask, val, lane, width)) == "ParallelStencil.ParallelKernel.@shfl_sync mask val lane width" + @test @prettystring(1, @shfl_up_sync(mask, val, delta)) == "ParallelStencil.ParallelKernel.@shfl_up_sync mask val delta" + @test @prettystring(1, @shfl_up_sync(mask, val, delta, width)) == "ParallelStencil.ParallelKernel.@shfl_up_sync mask val delta width" + @test @prettystring(1, @shfl_down_sync(mask, val, delta)) == "ParallelStencil.ParallelKernel.@shfl_down_sync mask val delta" + @test @prettystring(1, @shfl_down_sync(mask, val, delta, width)) == "ParallelStencil.ParallelKernel.@shfl_down_sync mask val delta width" + @test @prettystring(1, @shfl_xor_sync(mask, val, lanemask)) == "ParallelStencil.ParallelKernel.@shfl_xor_sync mask val lanemask" + @test @prettystring(1, @shfl_xor_sync(mask, val, lanemask, width)) == "ParallelStencil.ParallelKernel.@shfl_xor_sync mask val lanemask width" + @test @prettystring(1, @vote_any_sync(mask, predicate)) == "ParallelStencil.ParallelKernel.@vote_any_sync mask predicate" + @test @prettystring(1, @vote_all_sync(mask, predicate)) == "ParallelStencil.ParallelKernel.@vote_all_sync mask predicate" + @test @prettystring(1, @vote_ballot_sync(mask, predicate)) == "ParallelStencil.ParallelKernel.@vote_ballot_sync mask predicate" + end + + @testset "CPU semantic smoke tests" begin + @static if @iscpu($package) + N = 8 + A = @rand(N) + P = [isfinite(A[i]) && (A[i] > zero($FloatDefault)) for i in 1:N] + Bout_any = Vector{Bool}(undef, N) + Bout_all = Vector{Bool}(undef, N) + Bout_ballot = Vector{UInt64}(undef, N) + Bshfl = similar(A) + Bshfl_up = similar(A) + Bshfl_down = similar(A) + Bshfl_xor = similar(A) + + @parallel_indices (ix) function kernel_semantics!(Bout_any, Bout_all, Bout_ballot, Bshfl, Bshfl_up, Bshfl_down, Bshfl_xor, A, P) + m = @active_mask() + w = @warpsize() + l = @laneid() + @test w == 1 + @test l == 1 + Bshfl[ix] = @shfl_sync(m, A[ix], l) + Bshfl_up[ix] = @shfl_up_sync(m, A[ix], 1) + Bshfl_down[ix] = @shfl_down_sync(m, A[ix], 1) + Bshfl_xor[ix] = @shfl_xor_sync(m, A[ix], 1) + pa = P[ix] + Bout_any[ix] = @vote_any_sync(m, pa) + Bout_all[ix] = @vote_all_sync(m, pa) + Bout_ballot[ix] = @vote_ballot_sync(m, pa) + return + end + @parallel (1:N) kernel_semantics!(Bout_any, Bout_all, Bout_ballot, Bshfl, Bshfl_up, Bshfl_down, Bshfl_xor, A, P) + + @test all(Bshfl .== A) + @test all(Bshfl_up .== A) + @test all(Bshfl_down .== A) + @test all(Bshfl_xor .== A) + @test Bout_any == P + @test Bout_all == P + @test Bout_ballot == map(p -> p ? UInt64(0x1) : UInt64(0x0), P) + end + end + + @reset_parallel_stencil() + end + )) +end == nothing || true; \ No newline at end of file