Skip to content

Commit 064db4e

Browse files
authored
Fix Resize Issue of Fab with the Async Arena (#3663)
## Summary Previously there was an issue with resizing Fabs using The_Async_Arena. The issue was the previous allocation during resize might be done on a different stream. This commit fixes the issue and makes the following patterns work. FArrayBox tmp0(The_Async_Arena()); FArrayBox tmp1(The_Async_Arena()); FArrayBox tmp2; for (MFIter ...) { tmp0.resize(box,ncomp,The_Async_Arena()); tmp1.resize(box,ncomp); tmp2.resize(box,ncomp,The_Async_Arena()); } ## Additional background AMReX-Astro/Castro#2677 ## Checklist The proposed changes: - [x] fix a bug or incorrect behavior in AMReX - [ ] add new capabilities to AMReX - [ ] changes answers in the test suite to more than roundoff level - [ ] are likely to significantly affect the results of downstream AMReX users - [ ] include documentation in the code and/or rst files, if appropriate
1 parent ecaa46d commit 064db4e

File tree

4 files changed

+35
-1
lines changed

4 files changed

+35
-1
lines changed

Src/Base/AMReX_Arena.H

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,11 @@ public:
157157
*/
158158
virtual void registerForProfiling (const std::string& memory_name);
159159

160+
#ifdef AMREX_USE_GPU
161+
//! Is this GPU stream ordered memory allocator?
162+
[[nodiscard]] virtual bool isStreamOrderedArena () const { return false; }
163+
#endif
164+
160165
/**
161166
* \brief Given a minimum required arena size of sz bytes, this returns
162167
* the next largest arena size that will align to align_size bytes

Src/Base/AMReX_BaseFab.H

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1631,6 +1631,9 @@ protected:
16311631
Long truesize = 0L; //!< nvar*numpts that was allocated on heap.
16321632
bool ptr_owner = false; //!< Owner of T*?
16331633
bool shared_memory = false; //!< Is the memory allocated in shared memory?
1634+
#ifdef AMREX_USE_GPU
1635+
gpuStream_t alloc_stream{};
1636+
#endif
16341637
};
16351638

16361639
template <class T>
@@ -1902,6 +1905,9 @@ BaseFab<T>::define ()
19021905
this->truesize = this->nvar*this->domain.numPts();
19031906
this->ptr_owner = true;
19041907
this->dptr = static_cast<T*>(this->alloc(this->truesize*sizeof(T)));
1908+
#ifdef AMREX_USE_GPU
1909+
this->alloc_stream = Gpu::gpuStream();
1910+
#endif
19051911

19061912
placementNew(this->dptr, this->truesize);
19071913

@@ -2003,6 +2009,9 @@ BaseFab<T>::BaseFab (BaseFab<T>&& rhs) noexcept
20032009
dptr(rhs.dptr), domain(rhs.domain),
20042010
nvar(rhs.nvar), truesize(rhs.truesize),
20052011
ptr_owner(rhs.ptr_owner), shared_memory(rhs.shared_memory)
2012+
#ifdef AMREX_USE_GPU
2013+
, alloc_stream(rhs.alloc_stream)
2014+
#endif
20062015
{
20072016
rhs.dptr = nullptr;
20082017
rhs.ptr_owner = false;
@@ -2021,6 +2030,9 @@ BaseFab<T>::operator= (BaseFab<T>&& rhs) noexcept
20212030
truesize = rhs.truesize;
20222031
ptr_owner = rhs.ptr_owner;
20232032
shared_memory = rhs.shared_memory;
2033+
#ifdef AMREX_USE_GPU
2034+
alloc_stream = rhs.alloc_stream;
2035+
#endif
20242036

20252037
rhs.dptr = nullptr;
20262038
rhs.ptr_owner = false;
@@ -2062,7 +2074,11 @@ BaseFab<T>::resize (const Box& b, int n, Arena* ar)
20622074
this->dptr = nullptr;
20632075
define();
20642076
}
2065-
else if (this->nvar*this->domain.numPts() > this->truesize)
2077+
else if (this->nvar*this->domain.numPts() > this->truesize
2078+
#ifdef AMREX_USE_GPU
2079+
|| (arena()->isStreamOrderedArena() && alloc_stream != Gpu::gpuStream())
2080+
#endif
2081+
)
20662082
{
20672083
if (this->shared_memory) {
20682084
amrex::Abort("BaseFab::resize: BaseFab in shared memory cannot increase size");
@@ -2114,7 +2130,14 @@ BaseFab<T>::clear () noexcept
21142130

21152131
placementDelete(this->dptr, this->truesize);
21162132

2133+
#ifdef AMREX_USE_GPU
2134+
auto current_stream = Gpu::Device::gpuStream();
2135+
Gpu::Device::setStream(alloc_stream);
2136+
#endif
21172137
this->free(this->dptr);
2138+
#ifdef AMREX_USE_GPU
2139+
Gpu::Device::setStream(current_stream);
2140+
#endif
21182141

21192142
if (this->nvar > 1) {
21202143
amrex::update_fab_stats(-this->truesize/this->nvar, -this->truesize, sizeof(T));

Src/Base/AMReX_GpuTypes.H

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ struct Dim1 {
2929
struct gpuStream_t {
3030
sycl::queue* queue = nullptr;
3131
bool operator== (gpuStream_t const& rhs) noexcept { return queue == rhs.queue; }
32+
bool operator!= (gpuStream_t const& rhs) noexcept { return queue != rhs.queue; }
3233
};
3334

3435
#endif

Src/Base/AMReX_PArena.H

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@ public:
3838
[[nodiscard]] bool isDevice () const final;
3939
[[nodiscard]] bool isPinned () const final;
4040

41+
#ifdef AMREX_USE_GPU
42+
//! Is this CUDA stream ordered memory allocator?
43+
[[nodiscard]] bool isStreamOrderedArena () const final { return true; }
44+
#endif
45+
4146
#ifdef AMREX_CUDA_GE_11_2
4247
private:
4348
cudaMemPool_t m_pool;

0 commit comments

Comments
 (0)