@@ -1370,7 +1370,7 @@ void SelectionDAG::init(MachineFunction &NewMF,
13701370 const TargetLibraryInfo *LibraryInfo,
13711371 UniformityInfo *NewUA, ProfileSummaryInfo *PSIin,
13721372 BlockFrequencyInfo *BFIin, MachineModuleInfo &MMIin,
1373- FunctionVarLocs const *VarLocs) {
1373+ FunctionVarLocs const *VarLocs, bool HasDivergency ) {
13741374 MF = &NewMF;
13751375 SDAGISelPass = PassPtr;
13761376 ORE = &NewORE;
@@ -1383,6 +1383,7 @@ void SelectionDAG::init(MachineFunction &NewMF,
13831383 BFI = BFIin;
13841384 MMI = &MMIin;
13851385 FnVarLocs = VarLocs;
1386+ DivergentTarget = HasDivergency;
13861387}
13871388
13881389SelectionDAG::~SelectionDAG() {
@@ -2329,7 +2330,8 @@ SDValue SelectionDAG::getRegister(Register Reg, EVT VT) {
23292330 return SDValue(E, 0);
23302331
23312332 auto *N = newSDNode<RegisterSDNode>(Reg, VTs);
2332- N->SDNodeBits.IsDivergent = TLI->isSDNodeSourceOfDivergence(N, FLI, UA);
2333+ N->SDNodeBits.IsDivergent =
2334+ DivergentTarget && TLI->isSDNodeSourceOfDivergence(N, FLI, UA);
23332335 CSEMap.InsertNode(N, IP);
23342336 InsertNode(N);
23352337 return SDValue(N, 0);
@@ -12067,6 +12069,8 @@ static bool gluePropagatesDivergence(const SDNode *Node) {
1206712069}
1206812070
1206912071bool SelectionDAG::calculateDivergence(SDNode *N) {
12072+ if(!DivergentTarget)
12073+ return false;
1207012074 if (TLI->isSDNodeAlwaysUniform(N)) {
1207112075 assert(!TLI->isSDNodeSourceOfDivergence(N, FLI, UA) &&
1207212076 "Conflicting divergence information!");
@@ -12086,6 +12090,8 @@ bool SelectionDAG::calculateDivergence(SDNode *N) {
1208612090}
1208712091
1208812092void SelectionDAG::updateDivergence(SDNode *N) {
12093+ if (!DivergentTarget)
12094+ return;
1208912095 SmallVector<SDNode *, 16> Worklist(1, N);
1209012096 do {
1209112097 N = Worklist.pop_back_val();
@@ -13633,16 +13639,20 @@ void SelectionDAG::createOperands(SDNode *Node, ArrayRef<SDValue> Vals) {
1363313639 Ops[I].setInitial(Vals[I]);
1363413640 EVT VT = Ops[I].getValueType();
1363513641
13642+ // Take care of the Node's operands iff target has divergence
1363613643 // Skip Chain. It does not carry divergence.
13637- if (VT != MVT::Other &&
13644+ if (DivergentTarget && VT != MVT::Other &&
1363813645 (VT != MVT::Glue || gluePropagatesDivergence(Ops[I].getNode())) &&
1363913646 Ops[I].getNode()->isDivergent()) {
13647+ // Node is going to be divergent if at least one of its operand is
13648+ // divergent, unless it belongs to the "AlwaysUniform" exemptions.
1364013649 IsDivergent = true;
1364113650 }
1364213651 }
1364313652 Node->NumOperands = Vals.size();
1364413653 Node->OperandList = Ops;
13645- if (!TLI->isSDNodeAlwaysUniform(Node)) {
13654+ // Check the divergence of the Node itself.
13655+ if (DivergentTarget && !TLI->isSDNodeAlwaysUniform(Node)) {
1364613656 IsDivergent |= TLI->isSDNodeSourceOfDivergence(Node, FLI, UA);
1364713657 Node->SDNodeBits.IsDivergent = IsDivergent;
1364813658 }
0 commit comments