Skip to content

Commit 7272683

Browse files
committed
Add the blocks_with_stride and elements_in_block ranges
`blocks_with_stride(acc, size)` returns a range than spans the (virtual) block indices required to cover the given problem size. For example, if size is 1000 and the block size is 16, it will return the range from 1 to 62. If the work division has more than 63 blocks, only the first 63 will perform one iteration of the loop, and the other will exit immediately. if the work division has less than 63 blocks, some of the blocks will perform more than one iteration, in order to cover then whole problem space. All threads in a block see the same loop iterations, while threads in different blocks may see a different number of iterations. `elements_in_block(acc, block, size)` returns a range that spans all the elements within the given block. Iterating over the range yields values of type ElementIndex, that contain both .global and .local indices of the corresponding element. If the work division has only one element per thread, the loop will perform at most one iteration. If the work division has more than one elements per thread, the loop will perform that number of iterations, or less if it reaches size.
1 parent 3eed8ad commit 7272683

File tree

1 file changed

+170
-0
lines changed

1 file changed

+170
-0
lines changed

HeterogeneousCore/AlpakaInterface/interface/workdivision.h

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,19 @@ namespace cms::alpakatools {
7575
}
7676
}
7777

78+
/* ElementIndex
79+
*
80+
* an aggregate that containes the .global and .local indices of an element; returned by iterating over elements_in_block.
81+
*/
82+
83+
struct ElementIndex {
84+
Idx global;
85+
Idx local;
86+
};
87+
88+
/* elements_with_stride
89+
*/
90+
7891
template <typename TAcc, typename = std::enable_if_t<alpaka::isAccelerator<TAcc> and alpaka::Dim<TAcc>::value == 1>>
7992
class elements_with_stride {
8093
public:
@@ -326,6 +339,163 @@ namespace cms::alpakatools {
326339
const Vec extent_;
327340
};
328341

342+
/* blocks_with_stride
343+
*
344+
* `blocks_with_stride(acc, size)` returns a range than spans the (virtual) block indices required to cover the given
345+
* problem size.
346+
*
347+
* For example, if size is 1000 and the block size is 16, it will return the range from 1 to 62.
348+
* If the work division has more than 63 blocks, only the first 63 will perform one iteration of the loop, and the
349+
* other will exit immediately.
350+
* If the work division has less than 63 blocks, some of the blocks will perform more than one iteration, in order to
351+
* cover then whole problem space.
352+
*
353+
* All threads in a block see the same loop iterations, while threads in different blocks may see a different number
354+
* of iterations.
355+
*/
356+
357+
template <typename TAcc, typename = std::enable_if_t<alpaka::isAccelerator<TAcc> and alpaka::Dim<TAcc>::value == 1>>
358+
class blocks_with_stride {
359+
public:
360+
ALPAKA_FN_ACC inline blocks_with_stride(TAcc const& acc)
361+
: first_{alpaka::getIdx<alpaka::Grid, alpaka::Blocks>(acc)[0u]},
362+
stride_{alpaka::getWorkDiv<alpaka::Grid, alpaka::Blocks>(acc)[0u]},
363+
extent_{stride_} {}
364+
365+
// extent is the total number of elements (not blocks)
366+
ALPAKA_FN_ACC inline blocks_with_stride(TAcc const& acc, Idx extent)
367+
: first_{alpaka::getIdx<alpaka::Grid, alpaka::Blocks>(acc)[0u]},
368+
stride_{alpaka::getWorkDiv<alpaka::Grid, alpaka::Blocks>(acc)[0u]},
369+
extent_{divide_up_by(extent, alpaka::getWorkDiv<alpaka::Block, alpaka::Elems>(acc)[0u])} {}
370+
371+
class iterator {
372+
friend class blocks_with_stride;
373+
374+
ALPAKA_FN_ACC inline iterator(Idx stride, Idx extent, Idx first)
375+
: stride_{stride}, extent_{extent}, first_{std::min(first, extent)} {}
376+
377+
public:
378+
ALPAKA_FN_ACC inline Idx operator*() const { return first_; }
379+
380+
// pre-increment the iterator
381+
ALPAKA_FN_ACC inline iterator& operator++() {
382+
// increment the first-element-in-block index by the grid stride
383+
first_ += stride_;
384+
if (first_ < extent_)
385+
return *this;
386+
387+
// the iterator has reached or passed the end of the extent, clamp it to the extent
388+
first_ = extent_;
389+
return *this;
390+
}
391+
392+
// post-increment the iterator
393+
ALPAKA_FN_ACC inline iterator operator++(int) {
394+
iterator old = *this;
395+
++(*this);
396+
return old;
397+
}
398+
399+
ALPAKA_FN_ACC inline bool operator==(iterator const& other) const { return (first_ == other.first_); }
400+
401+
ALPAKA_FN_ACC inline bool operator!=(iterator const& other) const { return not(*this == other); }
402+
403+
private:
404+
// non-const to support iterator copy and assignment
405+
Idx stride_;
406+
Idx extent_;
407+
// modified by the pre/post-increment operator
408+
Idx first_;
409+
};
410+
411+
ALPAKA_FN_ACC inline iterator begin() const { return iterator(stride_, extent_, first_); }
412+
413+
ALPAKA_FN_ACC inline iterator end() const { return iterator(stride_, extent_, extent_); }
414+
415+
private:
416+
const Idx first_;
417+
const Idx stride_;
418+
const Idx extent_;
419+
};
420+
421+
/* elements_in_block
422+
*
423+
* `elements_in_block(acc, block, size)` returns a range that spans all the elements within the given block.
424+
* Iterating over the range yields values of type ElementIndex, that contain both .global and .local indices
425+
* of the corresponding element.
426+
*
427+
* If the work division has only one element per thread, the loop will perform at most one iteration.
428+
* If the work division has more than one elements per thread, the loop will perform that number of iterations,
429+
* or less if it reaches size.
430+
*/
431+
432+
template <typename TAcc, typename = std::enable_if_t<alpaka::isAccelerator<TAcc> and alpaka::Dim<TAcc>::value == 1>>
433+
class elements_in_block {
434+
public:
435+
ALPAKA_FN_ACC inline elements_in_block(TAcc const& acc, Idx block)
436+
: first_{block * alpaka::getWorkDiv<alpaka::Block, alpaka::Elems>(acc)[0u]},
437+
local_{alpaka::getIdx<alpaka::Block, alpaka::Threads>(acc)[0u] *
438+
alpaka::getWorkDiv<alpaka::Thread, alpaka::Elems>(acc)[0u]},
439+
range_{local_ + alpaka::getWorkDiv<alpaka::Thread, alpaka::Elems>(acc)[0u]} {}
440+
441+
ALPAKA_FN_ACC inline elements_in_block(TAcc const& acc, Idx block, Idx extent)
442+
: first_{block * alpaka::getWorkDiv<alpaka::Block, alpaka::Elems>(acc)[0u]},
443+
local_{std::min(extent - first_,
444+
alpaka::getIdx<alpaka::Block, alpaka::Threads>(acc)[0u] *
445+
alpaka::getWorkDiv<alpaka::Thread, alpaka::Elems>(acc)[0u])},
446+
range_{std::min(extent - first_, local_ + alpaka::getWorkDiv<alpaka::Thread, alpaka::Elems>(acc)[0u])} {}
447+
448+
class iterator {
449+
friend class elements_in_block;
450+
451+
ALPAKA_FN_ACC inline iterator(Idx local, Idx first, Idx range) : index_{local}, first_{first}, range_{range} {}
452+
453+
public:
454+
ALPAKA_FN_ACC inline ElementIndex operator*() const { return ElementIndex{index_ + first_, index_}; }
455+
456+
// pre-increment the iterator
457+
ALPAKA_FN_ACC inline iterator& operator++() {
458+
if constexpr (requires_single_thread_per_block_v<TAcc>) {
459+
// increment the index along the elements processed by the current thread
460+
++index_;
461+
if (index_ < range_)
462+
return *this;
463+
}
464+
465+
// the iterator has reached or passed the end of the extent, clamp it to the extent
466+
index_ = range_;
467+
return *this;
468+
}
469+
470+
// post-increment the iterator
471+
ALPAKA_FN_ACC inline iterator operator++(int) {
472+
iterator old = *this;
473+
++(*this);
474+
return old;
475+
}
476+
477+
ALPAKA_FN_ACC inline bool operator==(iterator const& other) const { return (index_ == other.index_); }
478+
479+
ALPAKA_FN_ACC inline bool operator!=(iterator const& other) const { return not(*this == other); }
480+
481+
private:
482+
// modified by the pre/post-increment operator
483+
Idx index_;
484+
// non-const to support iterator copy and assignment
485+
Idx first_;
486+
Idx range_;
487+
};
488+
489+
ALPAKA_FN_ACC inline iterator begin() const { return iterator(local_, first_, range_); }
490+
491+
ALPAKA_FN_ACC inline iterator end() const { return iterator(range_, first_, range_); }
492+
493+
private:
494+
const Idx first_;
495+
const Idx local_;
496+
const Idx range_;
497+
};
498+
329499
} // namespace cms::alpakatools
330500

331501
#endif // HeterogeneousCore_AlpakaInterface_interface_workdivision_h

0 commit comments

Comments
 (0)