@@ -75,6 +75,19 @@ namespace cms::alpakatools {
7575 }
7676 }
7777
78+ /* ElementIndex
79+ *
80+ * an aggregate that containes the .global and .local indices of an element; returned by iterating over elements_in_block.
81+ */
82+
83+ struct ElementIndex {
84+ Idx global;
85+ Idx local;
86+ };
87+
88+ /* elements_with_stride
89+ */
90+
7891 template <typename TAcc, typename = std::enable_if_t <alpaka::isAccelerator<TAcc> and alpaka::Dim<TAcc>::value == 1 >>
7992 class elements_with_stride {
8093 public:
@@ -326,6 +339,187 @@ namespace cms::alpakatools {
326339 const Vec extent_;
327340 };
328341
342+ /* blocks_with_stride
343+ *
344+ * `blocks_with_stride(acc, size)` returns a range than spans the (virtual) block indices required to cover the given
345+ * problem size.
346+ *
347+ * For example, if size is 1000 and the block size is 16, it will return the range from 1 to 62.
348+ * If the work division has more than 63 blocks, only the first 63 will perform one iteration of the loop, and the
349+ * other will exit immediately.
350+ * If the work division has less than 63 blocks, some of the blocks will perform more than one iteration, in order to
351+ * cover then whole problem space.
352+ *
353+ * All threads in a block see the same loop iterations, while threads in different blocks may see a different number
354+ * of iterations.
355+ */
356+
357+ template <typename TAcc, typename = std::enable_if_t <alpaka::isAccelerator<TAcc> and alpaka::Dim<TAcc>::value == 1 >>
358+ class blocks_with_stride {
359+ public:
360+ ALPAKA_FN_ACC inline blocks_with_stride (TAcc const & acc)
361+ : first_{alpaka::getIdx<alpaka::Grid, alpaka::Blocks>(acc)[0u ]},
362+ stride_{alpaka::getWorkDiv<alpaka::Grid, alpaka::Blocks>(acc)[0u ]},
363+ extent_{stride_} {}
364+
365+ // extent is the total number of elements (not blocks)
366+ ALPAKA_FN_ACC inline blocks_with_stride (TAcc const & acc, Idx extent)
367+ : first_{alpaka::getIdx<alpaka::Grid, alpaka::Blocks>(acc)[0u ]},
368+ stride_{alpaka::getWorkDiv<alpaka::Grid, alpaka::Blocks>(acc)[0u ]},
369+ extent_{divide_up_by (extent, alpaka::getWorkDiv<alpaka::Block, alpaka::Elems>(acc)[0u ])} {}
370+
371+ class iterator {
372+ friend class blocks_with_stride ;
373+
374+ ALPAKA_FN_ACC inline iterator (Idx stride, Idx extent, Idx first)
375+ : stride_{stride}, extent_{extent}, first_{std::min (first, extent)} {}
376+
377+ public:
378+ ALPAKA_FN_ACC inline Idx operator *() const { return first_; }
379+
380+ // pre-increment the iterator
381+ ALPAKA_FN_ACC inline iterator& operator ++() {
382+ // increment the first-element-in-block index by the grid stride
383+ first_ += stride_;
384+ if (first_ < extent_)
385+ return *this ;
386+
387+ // the iterator has reached or passed the end of the extent, clamp it to the extent
388+ first_ = extent_;
389+ return *this ;
390+ }
391+
392+ // post-increment the iterator
393+ ALPAKA_FN_ACC inline iterator operator ++(int ) {
394+ iterator old = *this ;
395+ ++(*this );
396+ return old;
397+ }
398+
399+ ALPAKA_FN_ACC inline bool operator ==(iterator const & other) const { return (first_ == other.first_ ); }
400+
401+ ALPAKA_FN_ACC inline bool operator !=(iterator const & other) const { return not (*this == other); }
402+
403+ private:
404+ // non-const to support iterator copy and assignment
405+ Idx stride_;
406+ Idx extent_;
407+ // modified by the pre/post-increment operator
408+ Idx first_;
409+ };
410+
411+ ALPAKA_FN_ACC inline iterator begin () const { return iterator (stride_, extent_, first_); }
412+
413+ ALPAKA_FN_ACC inline iterator end () const { return iterator (stride_, extent_, extent_); }
414+
415+ private:
416+ const Idx first_;
417+ const Idx stride_;
418+ const Idx extent_;
419+ };
420+
421+ /* elements_in_block
422+ *
423+ * `elements_in_block(acc, block, size)` returns a range that spans all the elements within the given block.
424+ * Iterating over the range yields values of type ElementIndex, that contain both .global and .local indices
425+ * of the corresponding element.
426+ *
427+ * If the work division has only one element per thread, the loop will perform at most one iteration.
428+ * If the work division has more than one elements per thread, the loop will perform that number of iterations,
429+ * or less if it reaches size.
430+ */
431+
432+ template <typename TAcc, typename = std::enable_if_t <alpaka::isAccelerator<TAcc> and alpaka::Dim<TAcc>::value == 1 >>
433+ class elements_in_block {
434+ public:
435+ ALPAKA_FN_ACC inline elements_in_block (TAcc const & acc, Idx block)
436+ : first_{block * alpaka::getWorkDiv<alpaka::Block, alpaka::Elems>(acc)[0u ]},
437+ local_{alpaka::getIdx<alpaka::Block, alpaka::Threads>(acc)[0u ] *
438+ alpaka::getWorkDiv<alpaka::Thread, alpaka::Elems>(acc)[0u ]},
439+ range_{local_ + alpaka::getWorkDiv<alpaka::Thread, alpaka::Elems>(acc)[0u ]} {}
440+
441+ ALPAKA_FN_ACC inline elements_in_block (TAcc const & acc, Idx block, Idx extent)
442+ : first_{block * alpaka::getWorkDiv<alpaka::Block, alpaka::Elems>(acc)[0u ]},
443+ local_{std::min (extent - first_,
444+ alpaka::getIdx<alpaka::Block, alpaka::Threads>(acc)[0u ] *
445+ alpaka::getWorkDiv<alpaka::Thread, alpaka::Elems>(acc)[0u ])},
446+ range_{std::min (extent - first_, local_ + alpaka::getWorkDiv<alpaka::Thread, alpaka::Elems>(acc)[0u ])} {}
447+
448+ class iterator {
449+ friend class elements_in_block ;
450+
451+ ALPAKA_FN_ACC inline iterator (Idx local, Idx first, Idx range) : index_{local}, first_{first}, range_{range} {}
452+
453+ public:
454+ ALPAKA_FN_ACC inline ElementIndex operator *() const { return ElementIndex{index_ + first_, index_}; }
455+
456+ // pre-increment the iterator
457+ ALPAKA_FN_ACC inline iterator& operator ++() {
458+ if constexpr (requires_single_thread_per_block_v<TAcc>) {
459+ // increment the index along the elements processed by the current thread
460+ ++index_;
461+ if (index_ < range_)
462+ return *this ;
463+ }
464+
465+ // the iterator has reached or passed the end of the extent, clamp it to the extent
466+ index_ = range_;
467+ return *this ;
468+ }
469+
470+ // post-increment the iterator
471+ ALPAKA_FN_ACC inline iterator operator ++(int ) {
472+ iterator old = *this ;
473+ ++(*this );
474+ return old;
475+ }
476+
477+ ALPAKA_FN_ACC inline bool operator ==(iterator const & other) const { return (index_ == other.index_ ); }
478+
479+ ALPAKA_FN_ACC inline bool operator !=(iterator const & other) const { return not (*this == other); }
480+
481+ private:
482+ // modified by the pre/post-increment operator
483+ Idx index_;
484+ // non-const to support iterator copy and assignment
485+ Idx first_;
486+ Idx range_;
487+ };
488+
489+ ALPAKA_FN_ACC inline iterator begin () const { return iterator (local_, first_, range_); }
490+
491+ ALPAKA_FN_ACC inline iterator end () const { return iterator (range_, first_, range_); }
492+
493+ private:
494+ const Idx first_;
495+ const Idx local_;
496+ const Idx range_;
497+ };
498+
499+ /* once_per_grid
500+ *
501+ * `once_per_grid(acc)` returns true for a single thread within the kernel execution grid.
502+ *
503+ * Usually the condition is true for block 0 and thread 0, but these indices should not be relied upon.
504+ */
505+
506+ template <typename TAcc, typename = std::enable_if_t <alpaka::isAccelerator<TAcc>>>
507+ ALPAKA_FN_ACC inline constexpr bool once_per_grid (TAcc const & acc) {
508+ return alpaka::getIdx<alpaka::Grid, alpaka::Threads>(acc) == Vec<alpaka::Dim<TAcc>>::zeros ();
509+ }
510+
511+ /* once_per_block
512+ *
513+ * `once_per_block(acc)` returns true for a single thread within the block.
514+ *
515+ * Usually the condition is true for thread 0, but this index should not be relied upon.
516+ */
517+
518+ template <typename TAcc, typename = std::enable_if_t <alpaka::isAccelerator<TAcc>>>
519+ ALPAKA_FN_ACC inline constexpr bool once_per_block (TAcc const & acc) {
520+ return alpaka::getIdx<alpaka::Block, alpaka::Threads>(acc) == Vec<alpaka::Dim<TAcc>>::zeros ();
521+ }
522+
329523} // namespace cms::alpakatools
330524
331525#endif // HeterogeneousCore_AlpakaInterface_interface_workdivision_h
0 commit comments