Skip to content

Commit 18bde74

Browse files
authored
Merge pull request cms-sw#43205 from fwyzard/implement_blocks_with_stride
Introduce new utilities for writing Alpaka kernels
2 parents 695beba + 8c859bc commit 18bde74

File tree

2 files changed

+469
-194
lines changed

2 files changed

+469
-194
lines changed

HeterogeneousCore/AlpakaInterface/interface/workdivision.h

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,19 @@ namespace cms::alpakatools {
7575
}
7676
}
7777

78+
/* ElementIndex
79+
*
80+
* an aggregate that containes the .global and .local indices of an element; returned by iterating over elements_in_block.
81+
*/
82+
83+
struct ElementIndex {
84+
Idx global;
85+
Idx local;
86+
};
87+
88+
/* elements_with_stride
89+
*/
90+
7891
template <typename TAcc, typename = std::enable_if_t<alpaka::isAccelerator<TAcc> and alpaka::Dim<TAcc>::value == 1>>
7992
class elements_with_stride {
8093
public:
@@ -326,6 +339,187 @@ namespace cms::alpakatools {
326339
const Vec extent_;
327340
};
328341

342+
/* blocks_with_stride
343+
*
344+
* `blocks_with_stride(acc, size)` returns a range than spans the (virtual) block indices required to cover the given
345+
* problem size.
346+
*
347+
* For example, if size is 1000 and the block size is 16, it will return the range from 1 to 62.
348+
* If the work division has more than 63 blocks, only the first 63 will perform one iteration of the loop, and the
349+
* other will exit immediately.
350+
* If the work division has less than 63 blocks, some of the blocks will perform more than one iteration, in order to
351+
* cover then whole problem space.
352+
*
353+
* All threads in a block see the same loop iterations, while threads in different blocks may see a different number
354+
* of iterations.
355+
*/
356+
357+
template <typename TAcc, typename = std::enable_if_t<alpaka::isAccelerator<TAcc> and alpaka::Dim<TAcc>::value == 1>>
358+
class blocks_with_stride {
359+
public:
360+
ALPAKA_FN_ACC inline blocks_with_stride(TAcc const& acc)
361+
: first_{alpaka::getIdx<alpaka::Grid, alpaka::Blocks>(acc)[0u]},
362+
stride_{alpaka::getWorkDiv<alpaka::Grid, alpaka::Blocks>(acc)[0u]},
363+
extent_{stride_} {}
364+
365+
// extent is the total number of elements (not blocks)
366+
ALPAKA_FN_ACC inline blocks_with_stride(TAcc const& acc, Idx extent)
367+
: first_{alpaka::getIdx<alpaka::Grid, alpaka::Blocks>(acc)[0u]},
368+
stride_{alpaka::getWorkDiv<alpaka::Grid, alpaka::Blocks>(acc)[0u]},
369+
extent_{divide_up_by(extent, alpaka::getWorkDiv<alpaka::Block, alpaka::Elems>(acc)[0u])} {}
370+
371+
class iterator {
372+
friend class blocks_with_stride;
373+
374+
ALPAKA_FN_ACC inline iterator(Idx stride, Idx extent, Idx first)
375+
: stride_{stride}, extent_{extent}, first_{std::min(first, extent)} {}
376+
377+
public:
378+
ALPAKA_FN_ACC inline Idx operator*() const { return first_; }
379+
380+
// pre-increment the iterator
381+
ALPAKA_FN_ACC inline iterator& operator++() {
382+
// increment the first-element-in-block index by the grid stride
383+
first_ += stride_;
384+
if (first_ < extent_)
385+
return *this;
386+
387+
// the iterator has reached or passed the end of the extent, clamp it to the extent
388+
first_ = extent_;
389+
return *this;
390+
}
391+
392+
// post-increment the iterator
393+
ALPAKA_FN_ACC inline iterator operator++(int) {
394+
iterator old = *this;
395+
++(*this);
396+
return old;
397+
}
398+
399+
ALPAKA_FN_ACC inline bool operator==(iterator const& other) const { return (first_ == other.first_); }
400+
401+
ALPAKA_FN_ACC inline bool operator!=(iterator const& other) const { return not(*this == other); }
402+
403+
private:
404+
// non-const to support iterator copy and assignment
405+
Idx stride_;
406+
Idx extent_;
407+
// modified by the pre/post-increment operator
408+
Idx first_;
409+
};
410+
411+
ALPAKA_FN_ACC inline iterator begin() const { return iterator(stride_, extent_, first_); }
412+
413+
ALPAKA_FN_ACC inline iterator end() const { return iterator(stride_, extent_, extent_); }
414+
415+
private:
416+
const Idx first_;
417+
const Idx stride_;
418+
const Idx extent_;
419+
};
420+
421+
/* elements_in_block
422+
*
423+
* `elements_in_block(acc, block, size)` returns a range that spans all the elements within the given block.
424+
* Iterating over the range yields values of type ElementIndex, that contain both .global and .local indices
425+
* of the corresponding element.
426+
*
427+
* If the work division has only one element per thread, the loop will perform at most one iteration.
428+
* If the work division has more than one elements per thread, the loop will perform that number of iterations,
429+
* or less if it reaches size.
430+
*/
431+
432+
template <typename TAcc, typename = std::enable_if_t<alpaka::isAccelerator<TAcc> and alpaka::Dim<TAcc>::value == 1>>
433+
class elements_in_block {
434+
public:
435+
ALPAKA_FN_ACC inline elements_in_block(TAcc const& acc, Idx block)
436+
: first_{block * alpaka::getWorkDiv<alpaka::Block, alpaka::Elems>(acc)[0u]},
437+
local_{alpaka::getIdx<alpaka::Block, alpaka::Threads>(acc)[0u] *
438+
alpaka::getWorkDiv<alpaka::Thread, alpaka::Elems>(acc)[0u]},
439+
range_{local_ + alpaka::getWorkDiv<alpaka::Thread, alpaka::Elems>(acc)[0u]} {}
440+
441+
ALPAKA_FN_ACC inline elements_in_block(TAcc const& acc, Idx block, Idx extent)
442+
: first_{block * alpaka::getWorkDiv<alpaka::Block, alpaka::Elems>(acc)[0u]},
443+
local_{std::min(extent - first_,
444+
alpaka::getIdx<alpaka::Block, alpaka::Threads>(acc)[0u] *
445+
alpaka::getWorkDiv<alpaka::Thread, alpaka::Elems>(acc)[0u])},
446+
range_{std::min(extent - first_, local_ + alpaka::getWorkDiv<alpaka::Thread, alpaka::Elems>(acc)[0u])} {}
447+
448+
class iterator {
449+
friend class elements_in_block;
450+
451+
ALPAKA_FN_ACC inline iterator(Idx local, Idx first, Idx range) : index_{local}, first_{first}, range_{range} {}
452+
453+
public:
454+
ALPAKA_FN_ACC inline ElementIndex operator*() const { return ElementIndex{index_ + first_, index_}; }
455+
456+
// pre-increment the iterator
457+
ALPAKA_FN_ACC inline iterator& operator++() {
458+
if constexpr (requires_single_thread_per_block_v<TAcc>) {
459+
// increment the index along the elements processed by the current thread
460+
++index_;
461+
if (index_ < range_)
462+
return *this;
463+
}
464+
465+
// the iterator has reached or passed the end of the extent, clamp it to the extent
466+
index_ = range_;
467+
return *this;
468+
}
469+
470+
// post-increment the iterator
471+
ALPAKA_FN_ACC inline iterator operator++(int) {
472+
iterator old = *this;
473+
++(*this);
474+
return old;
475+
}
476+
477+
ALPAKA_FN_ACC inline bool operator==(iterator const& other) const { return (index_ == other.index_); }
478+
479+
ALPAKA_FN_ACC inline bool operator!=(iterator const& other) const { return not(*this == other); }
480+
481+
private:
482+
// modified by the pre/post-increment operator
483+
Idx index_;
484+
// non-const to support iterator copy and assignment
485+
Idx first_;
486+
Idx range_;
487+
};
488+
489+
ALPAKA_FN_ACC inline iterator begin() const { return iterator(local_, first_, range_); }
490+
491+
ALPAKA_FN_ACC inline iterator end() const { return iterator(range_, first_, range_); }
492+
493+
private:
494+
const Idx first_;
495+
const Idx local_;
496+
const Idx range_;
497+
};
498+
499+
/* once_per_grid
500+
*
501+
* `once_per_grid(acc)` returns true for a single thread within the kernel execution grid.
502+
*
503+
* Usually the condition is true for block 0 and thread 0, but these indices should not be relied upon.
504+
*/
505+
506+
template <typename TAcc, typename = std::enable_if_t<alpaka::isAccelerator<TAcc>>>
507+
ALPAKA_FN_ACC inline constexpr bool once_per_grid(TAcc const& acc) {
508+
return alpaka::getIdx<alpaka::Grid, alpaka::Threads>(acc) == Vec<alpaka::Dim<TAcc>>::zeros();
509+
}
510+
511+
/* once_per_block
512+
*
513+
* `once_per_block(acc)` returns true for a single thread within the block.
514+
*
515+
* Usually the condition is true for thread 0, but this index should not be relied upon.
516+
*/
517+
518+
template <typename TAcc, typename = std::enable_if_t<alpaka::isAccelerator<TAcc>>>
519+
ALPAKA_FN_ACC inline constexpr bool once_per_block(TAcc const& acc) {
520+
return alpaka::getIdx<alpaka::Block, alpaka::Threads>(acc) == Vec<alpaka::Dim<TAcc>>::zeros();
521+
}
522+
329523
} // namespace cms::alpakatools
330524

331525
#endif // HeterogeneousCore_AlpakaInterface_interface_workdivision_h

0 commit comments

Comments
 (0)