@@ -23,7 +23,6 @@ SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
2323
2424
2525======================= end_copyright_notice ==================================*/
26-
2726#include " GenISAIntrinsics/GenIntrinsicInst.h"
2827#include " ThreadCombining.hpp"
2928#include " Compiler/IGCPassSupport.h"
@@ -366,16 +365,15 @@ bool ThreadCombining::canDoOptimization(Function* m_kernel, llvm::Module& M)
366365 unsigned int threadGroupSize_Y = GetthreadGroupSize (M, ThreadGroupSize_Y);
367366 unsigned int threadGroupSize_Z = GetthreadGroupSize (M, ThreadGroupSize_Z);
368367
369- if (threadGroupSize_X == 1 ||
370- threadGroupSize_Y == 1 ||
371- threadGroupSize_Z != 1 )
372- {
373- return false ;
374- }
375-
376368 std::vector<llvm::Instruction*> barriers;
377369 PreAnalysis (m_kernel, M, barriers);
378370
371+ // Explicit thread group size shrinking works only for no-barrier no-SLM case
372+ if (IGC_IS_FLAG_ENABLED (EnableForceGroupSize))
373+ {
374+ return barriers.empty () && !m_SLMUsed && (threadGroupSize_Z == 1 );
375+ }
376+
379377 PostDominatorTree* PDT = &getAnalysis<PostDominatorTreeWrapperPass>(*m_kernel).getPostDomTree ();
380378 // Check if any of the barriers are within control flow
381379 bool anyBarrierWithinControlFlow = false ;
@@ -387,13 +385,23 @@ bool ThreadCombining::canDoOptimization(Function* m_kernel, llvm::Module& M)
387385 }
388386 }
389387
390- if ((!m_SLMUsed && IGC_IS_FLAG_DISABLED (EnableThreadCombiningWithNoSLM)) ||
391- anyBarrierWithinControlFlow)
388+ if (anyBarrierWithinControlFlow)
392389 {
393390 return false ;
394391 }
395392
396- FindRegistersAliveAcrossBarriers (m_kernel, M);
393+ if (threadGroupSize_X == 1 ||
394+ threadGroupSize_Y == 1 ||
395+ threadGroupSize_Z != 1 )
396+ {
397+ return false ;
398+ }
399+
400+ if (!m_SLMUsed && IGC_IS_FLAG_DISABLED (EnableThreadCombiningWithNoSLM)
401+ && !IGC_IS_FLAG_ENABLED (EnableForceThreadCombining))
402+ {
403+ return false ;
404+ }
397405
398406 return true ;
399407}
@@ -621,6 +629,83 @@ void ThreadCombining::CreateNewKernel(llvm::Module& M,
621629 }
622630}
623631
632+ // Remap ThreadIDs and GroupIDs to old values
633+ void ThreadCombining::remapThreads (
634+ llvm::Module& M,
635+ unsigned int newSizeX,
636+ unsigned int newSizeY,
637+ unsigned int threadGroupSize_X,
638+ unsigned int threadGroupSize_Y,
639+ llvm::IRBuilder<> builder)
640+ {
641+ unsigned int threadGroupSizeModifier_X = threadGroupSize_X / newSizeX;
642+ unsigned int threadGroupSizeModifier_Y = threadGroupSize_Y / newSizeY;
643+
644+ BasicBlock* oldEntry = &(m_kernel->getEntryBlock ());
645+ BasicBlock* newEntry = BasicBlock::Create (M.getContext (), " ThreadID_remap" , m_kernel, oldEntry);
646+
647+ builder.SetInsertPoint (newEntry);
648+
649+ Function* ThreadIDFN = GenISAIntrinsic::getDeclaration (&M, GenISAIntrinsic::GenISA_DCL_SystemValue, builder.getFloatTy ());
650+ Value* threadID_X = builder.CreateCall (ThreadIDFN, builder.getInt32 (THREAD_ID_IN_GROUP_X));
651+ Value* threadID_Y = builder.CreateCall (ThreadIDFN, builder.getInt32 (THREAD_ID_IN_GROUP_Y));
652+ Value* groupID_X = builder.CreateCall (ThreadIDFN, builder.getInt32 (THREAD_GROUP_ID_X));
653+ Value* groupID_Y = builder.CreateCall (ThreadIDFN, builder.getInt32 (THREAD_GROUP_ID_Y));
654+
655+ threadID_X = builder.CreateBitCast (threadID_X, builder.getInt32Ty ());
656+ threadID_Y = builder.CreateBitCast (threadID_Y, builder.getInt32Ty ());
657+ groupID_X = builder.CreateBitCast (groupID_X, builder.getInt32Ty ());
658+ groupID_Y = builder.CreateBitCast (groupID_Y, builder.getInt32Ty ());
659+
660+ Value* oldGroupID_X = builder.CreateUDiv (groupID_X, builder.getInt32 (threadGroupSizeModifier_X));
661+ Value* oldGroupID_Y = builder.CreateUDiv (groupID_Y, builder.getInt32 (threadGroupSizeModifier_Y));
662+
663+ Value* oldThreadID_X = builder.CreateURem (groupID_X, builder.getInt32 (threadGroupSizeModifier_X));
664+ oldThreadID_X = builder.CreateAdd (threadID_X, builder.CreateMul (builder.getInt32 (newSizeX), oldThreadID_X));
665+ Value* oldThreadID_Y = builder.CreateURem (groupID_Y, builder.getInt32 (threadGroupSizeModifier_Y));
666+ oldThreadID_Y = builder.CreateAdd (threadID_Y, builder.CreateMul (builder.getInt32 (newSizeY), oldThreadID_Y));
667+
668+ for (auto & BI : *m_kernel)
669+ {
670+ for (auto & inst : BI)
671+ {
672+ if (&BI == newEntry)
673+ {
674+ continue ;
675+ }
676+ if (GenIntrinsicInst * b = dyn_cast<GenIntrinsicInst>(&inst))
677+ {
678+ if (b->getIntrinsicID () == GenISAIntrinsic::GenISA_DCL_SystemValue)
679+ {
680+ switch (cast<ConstantInt>(b->getOperand (0 ))->getZExtValue ())
681+ {
682+ case THREAD_ID_IN_GROUP_X:
683+ b->replaceAllUsesWith (builder.CreateBitCast (oldThreadID_X, b->getType ()));
684+ break ;
685+ case THREAD_ID_IN_GROUP_Y:
686+ b->replaceAllUsesWith (builder.CreateBitCast (oldThreadID_Y, b->getType ()));
687+ break ;
688+ case THREAD_GROUP_ID_X:
689+ b->replaceAllUsesWith (builder.CreateBitCast (oldGroupID_X, b->getType ()));
690+ break ;
691+ case THREAD_GROUP_ID_Y:
692+ b->replaceAllUsesWith (builder.CreateBitCast (oldGroupID_Y, b->getType ()));
693+ break ;
694+ default :
695+ break ;
696+ }
697+ }
698+ }
699+ }
700+ }
701+ builder.CreateBr (oldEntry);
702+
703+ // Set in global variable, how many times thread group size was reduced
704+ // It will be used by UMD for increasing dispatch size in the same amount
705+ M.getGlobalVariable (" ThreadGroupModifier_X" )->setInitializer (builder.getInt32 (threadGroupSizeModifier_X));
706+ M.getGlobalVariable (" ThreadGroupModifier_Y" )->setInitializer (builder.getInt32 (threadGroupSizeModifier_Y));
707+ }
708+
624709bool ThreadCombining::runOnModule (llvm::Module& M)
625710{
626711 llvm::IRBuilder<> builder (M.getContext ());
@@ -635,6 +720,8 @@ bool ThreadCombining::runOnModule(llvm::Module& M)
635720 return false ;
636721 }
637722
723+ FindRegistersAliveAcrossBarriers (m_kernel, M);
724+
638725 unsigned int threadGroupSize_X = GetthreadGroupSize (M, ThreadGroupSize_X);
639726 unsigned int threadGroupSize_Y = GetthreadGroupSize (M, ThreadGroupSize_Y);
640727
@@ -680,15 +767,20 @@ bool ThreadCombining::runOnModule(llvm::Module& M)
680767
681768 unsigned int newSizeX = threadGroupSize_X;
682769 unsigned int newSizeY = threadGroupSize_Y;
683- // Heuristic for Threadcombining based on EU Occupancy, if EU occupancy increases with the new
684- // size then combine threads, otherwise skip it
685- if (IGC_IS_FLAG_ENABLED (EnableForceGroupSize))
770+ if (IGC_IS_FLAG_ENABLED (EnableForceGroupSize) || IGC_IS_FLAG_ENABLED (EnableForceThreadCombining))
686771 {
772+ if (IGC_GET_FLAG_VALUE (ForceGroupSizeShaderHash) &&
773+ (IGC_GET_FLAG_VALUE (ForceGroupSizeShaderHash) != (DWORD)csCtx->hash .getAsmHash ()))
774+ {
775+ return false ;
776+ }
687777 newSizeX = IGC_GET_FLAG_VALUE (ForceGroupSizeX);
688778 newSizeY = IGC_GET_FLAG_VALUE (ForceGroupSizeY);
689779 }
690780 else if (x * y >= minTGSizeHeuristic && newThreadOccupancy > currentThreadOccupancy)
691781 {
782+ // Heuristic for Threadcombining based on EU Occupancy, if EU occupancy increases with the new
783+ // size then combine threads, otherwise skip it
692784 newSizeX = x;
693785 newSizeY = y;
694786 currentThreadOccupancy = newThreadOccupancy;
@@ -706,12 +798,31 @@ bool ThreadCombining::runOnModule(llvm::Module& M)
706798 return false ;
707799 }
708800
709- IGC_ASSERT (newSizeX <= threadGroupSize_X);
710- IGC_ASSERT (newSizeY <= threadGroupSize_Y);
801+ if ((newSizeX > threadGroupSize_X) ||
802+ (newSizeY > threadGroupSize_Y) ||
803+ ((threadGroupSize_X % newSizeX) != 0 ) ||
804+ ((threadGroupSize_Y % newSizeY) != 0 ))
805+ {
806+ return false ;
807+ }
711808
712809 SetthreadGroupSize (M, builder.getInt32 (newSizeX), ThreadGroupSize_X);
713810 SetthreadGroupSize (M, builder.getInt32 (newSizeY), ThreadGroupSize_Y);
714811
812+ if (IGC_IS_FLAG_ENABLED (EnableForceGroupSize))
813+ {
814+ // Don't perform thread combining, just remap threads as if thread group size hasn't been changed
815+ remapThreads (
816+ M,
817+ newSizeX,
818+ newSizeY,
819+ threadGroupSize_X,
820+ threadGroupSize_Y,
821+ builder);
822+ return true ;
823+ }
824+
825+ // Perform Thread Combining
715826 // Create a new function with function arguments, New threadIDX, threadIDY,
716827 // a bool variable to indicate if it is kernel section before last barrier or after
717828 // last barrier and all the live variables
@@ -759,6 +870,5 @@ bool ThreadCombining::runOnModule(llvm::Module& M)
759870 builder);
760871
761872 context->m_threadCombiningOptDone = true ;
762-
763873 return true ;
764874}
0 commit comments