@@ -2820,6 +2820,27 @@ void FlowGraph::markDivergentBBs()
28202820 // [(p)] while/goto L
28212821 // <joinBB>
28222822 //
2823+ // The presence of loop makes the checking complicated. Before checking the
2824+ // divergence of a loop head, we need to know if the loop has any non-uniform
2825+ // out-of-loop branch, which includes checking the loop's backedge and any
2826+ // out-going branches inside the loop body. For example,
2827+ // L0:
2828+ // B0 if (uniform cond)
2829+ // B1: else
2830+ // B2
2831+ // L1:
2832+ // B3: goto OUT
2833+ // B4 : if (non-uniform cond) goto L1
2834+ // OUT1:
2835+ // B5
2836+ // B6 : if (uniform cond) goto L0
2837+ //
2838+ // OUT: B6
2839+ //
2840+ // Once scanning B0, we don't know whether it is divergent until we find out "goto OUT"
2841+ // is divergent out-of-loop branch. In turn, in order to know "goto OUT" is divergent,
2842+ // we need to check the back branch of loop L1. For this, we will pre-scan the loop.
2843+ //
28232844 int LastJoinBBId;
28242845
28252846 auto pushJoin = [&](G4_BB* joinBB) {
@@ -2879,55 +2900,85 @@ void FlowGraph::markDivergentBBs()
28792900 }
28802901 }
28812902
2882- // Check if the BB referred to by CurrIT needs to update lastJoin and do
2883- // updating if so. The IterEnd is the end iterator of current function
2884- // that CurrIT refers to.
2903+ // Update LastJoin for the backward branch of BB referred to by IT.
2904+ auto updateLastJoinForBwdBrInBB = [&](BB_LIST_ITER& IT)
2905+ {
2906+ G4_BB* predBB = *IT;
2907+ G4_INST* bInst = predBB->back ();
2908+ assert (bInst->isCFInst () &&
2909+ (bInst->opcode () == G4_while || bInst->asCFInst ()->isBackward ()) &&
2910+ " ICE: expected backward branch/while!" );
2911+
2912+ // joinBB of a loop is the BB right after tail BB
2913+ BB_LIST_ITER loopJoinIter = IT;
2914+ ++loopJoinIter;
2915+ if (loopJoinIter == BBs.end ())
2916+ {
2917+ // Loop end is the last BB (no Join BB). This happens in the
2918+ // following cases:
2919+ // 1. For CM, CM's return is in the middle of code! For IGC,
2920+ // this never happen as a return, if present, must be in
2921+ // the last BB.
2922+ // 2. The last segment code is an infinite loop (static infinite
2923+ // loop so that compiler knows it is an infinite loop).
2924+ // In either case, no join is needed.
2925+ return ;
2926+ }
2927+
2928+ // Update LastJoin if the branch itself is divergent.
2929+ // (todo: checking G4_jmpi is redundant as jmpi must have isUniform()
2930+ // return true.)
2931+ if ((!bInst->asCFInst ()->isUniform () && bInst->opcode () != G4_jmpi))
2932+ {
2933+ G4_BB* joinBB = *loopJoinIter;
2934+ pushJoin (joinBB);
2935+ }
2936+ return ;
2937+ };
2938+
2939+ // Update LastJoin for BB referred to by IT. It is called either from normal
2940+ // scan (setting BB's divergent field) or from loop pre-scan (no setting to
2941+ // BB's divergent field). IsPreScan indicates whether it is called from
2942+ // the pre-scan or not.
28852943 //
2886- // 1. update lastJoinBB if needed
2887- // 2. set up divergence for entry of subroutines if divergent
2944+ // The normal scan (IsPreScan is false), this function also update the divergence
2945+ // info for any called subroutines.
28882946 //
2889- auto updateLastJoinForBB = [&](BB_LIST_ITER& CurrIT, BB_LIST_ITER& IterEnd )
2947+ auto updateLastJoinForFwdBrInBB = [&](BB_LIST_ITER& IT, bool IsPreScan )
28902948 {
2891- G4_BB* aBB = *CurrIT ;
2949+ G4_BB* aBB = *IT ;
28922950 if (aBB->size () == 0 )
28932951 {
28942952 return ;
28952953 }
28962954
2955+ // Using isPriorToLastJoin() works for both loop pre-scan and normal scan.
2956+ // (aBB->isDivergent() works for normal scan only.)
2957+ bool isBBDivergent = isPriorToLastJoin (aBB);
2958+
28972959 G4_INST* lastInst = aBB->back ();
2898- if ((lastInst->opcode () == G4_while ||
2899- (lastInst->opcode () == G4_goto && lastInst->asCFInst ()->isBackward ())) &&
2900- (!lastInst->asCFInst ()->isUniform () || aBB->isDivergent ()))
2901- {
2902- if (CurrIT != IterEnd) {
2903- auto NI = CurrIT;
2904- ++NI;
2905- G4_BB* joinBB = *NI;
2906- pushJoin (joinBB);
2907- }
2908- }
2909- else if (((lastInst->opcode () == G4_goto && !lastInst->asCFInst ()->isBackward ()) ||
2960+ if (((lastInst->opcode () == G4_goto && !lastInst->asCFInst ()->isBackward ()) ||
29102961 lastInst->opcode () == G4_break) &&
2911- (!lastInst->asCFInst ()->isUniform () || aBB-> isDivergent () ))
2962+ (!lastInst->asCFInst ()->isUniform () || isBBDivergent ))
29122963 {
29132964 // forward goto/break : the last Succ BB is our target BB
29142965 // For break, it should be the BB right after while inst.
29152966 G4_BB* joinBB = aBB->Succs .back ();
29162967 pushJoin (joinBB);
29172968 }
29182969 else if (lastInst->opcode () == G4_if &&
2919- (!lastInst->asCFInst ()->isUniform () || aBB-> isDivergent () ))
2970+ (!lastInst->asCFInst ()->isUniform () || isBBDivergent ))
29202971 {
29212972 G4_Label* labelInst = lastInst->asCFInst ()->getUip ();
2922- G4_BB* joinBB = findLabelBB (CurrIT, IterEnd , labelInst->getLabel ());
2973+ G4_BB* joinBB = findLabelBB (IT, BBs. end () , labelInst->getLabel ());
29232974 assert (joinBB && " ICE(vISA) : missing endif label!" );
29242975 pushJoin (joinBB);
29252976 }
2926- else if (lastInst->opcode () == G4_call)
2977+ else if (!IsPreScan && lastInst->opcode () == G4_call)
29272978 {
29282979 // If this function is already in divergent branch, the callee
29292980 // must be in a divergent branch!.
2930- if (aBB-> isDivergent () || lastInst->getPredicate () != nullptr )
2981+ if (isBBDivergent || lastInst->getPredicate () != nullptr )
29312982 {
29322983 FuncInfo* calleeFunc = aBB->getCalleeInfo ();
29332984 if (funcInfoIndex.count (calleeFunc))
@@ -2940,6 +2991,8 @@ void FlowGraph::markDivergentBBs()
29402991 return ;
29412992 };
29422993
2994+
2995+ // Now, scan all subroutines (kernels or functions)
29432996 for (int i = 0 ; i < numFuncs; ++i)
29442997 {
29452998 // each function: [IT, IE)
@@ -2977,10 +3030,11 @@ void FlowGraph::markDivergentBBs()
29773030 }
29783031 }
29793032 }
2980- // continue for next func
3033+ // continue for next subroutine
29813034 continue ;
29823035 }
29833036
3037+ // Scaning all BBs of a single subroutine (or kernel, or function).
29843038 LastJoinBBId = -1 ;
29853039 for (auto IT = IS; IT != IE; ++IT)
29863040 {
@@ -2991,38 +3045,23 @@ void FlowGraph::markDivergentBBs()
29913045
29923046 // Handle loop
29933047 // Loop needs to be scanned twice in order to get an accurate marking.
2994- // For example,
2995- // L:
2996- // B1:
2997- // if (p0) goto B2
2998- // ...
2999- // else
3000- // if (p1) goto OUT
3001- // endif
3002- // B2:
3003- // (p2) goto L
3004- // B3:
3005- // OUT:
3006- //
3007- // We don't know whether B1 is divergent until the entire loop body has been
3008- // scanned, so that we know any out-of-loop gotos (goto out and goto L in this
3009- // case). This require scanning the loop twice.
3010- //
3011- for (auto iter = BB->Preds .begin (), iterEnd = BB->Preds .end (); iter != iterEnd; ++iter)
3012- {
3013- G4_BB* predBB = *iter;
3014- if (predBB->getId () < BB->getId ())
3015- continue ;
3048+ // In pre-scan (1st scan), it finds out divergent out-of-loop branch.
3049+ // If found, it updates LastJoin to the target of that out-of-loop branch
3050+ // and restart the normal scan. If not, it restarts the normal scan with
3051+ // the original LastJoin unchanged.
30163052
3017- assert (predBB->size () > 0 && " ICE: missing branch inst!" );
3018- G4_INST* bInst = predBB->back ();
3019- if (bInst->opcode () != G4_goto &&
3020- bInst->opcode () != G4_while &&
3021- bInst->opcode () != G4_jmpi)
3022- {
3023- // not loop
3053+ // BB could be head of several loops, and the following does pre-scan for
3054+ // every one of those loops.
3055+ for (auto PI0 = BB->Preds .begin (), PI0E = BB->Preds .end (); PI0 != PI0E; ++PI0)
3056+ {
3057+ G4_BB* predBB = *PI0;
3058+ if (!isBackwardBranch (predBB, BB)) {
30243059 continue ;
30253060 }
3061+ assert (BB->getId () <= predBB->getId () && " Branch incorrectly set to be backward!" );
3062+
3063+ BB_LIST_ITER LoopITEnd = std::find (BBs.begin (), BBs.end (), predBB);
3064+ updateLastJoinForBwdBrInBB (LoopITEnd);
30263065
30273066 // If lastJoin is already after loop end, no need to scan loop twice
30283067 // as all BBs in the loop must be divergent
@@ -3031,82 +3070,87 @@ void FlowGraph::markDivergentBBs()
30313070 continue ;
30323071 }
30333072
3034- BB_LIST_ITER LoopIterEnd = std::find (BBs.begin (), BBs.end (), predBB);
3035-
3036- // joinBB of a loop is the BB right after backward-goto/while
3037- BB_LIST_ITER loopJoinIter = LoopIterEnd;
3038- ++loopJoinIter;
3039- if (loopJoinIter == BBs.end ())
3040- {
3041- // Loop end is the last BB (no Join BB). This happens for CM
3042- // in which CM's return is in the middle of code! Note that
3043- // IGC's return is the last BB always.
3044- if (builder->kernel .getOptions ()->getTarget () != VISA_CM)
3045- {
3046- // IGC: loop end should not be the last BB as the last
3047- // BB must be the return.
3048- assert (false && " ICE: return must be the last BB!" );
3049- }
3050- continue ;
3051- }
3052-
3053- // If backward goto/while is divergent, update lastJoin with
3054- // loop's join BB. No need to pre-scan loop!
3055- G4_BB* joinBB = *loopJoinIter;
3056- if (!bInst->asCFInst ()->isUniform () &&
3057- bInst->opcode () != G4_jmpi)
3058- {
3059- pushJoin (joinBB);
3060- continue ;
3061- }
3062-
3063- // pre-scan loop to find any out-of-loop branch, set join if found
3073+ // pre-scan loop (BB, predBB)
3074+ //
3075+ // Observation:
3076+ // pre-scan loop once and as a BB is scanned. Each backward
3077+ // branch to this BB and a forward branch from this BB are processed.
3078+ // Doing so finds out any divergent out-of-loop branch iff the
3079+ // loop has one. In addition, if a loop has more than one divergent
3080+ // out-of-loop branches, using any of those branches would get us
3081+ // precise divergence during the normal scan.
30643082 //
3065- // LastJoinBBId will be updated iff there is a branch out of loop.
3066- // For example,
3083+ // LastJoinBBId will be updated iff there is an out-of-loop branch that is
3084+ // is also divergent. For example,
3085+ // a) "goto OUT" is out-of-loop branch, but not divergent. Thus, LastJoinBB
3086+ // will not be updated.
30673087 //
30683088 // L :
30693089 // B0
3070- // (p ) goto OUT; // uniform
3071- // if // divergent
3090+ // if (uniform cond ) goto OUT;
3091+ // if (non-uniform cond)
30723092 // B1
30733093 // else
30743094 // B2
30753095 // B3
30763096 // (p) jmpi L
30773097 // OUT:
3078- // Assume LastJoinBBId = -1 (no join) right before this loop, pre-scanning
3079- // loop will set LastJoinBBId = B3, since it is not out of the loop, we will
3080- // need to reset it to -1. Otherwise, normal scan of loop will make B0 as
3081- // divergent due to LastJoinBBId = B3.
30823098 //
3083- // The reason for updating LastJoinBBId during pre-scanning is to check if
3084- // out-of-loop goto is uniform or not. The above "goto OUT" is uniform, thus
3085- // it does not make B0 divergent. Without updating LastJoinBBId, this goto
3086- // will be conservatively treated as divergent goto.
3099+ // b) "goto OUT" is out-of-loop branch and divergent. Thus, update LastJoinBB.
3100+ // Note that "goto OUT" is divergent because loop L1 is divergent, which
3101+ // makes every BB in L1 divergent. And any branch from a divergent BB
3102+ // must be divergent.
3103+ //
3104+ // L :
3105+ // B0
3106+ // L1:
3107+ // B1
3108+ // if (uniform cond) goto OUT;
3109+ // if (cond)
3110+ // B2
3111+ // else
3112+ // B3
3113+ // if (non-uniform cond) goto L1
3114+ // B4
3115+ // (uniform cond) while L
3116+ // OUT:
3117+ //
30873118 int orig_LastJoinBBId = LastJoinBBId;
3088- for (auto LoopIter = IT; LoopIter != LoopIterEnd ; ++LoopIter )
3119+ for (auto LoopIT = IT; LoopIT != LoopITEnd ; ++LoopIT )
30893120 {
3090- updateLastJoinForBB (LoopIter, IE);
3091- if (isPriorToLastJoin (predBB))
3092- { // Once found, no need to pre-scan anymore.
3093- break ;
3121+ if (LoopIT != IT)
3122+ {
3123+ // Check loops that are fully inside the current loop.
3124+ G4_BB* H = *LoopIT;
3125+ for (auto PI1 = H->Preds .begin (), PI1E = H->Preds .end (); PI1 != PI1E; ++PI1)
3126+ {
3127+ G4_BB* T = *PI1;
3128+ if (!isBackwardBranch (T, H)) {
3129+ continue ;
3130+ }
3131+ assert (H->getId () <= T->getId () && " Branch incorrectly set to be backward!" );
3132+ BB_LIST_ITER TIter = std::find (BBs.begin (), BBs.end (), T);
3133+ updateLastJoinForBwdBrInBB (TIter);
3134+ }
30943135 }
3136+ updateLastJoinForFwdBrInBB (LoopIT, true );
30953137 }
3096- // If no branch out of loop, restore the original LastJoinBBId
3138+
3139+ // After scan, if no branch out of loop, restore the original LastJoinBBId
30973140 if (!isPriorToLastJoin (predBB))
3098- {
3141+ { // case a) above.
30993142 LastJoinBBId = orig_LastJoinBBId;
31003143 }
31013144 }
31023145
3146+ // normal scan of BB
31033147 if (isPriorToLastJoin (BB)) {
31043148 BB->setDivergent (true );
31053149 // set InSIMDFlow as well, will merge these two fields gradually
31063150 BB->setInSimdFlow (true );
31073151 }
31083152
3109- updateLastJoinForBB (IT, IE );
3153+ updateLastJoinForFwdBrInBB (IT, false );
31103154 }
31113155 }
31123156 return ;
0 commit comments