3030
3131#include < alpaka/core/Common.hpp>
3232#include < alpaka/core/Positioning.hpp>
33+ #include < alpaka/extent/Traits.hpp>
3334#include < alpaka/idx/Accessors.hpp>
35+ #include < alpaka/idx/MapIdx.hpp>
3436#include < alpaka/kernel/Traits.hpp>
3537#include < alpaka/mem/fence/Traits.hpp>
3638#include < alpaka/mem/view/Traits.hpp>
3739#include < alpaka/mem/view/ViewPlainPtr.hpp>
3840#include < alpaka/vec/Vec.hpp>
41+ #include < alpaka/workdiv/Traits.hpp>
42+ #include < alpaka/workdiv/WorkDivHelpers.hpp>
3943#include < alpaka/workdiv/WorkDivMembers.hpp>
4044
4145#include < sys/types.h>
@@ -86,14 +90,38 @@ namespace mallocMC::CreationPolicies::FlatterScatterAlloc
8690 MyAccessBlock* accessBlocks{};
8791 uint32_t volatile block = 0U ;
8892
89- ALPAKA_FN_INLINE ALPAKA_FN_ACC auto init () -> void
93+ ALPAKA_FN_INLINE ALPAKA_FN_ACC static auto init (auto const & acc, void * accessBlocksPointer, auto heapSize)
94+ -> void
9095 {
91- for (uint32_t i = 0 ; i < numBlocks (); ++i)
96+ auto threadsInGrid = alpaka::getWorkDiv<alpaka::Grid, alpaka::Threads>(acc);
97+ auto numThreads = threadsInGrid.prod ();
98+ auto const [idx] = alpaka::mapIdx<1U >(alpaka::getIdx<alpaka::Grid, alpaka::Threads>(acc), threadsInGrid);
99+ auto * accessBlocks = static_cast <MyAccessBlock*>(accessBlocksPointer);
100+
101+ for (uint32_t i = idx; i < numBlocks (heapSize) * MyAccessBlock::numPages (); i += numThreads)
92102 {
93- accessBlocks[i].init ();
103+ auto blockIdx = i / MyAccessBlock::numPages ();
104+ auto pageIdx = i % MyAccessBlock::numPages ();
105+
106+ accessBlocks[blockIdx].init (acc, pageIdx);
94107 }
95108 }
96109
110+ ALPAKA_FN_INLINE ALPAKA_FN_ACC auto init (auto const & acc) -> void
111+ {
112+ init (acc, accessBlocks, heapSize);
113+ }
114+
115+ /* *
116+ * @brief Number of access blocks assuming the given heapSize.
117+ *
118+ * @return Number of access blocks in the heap.
119+ */
120+ ALPAKA_FN_INLINE ALPAKA_FN_ACC static constexpr auto numBlocks (auto heapSize) -> uint32_t
121+ {
122+ return heapSize / T_HeapConfig::accessblocksize;
123+ }
124+
97125 /* *
98126 * @brief Number of access blocks in the heap. This is a runtime quantity because it depends on the given heap
99127 * size.
@@ -102,7 +130,7 @@ namespace mallocMC::CreationPolicies::FlatterScatterAlloc
102130 */
103131 ALPAKA_FN_INLINE ALPAKA_FN_ACC auto numBlocks () const -> uint32_t
104132 {
105- return heapSize / T_HeapConfig::accessblocksize ;
133+ return numBlocks ( heapSize) ;
106134 }
107135
108136 /* *
@@ -307,15 +335,22 @@ namespace mallocMC::CreationPolicies::FlatterScatterAlloc
307335 {
308336 template <typename T_HeapConfig, typename T_HashConfig, typename T_AlignmentPolicy>
309337 ALPAKA_FN_INLINE ALPAKA_FN_ACC auto operator ()(
310- auto const & /* unused */ ,
338+ auto const & acc ,
311339 Heap<T_HeapConfig, T_HashConfig, T_AlignmentPolicy>* m_heap,
312340 void * m_heapmem,
313341 size_t const m_memsize) const
314342 {
315- m_heap->accessBlocks
316- = static_cast <Heap<T_HeapConfig, T_HashConfig, T_AlignmentPolicy>::MyAccessBlock*>(m_heapmem);
317- m_heap->heapSize = m_memsize;
318- m_heap->init ();
343+ auto const idx = alpaka::mapIdx<1U >(
344+ alpaka::getIdx<alpaka::Grid, alpaka::Threads>(acc),
345+ alpaka::getWorkDiv<alpaka::Grid, alpaka::Threads>(acc));
346+ if (idx == 0 )
347+ {
348+ m_heap->accessBlocks
349+ = static_cast <Heap<T_HeapConfig, T_HashConfig, T_AlignmentPolicy>::MyAccessBlock*>(m_heapmem);
350+ m_heap->heapSize = m_memsize;
351+ }
352+ // We can't rely on thread 0 to finish the above before we start, so we use the static version:
353+ Heap<T_HeapConfig, T_HashConfig, T_AlignmentPolicy>::init (acc, m_heapmem, m_memsize);
319354 }
320355 };
321356
@@ -374,13 +409,15 @@ namespace mallocMC::CreationPolicies
374409 template <typename TAcc>
375410 static void initHeap ([[maybe_unused]] auto & dev, auto & queue, auto * heap, void * pool, size_t memsize)
376411 {
377- using Dim = typename alpaka::trait::DimType<TAcc>::type;
378- using Idx = typename alpaka::trait::IdxType<TAcc>::type;
379- using VecType = alpaka::Vec<Dim, Idx>;
380-
381- auto workDivSingleThread
382- = alpaka::WorkDivMembers<Dim, Idx>{VecType::ones (), VecType::ones (), VecType::ones ()};
383- alpaka::exec<TAcc>(queue, workDivSingleThread, FlatterScatterAlloc::InitKernel{}, heap, pool, memsize);
412+ using MyHeap = FlatterScatterAlloc::Heap<T_HeapConfig, T_HashConfig, T_AlignmentPolicy>;
413+ auto numBlocks = MyHeap::numBlocks (memsize);
414+ auto numPagesPerBlock = MyHeap::MyAccessBlock::numPages ();
415+
416+ alpaka::KernelCfg<TAcc> const kernelCfg
417+ = {numBlocks * numPagesPerBlock, 1U , false , alpaka::GridBlockExtentSubDivRestrictions::Unrestricted};
418+ auto workDiv
419+ = alpaka::getValidWorkDiv (kernelCfg, dev, FlatterScatterAlloc::InitKernel{}, heap, pool, memsize);
420+ alpaka::exec<TAcc>(queue, workDiv, FlatterScatterAlloc::InitKernel{}, heap, pool, memsize);
384421 alpaka::wait (queue);
385422 }
386423
0 commit comments