Skip to content

Commit c6eb03b

Browse files
Add partitionParticles version that skips reduction (#4803)
## Summary This PR adds an overload of partitionParticles that takes num_left as an input to skip the reduction that would compute num_left in the original function. This can be useful when combining the reduction with other operations in an effort to reduce the overhead from extra kernel launches and stream synchronizations. ## Additional background ## Checklist The proposed changes: - [ ] fix a bug or incorrect behavior in AMReX - [x] add new capabilities to AMReX - [ ] changes answers in the test suite to more than roundoff level - [ ] are likely to significantly affect the results of downstream AMReX users - [ ] include documentation in the code and/or rst files, if appropriate
1 parent ef7ca32 commit c6eb03b

File tree

1 file changed

+65
-0
lines changed

1 file changed

+65
-0
lines changed

Src/Particle/AMReX_ParticleUtil.H

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,71 @@ partitionParticles (PTile& ptile, ParFunc const& is_left)
569569
return num_left;
570570
}
571571

572+
/**
573+
* \brief Reorders the ParticleTile into two partitions
574+
* left [0, num_left-1] and right [num_left, ptile.numParticles()-1].
575+
* This version of the function requires the correct amount for num_left to be passed as an input,
576+
* which allows it to skip a reduction. 
577+
*
578+
* The functor is_left [(ParticleTileData ptd, int index) -> bool] maps each particle to
579+
* either the left [return true] or the right [return false] partition.
580+
* It must return the same result if evaluated multiple times for the same particle.
581+
*
582+
* \param ptile the ParticleTile to partition
583+
* \param num_left number of particles in the left partition
584+
* \param is_left functor to map particles to a partition
585+
*/
586+
template <typename PTile, typename ParFunc>
587+
void
588+
partitionParticles (PTile& ptile, int num_left, ParFunc const& is_left)
589+
{
590+
const int np = ptile.numParticles();
591+
if (np == 0) { return; }
592+
593+
auto ptd = ptile.getParticleTileData();
594+
595+
const int max_num_swaps = std::min(num_left, np - num_left);
596+
if (max_num_swaps == 0) { return; }
597+
598+
Gpu::DeviceVector<int> index_left(max_num_swaps);
599+
Gpu::DeviceVector<int> index_right(max_num_swaps);
600+
int * const p_index_left = index_left.dataPtr();
601+
int * const p_index_right = index_right.dataPtr();
602+
603+
Scan::PrefixSum<int>(np,
604+
[=] AMREX_GPU_DEVICE (int i) -> int
605+
{
606+
return int(!is_left(ptd, i));
607+
},
608+
[=] AMREX_GPU_DEVICE (int i, int const& s)
609+
{
610+
if (!is_left(ptd, i)) {
611+
int dst = s;
612+
if (dst < max_num_swaps) {
613+
p_index_right[dst] = i;
614+
}
615+
} else {
616+
int dst = num_left-1-(i-s); // avoid integer overflow
617+
if (dst < max_num_swaps) {
618+
p_index_left[dst] = i;
619+
}
620+
}
621+
},
622+
Scan::Type::exclusive, Scan::noRetSum);
623+
624+
ParallelFor(max_num_swaps,
625+
[=] AMREX_GPU_DEVICE (int i)
626+
{
627+
int left_i = p_index_left[i];
628+
int right_i = p_index_right[i];
629+
if (right_i < left_i) {
630+
swapParticle(ptd, ptd, left_i, right_i);
631+
}
632+
});
633+
634+
Gpu::streamSynchronize(); // for index_left and index_right deallocation
635+
}
636+
572637
template <typename PTile>
573638
void
574639
removeInvalidParticles (PTile& ptile)

0 commit comments

Comments
 (0)