Skip to content

Commit cc1ab2a

Browse files
committed
set params in function of capability and improve mapping to ranges
1 parent f3e8768 commit cc1ab2a

File tree

1 file changed

+24
-16
lines changed

1 file changed

+24
-16
lines changed

src/kernel_language.jl

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ function memopt(metadata_module::Module, is_parallel_kernel::Bool, caller::Modul
7878
offsets, offsets_by_z = extract_offsets(caller, body, indices, int_type, optvars, loopdim)
7979
optvars = remove_single_point_optvars(optvars, optranges, offsets, offsets_by_z)
8080
if (length(optvars)==0) @IncoherentArgumentError("incoherent argument memopt in @parallel[_indices] <kernel>: optimization can only be applied if there is at least one array that is read-only within the kernel (and accessed with a multi-point stencil). Set memopt=false for this kernel.") end
81-
optranges = define_optranges(optranges, optvars, offsets, int_type)
81+
optranges = define_optranges(optranges, optvars, offsets, int_type, package)
8282
regqueue_heads, regqueue_tails, offset_mins, offset_maxs, nb_regs_heads, nb_regs_tails = define_regqueues(offsets, optranges, optvars, indices, int_type, loopdim)
8383

8484
if loopdim == 3
@@ -102,6 +102,7 @@ function memopt(metadata_module::Module, is_parallel_kernel::Bool, caller::Modul
102102
ranges = RANGES_VARNAME
103103
range_z = :(($ranges[3])[$tz_g])
104104
range_z_start = :(($ranges[3])[1])
105+
range_z_end = :(($ranges[3])[end])
105106
i = gensym_world("i", @__MODULE__)
106107
loopoffset = gensym_world("loopoffset", @__MODULE__)
107108

@@ -125,7 +126,7 @@ function memopt(metadata_module::Module, is_parallel_kernel::Bool, caller::Modul
125126

126127
#TODO: replace wrap_if where possible with in-line if - compare performance when doing it
127128
body = quote
128-
$loopoffset = (@blockIdx().z-1)*$loopsize #TODO: MOVE UP - see no perf change! interchange other lines!
129+
$loopoffset = (@blockIdx().z-1)*$loopsize + $range_z_start-1 #TODO: MOVE UP - see no perf change! interchange other lines!
129130
$((quote
130131
$tx = @threadIdx().x + $hx1
131132
$ty = @threadIdx().y + $hy1
@@ -164,9 +165,12 @@ $((:( $reg = 0.0
164165
# for $i = $loopstart:$(mainloopstart-1)
165166
$(wrap_loop(i, loopstart:mainloopstart-1,
166167
quote
167-
$tz_g = $i + $loopoffset
168-
if ($tz_g > $rangelength_z) ParallelStencil.@return_nothing; end
169-
$iz = ($tz_g < 1) ? $range_z_start-(1-$tz_g) : $range_z # TODO: this will probably always be formulated with range_z_start
168+
$iz = $i + $loopoffset
169+
if ($iz > $range_z_end) ParallelStencil.@return_nothing; end
170+
# NOTE: the following is now fully included in the loopoffset (0.25% performance gain measured on H100) but is still of interest if we implement step ranges:
171+
# $tz_g = $i + $loopoffset
172+
# if ($tz_g > $rangelength_z) ParallelStencil.@return_nothing; end
173+
# $iz = ($tz_g < 1) ? $range_z_start-(1-$tz_g) : $range_z # TODO: this will probably always be formulated with range_z_start
170174
$((wrap_if(:($i > $(loopentry-1)),
171175
:( $reg = (0<$ix+$(oxy[1])<=size($A,1) && 0<$iy+$(oxy[2])<=size($A,2) && 0<$iz+$oz<=size($A,3)) ? $(regtarget(A, (oxy...,oz), indices)) : $reg
172176
)
@@ -212,9 +216,12 @@ $(( # NOTE: the if statement is not needed here as we only deal with registers
212216
# for $i = $mainloopstart:$mainloopend # ParallelStencil.@unroll
213217
$(wrap_loop(i, mainloopstart:mainloopend,
214218
quote
215-
$tz_g = $i + $loopoffset
216-
if ($tz_g > $rangelength_z) ParallelStencil.@return_nothing; end
217-
$iz = ($tz_g < 1) ? $range_z_start-(1-$tz_g) : $range_z # TODO: this will probably always be formulated with range_z_start
219+
$iz = $i + $loopoffset
220+
if ($iz > $range_z_end) ParallelStencil.@return_nothing; end
221+
# NOTE: the following is now fully included in the loopoffset (0.25% performance gain measured on H100) but is still of interest if we implement step ranges:
222+
# $tz_g = $i + $loopoffset
223+
# if ($tz_g > $rangelength_z) ParallelStencil.@return_nothing; end
224+
# $iz = ($tz_g < 1) ? $range_z_start-(1-$tz_g) : $range_z # TODO: this will probably always be formulated with range_z_start
218225
$(use_any_shmem ?
219226
:( @sync_threads()
220227
) : NOEXPR
@@ -468,7 +475,7 @@ end
468475

469476
function memopt(metadata_module::Module, is_parallel_kernel::Bool, caller::Module, indices::Union{Symbol,Expr}, optvars::Union{Expr,Symbol}, body::Expr; package::Symbol=get_package(caller))
470477
loopdim = isa(indices,Expr) ? length(indices.args) : 1
471-
loopsize = LOOPSIZE
478+
loopsize = compute_loopsize(package)
472479
optranges = nothing
473480
use_shmemhalos = nothing
474481
optimize_halo_read = true
@@ -545,7 +552,8 @@ function remove_single_point_optvars(optvars, optranges_arg, offsets, offsets_by
545552
return tuple((A for A in optvars if !(length(keys(offsets[A]))==1 && length(keys(offsets_by_z[A]))==1) || (!isnothing(optranges_arg) && A keys(optranges_arg)))...)
546553
end
547554

548-
function define_optranges(optranges_arg, optvars, offsets, int_type)
555+
function define_optranges(optranges_arg, optvars, offsets, int_type, package)
556+
compute_capability = get_compute_capability(package)
549557
optranges = Dict()
550558
for A in optvars
551559
zspan_max = 0
@@ -560,12 +568,12 @@ function define_optranges(optranges_arg, optvars, offsets, int_type)
560568
fullrange = typemin(int_type):typemax(int_type)
561569
pointrange_x = oxy_zspan_max[1]: oxy_zspan_max[1]
562570
pointrange_y = oxy_zspan_max[2]: oxy_zspan_max[2]
563-
if (!isnothing(optranges_arg) && A keys(optranges_arg)) optranges[A] = getproperty(optranges_arg, A)
564-
elseif (length(optvars) <= FULLRANGE_THRESHOLD) optranges[A] = (fullrange, fullrange, fullrange)
565-
elseif (USE_FULLRANGE_DEFAULT == (true, true, true)) optranges[A] = (fullrange, fullrange, fullrange)
566-
elseif (USE_FULLRANGE_DEFAULT == (false, true, true)) optranges[A] = (pointrange_x, fullrange, fullrange)
567-
elseif (USE_FULLRANGE_DEFAULT == (true, false, true)) optranges[A] = (fullrange, pointrange_y, fullrange)
568-
elseif (USE_FULLRANGE_DEFAULT == (false, false, true)) optranges[A] = (pointrange_x, pointrange_y, fullrange)
571+
if (!isnothing(optranges_arg) && A keys(optranges_arg)) optranges[A] = getproperty(optranges_arg, A)
572+
elseif (compute_capability < v"8" && (length(optvars) <= FULLRANGE_THRESHOLD)) optranges[A] = (fullrange, fullrange, fullrange)
573+
elseif (USE_FULLRANGE_DEFAULT == (true, true, true)) optranges[A] = (fullrange, fullrange, fullrange)
574+
elseif (USE_FULLRANGE_DEFAULT == (false, true, true)) optranges[A] = (pointrange_x, fullrange, fullrange)
575+
elseif (USE_FULLRANGE_DEFAULT == (true, false, true)) optranges[A] = (fullrange, pointrange_y, fullrange)
576+
elseif (USE_FULLRANGE_DEFAULT == (false, false, true)) optranges[A] = (pointrange_x, pointrange_y, fullrange)
569577
end
570578
end
571579
return optranges

0 commit comments

Comments
 (0)