Skip to content

Commit a21af25

Browse files
authored
Fix GPU communication for non-arithmetic types (#4515)
Tested on Nvidia GTX 1060, GV100, V100, A100 and H100, AMD MI250X and MI300A, and Intel PVC.
1 parent 9a3d449 commit a21af25

File tree

9 files changed

+509
-131
lines changed

9 files changed

+509
-131
lines changed

Src/Base/AMReX_FBI.H

Lines changed: 232 additions & 127 deletions
Large diffs are not rendered by default.

Src/Base/AMReX_FabArrayUtility.H

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2123,6 +2123,76 @@ DistributionMap (Array<MF,N> const& mf)
21232123
return mf[0].DistributionMap();
21242124
}
21252125

2126+
/*
2127+
* \brief Return a mask indicating how many duplicates are in each point
2128+
*
2129+
* \param fa input FabArray
2130+
* \param nghost number of ghost cells included in counting
2131+
* \param period periodicity
2132+
*/
2133+
template <class FAB>
2134+
FabArray<BaseFab<int>>
2135+
OverlapMask (FabArray<FAB> const& fa, IntVect const& nghost, Periodicity const& period)
2136+
{
2137+
BL_PROFILE("OverlapMask()");
2138+
2139+
const BoxArray& ba = fa.boxArray();
2140+
const DistributionMapping& dm = fa.DistributionMap();
2141+
2142+
FabArray<BaseFab<int>> mask(ba, dm, 1, nghost);
2143+
mask.setVal(1);
2144+
2145+
const std::vector<IntVect>& pshifts = period.shiftIntVect();
2146+
2147+
Vector<Array4BoxTag<int> > tags;
2148+
2149+
bool run_on_gpu = Gpu::inLaunchRegion();
2150+
amrex::ignore_unused(run_on_gpu, tags);
2151+
#ifdef AMREX_USE_OMP
2152+
#pragma omp parallel if (!run_on_gpu)
2153+
#endif
2154+
{
2155+
std::vector< std::pair<int,Box> > isects;
2156+
2157+
for (MFIter mfi(mask); mfi.isValid(); ++mfi)
2158+
{
2159+
const Box& bx = mask[mfi].box();
2160+
auto const& arr = mask.array(mfi);
2161+
2162+
for (const auto& iv : pshifts)
2163+
{
2164+
ba.intersections(bx+iv, isects, false, nghost);
2165+
for (const auto& is : isects)
2166+
{
2167+
Box const& b = is.second-iv;
2168+
if (iv == 0 && b == bx) { continue; }
2169+
#ifdef AMREX_USE_GPU
2170+
if (run_on_gpu) {
2171+
tags.push_back({arr,b});
2172+
} else
2173+
#endif
2174+
{
2175+
amrex::LoopConcurrentOnCpu(b, [=] (int i, int j, int k) noexcept
2176+
{
2177+
arr(i,j,k) += 1;
2178+
});
2179+
}
2180+
}
2181+
}
2182+
}
2183+
}
2184+
2185+
#ifdef AMREX_USE_GPU
2186+
amrex::ParallelFor(tags, 1,
2187+
[=] AMREX_GPU_DEVICE (int i, int j, int k, int n, Array4BoxTag<int> const& tag) noexcept
2188+
{
2189+
Gpu::Atomic::AddNoRet(tag.dfab.ptr(i,j,k,n), 1);
2190+
});
2191+
#endif
2192+
2193+
return mask;
2194+
}
2195+
21262196
}
21272197

21282198
#endif

Src/Base/AMReX_MultiFab.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1529,8 +1529,8 @@ MultiFab::OverlapMask (const Periodicity& period) const
15291529
amrex::ParallelFor(tags, 1,
15301530
[=] AMREX_GPU_DEVICE (int i, int j, int k, int n, Array4BoxTag<Real> const& tag) noexcept
15311531
{
1532-
Real* p = tag.dfab.ptr(i,j,k,n);
1533-
Gpu::Atomic::AddNoRet(p, Real(1.0));
1532+
Real* ptr = tag.dfab.ptr(i,j,k,n);
1533+
Gpu::Atomic::AddNoRet(ptr, Real(1.0));
15341534
});
15351535
#endif
15361536

Src/Base/AMReX_PCI.H

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,13 +97,15 @@ FabArray<FAB>::PC_local_gpu (const CPC& thecpc, FabArray<FAB> const& src,
9797
loc_copy_tags.reserve(N_locs);
9898

9999
Vector<BaseFab<int> > maskfabs;
100+
Vector<Array4Tag<int> > masks_unique;
100101
Vector<Array4Tag<int> > masks;
101102
if (!is_thread_safe)
102103
{
103104
if ((op == FabArrayBase::COPY && !amrex::IsStoreAtomic<value_type>::value) ||
104105
(op == FabArrayBase::ADD && !amrex::HasAtomicAdd <value_type>::value))
105106
{
106107
maskfabs.resize(this->local_size());
108+
masks_unique.reserve(this->local_size());
107109
masks.reserve(N_locs);
108110
}
109111
}
@@ -122,14 +124,15 @@ FabArray<FAB>::PC_local_gpu (const CPC& thecpc, FabArray<FAB> const& src,
122124
if (maskfabs.size() > 0) {
123125
if (!maskfabs[li].isAllocated()) {
124126
maskfabs[li].resize(this->atLocalIdx(li).box());
127+
masks_unique.emplace_back(Array4Tag<int>{maskfabs[li].array()});
125128
}
126129
masks.emplace_back(Array4Tag<int>{maskfabs[li].array()});
127130
}
128131
}
129132
}
130133

