@@ -698,7 +698,7 @@ template <typename Ty> class StaticLoopChunker {
698698 static void NormalizedLoopNestNoChunk (void (*LoopBody)(Ty, void *), void *Arg,
699699 Ty NumBlocks, Ty BId, Ty NumThreads,
700700 Ty TId, Ty NumIters,
701- bool OneIterationPerThread) {
701+ uint8_t OneIterationPerThread) {
702702 Ty KernelIteration = NumBlocks * NumThreads;
703703
704704 // Start index in the normalized space.
@@ -729,7 +729,7 @@ template <typename Ty> class StaticLoopChunker {
729729 Ty BlockChunk, Ty NumBlocks, Ty BId,
730730 Ty ThreadChunk, Ty NumThreads, Ty TId,
731731 Ty NumIters,
732- bool OneIterationPerThread) {
732+ uint8_t OneIterationPerThread) {
733733 Ty KernelIteration = NumBlocks * BlockChunk;
734734
735735 // Start index in the chunked space.
@@ -767,8 +767,18 @@ template <typename Ty> class StaticLoopChunker {
767767
768768public:
769769 // / Worksharing `for`-loop.
770+ // / \param[in] Loc Description of source location
771+ // / \param[in] LoopBody Function which corresponds to loop body
772+ // / \param[in] Arg Pointer to struct which contains loop body args
773+ // / \param[in] NumIters Number of loop iterations
774+ // / \param[in] NumThreads Number of GPU threads
775+ // / \param[in] ThreadChunk Size of thread chunk
776+ // / \param[in] OneIterationPerThread If true/nonzero, each thread executes
777+ // / only one loop iteration or one thread chunk. This avoids an outer loop
778+ // / over all loop iterations/chunks.
770779 static void For (IdentTy *Loc, void (*LoopBody)(Ty, void *), void *Arg,
771- Ty NumIters, Ty NumThreads, Ty ThreadChunk) {
780+ Ty NumIters, Ty NumThreads, Ty ThreadChunk,
781+ uint8_t OneIterationPerThread) {
772782 ASSERT (NumIters >= 0 , " Bad iteration count" );
773783 ASSERT (ThreadChunk >= 0 , " Bad thread count" );
774784
@@ -790,12 +800,13 @@ template <typename Ty> class StaticLoopChunker {
790800
791801 // If we know we have more threads than iterations we can indicate that to
792802 // avoid an outer loop.
793- bool OneIterationPerThread = false ;
794803 if (config::getAssumeThreadsOversubscription ()) {
795- ASSERT (NumThreads >= NumIters, " Broken assumption" );
796804 OneIterationPerThread = true ;
797805 }
798806
807+ if (OneIterationPerThread)
808+ ASSERT (NumThreads >= NumIters, " Broken assumption" );
809+
799810 if (ThreadChunk != 1 )
800811 NormalizedLoopNestChunked (LoopBody, Arg, BlockChunk, NumBlocks, BId,
801812 ThreadChunk, NumThreads, TId, NumIters,
@@ -806,8 +817,17 @@ template <typename Ty> class StaticLoopChunker {
806817 }
807818
808819 // / Worksharing `distribute`-loop.
820+ // / \param[in] Loc Description of source location
821+ // / \param[in] LoopBody Function which corresponds to loop body
822+ // / \param[in] Arg Pointer to struct which contains loop body args
823+ // / \param[in] NumIters Number of loop iterations
824+ // / \param[in] BlockChunk Size of block chunk
825+ // / \param[in] OneIterationPerThread If true/nonzero, each thread executes
826+ // / only one loop iteration or one thread chunk. This avoids an outer loop
827+ // / over all loop iterations/chunks.
809828 static void Distribute (IdentTy *Loc, void (*LoopBody)(Ty, void *), void *Arg,
810- Ty NumIters, Ty BlockChunk) {
829+ Ty NumIters, Ty BlockChunk,
830+ uint8_t OneIterationPerThread) {
811831 ASSERT (icv::Level == 0 , " Bad distribute" );
812832 ASSERT (icv::ActiveLevel == 0 , " Bad distribute" );
813833 ASSERT (state::ParallelRegionFn == nullptr , " Bad distribute" );
@@ -831,12 +851,13 @@ template <typename Ty> class StaticLoopChunker {
831851
832852 // If we know we have more blocks than iterations we can indicate that to
833853 // avoid an outer loop.
834- bool OneIterationPerThread = false ;
835854 if (config::getAssumeTeamsOversubscription ()) {
836- ASSERT (NumBlocks >= NumIters, " Broken assumption" );
837855 OneIterationPerThread = true ;
838856 }
839857
858+ if (OneIterationPerThread)
859+ ASSERT (NumBlocks >= NumIters, " Broken assumption" );
860+
840861 if (BlockChunk != NumThreads)
841862 NormalizedLoopNestChunked (LoopBody, Arg, BlockChunk, NumBlocks, BId,
842863 ThreadChunk, NumThreads, TId, NumIters,
@@ -852,9 +873,20 @@ template <typename Ty> class StaticLoopChunker {
852873 }
853874
854875 // / Worksharing `distribute parallel for`-loop.
876+ // / \param[in] Loc Description of source location
877+ // / \param[in] LoopBody Function which corresponds to loop body
878+ // / \param[in] Arg Pointer to struct which contains loop body args
879+ // / \param[in] NumIters Number of loop iterations
880+ // / \param[in] NumThreads Number of GPU threads
881+ // / \param[in] BlockChunk Size of block chunk
882+ // / \param[in] ThreadChunk Size of thread chunk
883+ // / \param[in] OneIterationPerThread If true/nonzero, each thread executes
884+ // / only one loop iteration or one thread chunk. This avoids an outer loop
885+ // / over all loop iterations/chunks.
855886 static void DistributeFor (IdentTy *Loc, void (*LoopBody)(Ty, void *),
856887 void *Arg, Ty NumIters, Ty NumThreads,
857- Ty BlockChunk, Ty ThreadChunk) {
888+ Ty BlockChunk, Ty ThreadChunk,
889+ uint8_t OneIterationPerThread) {
858890 ASSERT (icv::Level == 1 , " Bad distribute" );
859891 ASSERT (icv::ActiveLevel == 1 , " Bad distribute" );
860892 ASSERT (state::ParallelRegionFn == nullptr , " Bad distribute" );
@@ -882,13 +914,14 @@ template <typename Ty> class StaticLoopChunker {
882914
883915 // If we know we have more threads (across all blocks) than iterations we
884916 // can indicate that to avoid an outer loop.
885- bool OneIterationPerThread = false ;
886917 if (config::getAssumeTeamsOversubscription () &
887918 config::getAssumeThreadsOversubscription ()) {
888919 OneIterationPerThread = true ;
889- ASSERT (NumBlocks * NumThreads >= NumIters, " Broken assumption" );
890920 }
891921
922+ if (OneIterationPerThread)
923+ ASSERT (NumBlocks * NumThreads >= NumIters, " Broken assumption" );
924+
892925 if (BlockChunk != NumThreads || ThreadChunk != 1 )
893926 NormalizedLoopNestChunked (LoopBody, Arg, BlockChunk, NumBlocks, BId,
894927 ThreadChunk, NumThreads, TId, NumIters,
@@ -907,24 +940,26 @@ template <typename Ty> class StaticLoopChunker {
907940
908941#define OMP_LOOP_ENTRY (BW, TY ) \
909942 [[gnu::flatten, clang::always_inline]] void \
910- __kmpc_distribute_for_static_loop##BW( \
911- IdentTy *loc, void (*fn)(TY, void *), void *arg, TY num_iters, \
912- TY num_threads, TY block_chunk, TY thread_chunk) { \
943+ __kmpc_distribute_for_static_loop##BW( \
944+ IdentTy *loc, void (*fn)(TY, void *), void *arg, TY num_iters, \
945+ TY num_threads, TY block_chunk, TY thread_chunk, \
946+ uint8_t one_iteration_per_thread) { \
913947 ompx::StaticLoopChunker<TY>::DistributeFor ( \
914- loc, fn, arg, num_iters, num_threads, block_chunk, thread_chunk); \
948+ loc, fn, arg, num_iters, num_threads, block_chunk, thread_chunk, \
949+ one_iteration_per_thread); \
915950 } \
916951 [[gnu::flatten, clang::always_inline]] void \
917- __kmpc_distribute_static_loop##BW(IdentTy *loc, void (*fn)(TY, void *), \
918- void *arg, TY num_iters, \
919- TY block_chunk ) { \
920- ompx::StaticLoopChunker<TY>::Distribute (loc, fn, arg, num_iters, \
921- block_chunk); \
952+ __kmpc_distribute_static_loop##BW(IdentTy *loc, void (*fn)(TY, void *), \
953+ void *arg, TY num_iters, TY block_chunk, \
954+ uint8_t one_iteration_per_thread ) { \
955+ ompx::StaticLoopChunker<TY>::Distribute ( \
956+ loc, fn, arg, num_iters, block_chunk, one_iteration_per_thread); \
922957 } \
923958 [[gnu::flatten, clang::always_inline]] void __kmpc_for_static_loop##BW( \
924959 IdentTy *loc, void (*fn)(TY, void *), void *arg, TY num_iters, \
925- TY num_threads, TY thread_chunk) { \
960+ TY num_threads, TY thread_chunk, uint8_t one_iteration_per_thread) { \
926961 ompx::StaticLoopChunker<TY>::For (loc, fn, arg, num_iters, num_threads, \
927- thread_chunk); \
962+ thread_chunk, one_iteration_per_thread); \
928963 }
929964
930965extern " C" {
0 commit comments