131134
if (maskfabs.size() > 0) {
132-
amrex::ParallelFor(masks,
135+
amrex::ParallelFor(masks_unique,
133136
[=] AMREX_GPU_DEVICE (int i, int j, int k, Array4Tag<int> const& msk) noexcept
134137
{
135138
msk.dfab(i,j,k) = 0;

Tests/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ else()
125125
#
126126
# List of subdirectories to search for CMakeLists.
127127
#
128-
set( AMREX_TESTS_SUBDIRS Amr AsyncOut CLZ CTOParFor DeviceGlobal Enum
128+
set( AMREX_TESTS_SUBDIRS Amr AsyncOut CLZ CommType CTOParFor DeviceGlobal Enum
129129
MultiBlock MultiPeriod ParmParse Parser Parser2 ParserUserFn Reinit
130130
RoundoffDomain SmallMatrix)
131131

Tests/CommType/CMakeLists.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
foreach(D IN LISTS AMReX_SPACEDIM)
2+
set(_sources main.cpp)
3+
set(_input_files)
4+
5+
setup_test(${D} _sources _input_files)
6+
7+
unset(_sources)
8+
unset(_input_files)
9+
endforeach()

Tests/CommType/GNUmakefile

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
AMREX_HOME = ../../
2+
3+
DEBUG = FALSE
4+
DIM = 3
5+
COMP = gcc
6+
7+
USE_MPI = TRUE
8+
USE_OMP = FALSE
9+
USE_CUDA = FALSE
10+
USE_HIP = FALSE
11+
USE_SYCL = FALSE
12+
13+
BL_NO_FORT = TRUE
14+
15+
TINY_PROFILE = FALSE
16+
17+
include $(AMREX_HOME)/Tools/GNUMake/Make.defs
18+
19+
include ./Make.package
20+
include $(AMREX_HOME)/Src/Base/Make.package
21+
22+
include $(AMREX_HOME)/Tools/GNUMake/Make.rules

Tests/CommType/Make.package

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
CEXE_sources += main.cpp

Tests/CommType/main.cpp

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
#include <AMReX.H>
2+
#include <AMReX_Print.H>
3+
#include <AMReX_MultiFab.H>
4+
#include <AMReX_GpuComplex.H>
5+
6+
using namespace amrex;
7+
8+
int main(int argc, char* argv[])
9+
{
10+
amrex::Initialize(argc,argv);
11+
12+
int ret_code = EXIT_SUCCESS;
13+
14+
{
15+
int ncells = 128;
16+
BoxArray ba(Box(IntVect(0), IntVect(ncells-1)));
17+
ba.maxSize(32);
18+
ba.convert(IntVect(1));
19+
DistributionMapping dm(ba);
20+
21+
constexpr int ncomp = 2;
22+
IntVect nghost(2);
23+
Periodicity period{IntVect(ncells)};
24+
25+
auto value = [=] AMREX_GPU_DEVICE (int i, int j, int k, int n) -> Real
26+
{
27+
if (i < 0) {
28+
i += ncells;
29+
} else if (i >= ncells) {
30+
i -= ncells;
31+
}
32+
if (j < 0) {
33+
j += ncells;
34+
} else if (j >= ncells) {
35+
j -= ncells;
36+
}
37+
if (k < 0) {
38+
k += ncells;
39+
} else if (k >= ncells) {
40+
k -= ncells;
41+
}
42+
return n + i*ncomp + j*ncomp*ncells + k*ncomp*ncells*ncells;
43+
};
44+
45+
// Test GpuArray
46+
{
47+
using T = GpuArray<Real,ncomp>;
48+
FabArray<BaseFab<T>> fa(ba,dm,1,nghost);
49+
FabArray<BaseFab<T>> fa2(ba,dm,1,nghost);
50+
FabArray<BaseFab<T>> fa3(ba,dm,1,nghost);
51+
auto const& ma = fa.arrays();
52+
auto const& ma2 = fa2.arrays();
53+
auto const& ma3 = fa3.arrays();
54+
55+
ParallelFor(fa, IntVect(0),
56+
[=] AMREX_GPU_DEVICE (int b, int i, int j, int k)
57+
{
58+
auto const& a = ma[b];
59+
for (int n = 0; n < ncomp; ++n) {
60+
a(i,j,k)[n] = value(i,j,k,n);
61+
}
62+
});
63+
64+
fa.FillBoundary(period);
65+
66+
fa2.ParallelCopy(fa, 0, 0, 1, IntVect(0), nghost, period);
67+
68+
fa3.setVal(T{});
69+
fa3.ParallelAdd(fa, 0, 0, 1, nghost, nghost, period);
70+
71+
auto mask = OverlapMask(fa3,nghost,period);
72+
auto const& mma = mask.const_arrays();
73+
74+
auto err = ParReduce(TypeList<ReduceOpMax,ReduceOpMax,ReduceOpMax>{},
75+
TypeList<Real,Real,Real>{},
76+
fa, nghost,
77+
[=] AMREX_GPU_DEVICE (int b, int i, int j, int k)
78+
-> GpuTuple<Real,Real,Real>
79+
{
80+
Real r1 = 0, r2 = 0, r3 = 0;
81+
auto const& a1 = ma[b];
82+
auto const& a2 = ma2[b];
83+
auto const& a3 = ma3[b];
84+
auto const& m = mma[b];
85+
for (int n = 0; n < ncomp; ++n) {
86+
auto v = value(i,j,k,n);
87+
r1 = std::max(r1, std::abs(a1(i,j,k)[n] - v));
88+
r2 = std::max(r2, std::abs(a2(i,j,k)[n] - v));
89+
r3 = std::max(r3, std::abs(a3(i,j,k)[n] - v*m(i,j,k)));
90+
}
91+
return {r1, r2, r3};
92+
});
93+
94+
AMREX_ALWAYS_ASSERT(amrex::get<0>(err) == 0);
95+
AMREX_ALWAYS_ASSERT(amrex::get<1>(err) == 0);
96+
AMREX_ALWAYS_ASSERT(amrex::get<2>(err) == 0);
97+
98+
Real errmax = std::max({amrex::get<0>(err),
99+
amrex::get<1>(err),
100+
amrex::get<2>(err)});
101+
ParallelDescriptor::ReduceRealSum(errmax);
102+
if (errmax != 0) {
103+
ret_code = EXIT_FAILURE;
104+
}
105+
}
106+
107+
// Test GpuComplex
108+
{
109+
using T = GpuComplex<Real>;
110+
FabArray<BaseFab<T>> fa(ba,dm,1,nghost);
111+
FabArray<BaseFab<T>> fa2(ba,dm,1,nghost);
112+
FabArray<BaseFab<T>> fa3(ba,dm,1,nghost);
113+
auto const& ma = fa.arrays();
114+
auto const& ma2 = fa2.arrays();
115+
auto const& ma3 = fa3.arrays();
116+
117+
ParallelFor(fa, IntVect(0),
118+
[=] AMREX_GPU_DEVICE (int b, int i, int j, int k)
119+
{
120+
auto const& a = ma[b];
121+
a(i,j,k) = T{value(i,j,k,0),value(i,j,k,1)};
122+
});
123+
124+
fa.FillBoundary(period);
125+
126+
fa2.ParallelCopy(fa, 0, 0, 1, IntVect(0), nghost, period);
127+
128+
fa3.setVal(T{});
129+
fa3.ParallelAdd(fa, 0, 0, 1, nghost, nghost, period);
130+
131+
auto mask = OverlapMask(fa3,nghost,period);
132+
auto const& mma = mask.const_arrays();
133+
134+
auto err = ParReduce(TypeList<ReduceOpMax,ReduceOpMax,ReduceOpMax>{},
135+
TypeList<Real,Real,Real>{},
136+
fa, nghost,
137+
[=] AMREX_GPU_DEVICE (int b, int i, int j, int k)
138+
-> GpuTuple<Real,Real,Real>
139+
{
140+
Real r1 = 0, r2 = 0, r3 = 0;
141+
auto const& a1 = ma[b];
142+
auto const& a2 = ma2[b];
143+
auto const& a3 = ma3[b];
144+
auto const& m = mma[b];
145+
auto v = GpuComplex{value(i,j,k,0), value(i,j,k,1)};
146+
r1 = std::max(r1, amrex::norm(a1(i,j,k) - v));
147+
r2 = std::max(r2, amrex::norm(a2(i,j,k) - v));
148+
r3 = std::max(r3, amrex::norm(a3(i,j,k) - v*Real(m(i,j,k))));
149+
return {r1, r2, r3};
150+
});
151+
152+
AMREX_ALWAYS_ASSERT(amrex::get<0>(err) == 0);
153+
AMREX_ALWAYS_ASSERT(amrex::get<1>(err) == 0);
154+
AMREX_ALWAYS_ASSERT(amrex::get<2>(err) == 0);
155+
156+
Real errmax = std::max({amrex::get<0>(err),
157+
amrex::get<1>(err),
158+
amrex::get<2>(err)});
159+
ParallelDescriptor::ReduceRealSum(errmax);
160+
if (errmax != 0) {
161+
ret_code = EXIT_FAILURE;
162+
}
163+
}
164+
}
165+
amrex::Finalize();
166+
167+
return ret_code;
168+
}

0 commit comments

Comments
 (0)