diff --git a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp index 7ebcc219efc15..cf98b4b012e32 100644 --- a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp +++ b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp @@ -1841,7 +1841,7 @@ bool IndVarSimplify::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) { // is that enough for *all* side effects? bool HasThreadLocalSideEffects = false; for (BasicBlock *BB : L->blocks()) - for (auto &I : *BB) + for (auto &I : *BB) { // TODO:isGuaranteedToTransfer if (I.mayHaveSideEffects()) { if (!LoopPredicationTraps) @@ -1859,6 +1859,18 @@ bool IndVarSimplify::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) { } } + // Skip if the loop has tokens referenced outside the loop to avoid + // changing convergence behavior. + if (I.getType()->isTokenTy()) { + for (User *U : I.users()) { + Instruction *UserInst = dyn_cast(U); + if (UserInst && !L->contains(UserInst)) { + return false; + } + } + } + } + bool Changed = false; // Finally, do the actual predication for all predicatable blocks. A couple // of notes here: diff --git a/llvm/test/Transforms/IndVarSimplify/skip-predication-convergence.ll b/llvm/test/Transforms/IndVarSimplify/skip-predication-convergence.ll new file mode 100644 index 0000000000000..59b84a3c082c2 --- /dev/null +++ b/llvm/test/Transforms/IndVarSimplify/skip-predication-convergence.ll @@ -0,0 +1,64 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt < %s -passes=indvars -indvars-predicate-loops=1 -S | FileCheck %s + +; Loop with body using loop convergence token should be skipped by IndVarSimplify. + +declare token @llvm.experimental.convergence.entry() #0 + +define void @loop(i32 %tid, ptr %array) #0 { +; CHECK-LABEL: @loop( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[TMP0:%.*]] = tail call token @llvm.experimental.convergence.entry() +; CHECK-NEXT: br label [[FOR_COND_I:%.*]] +; CHECK: for.cond.i: +; CHECK-NEXT: [[I_0_I:%.*]] = phi i32 [ 0, [[ENTRY:%.*]] ], [ [[INC_I:%.*]], [[FOR_BODY_I:%.*]] ] +; CHECK-NEXT: [[TMP1:%.*]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[TMP0]]) ] +; CHECK-NEXT: [[CMP_I:%.*]] = icmp ult i32 [[I_0_I]], 8 +; CHECK-NEXT: br i1 [[CMP_I]], label [[FOR_BODY_I]], label [[EXIT_LOOPEXIT:%.*]] +; CHECK: for.body.i: +; CHECK-NEXT: [[CMP1_I:%.*]] = icmp eq i32 [[I_0_I]], [[TID:%.*]] +; CHECK-NEXT: [[INC_I]] = add nuw nsw i32 [[I_0_I]], 1 +; CHECK-NEXT: br i1 [[CMP1_I]], label [[IF_THEN_I:%.*]], label [[FOR_COND_I]] +; CHECK: exit.loopexit: +; CHECK-NEXT: br label [[EXIT:%.*]] +; CHECK: if.then.i: +; CHECK-NEXT: [[HLSL_WAVE_ACTIVE_MAX2_I:%.*]] = call spir_func i32 @llvm.spv.wave.reduce.umax.i32(i32 [[TID]]) [ "convergencectrl"(token [[TMP1]]) ] +; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i32, ptr [[ARRAY:%.*]], i32 [[TID]] +; CHECK-NEXT: store i32 [[HLSL_WAVE_ACTIVE_MAX2_I]], ptr [[TMP2]], align 4 +; CHECK-NEXT: br label [[EXIT]] +; CHECK: exit: +; CHECK-NEXT: ret void +; +entry: + %0 = tail call token @llvm.experimental.convergence.entry() + br label %for.cond.i + +for.cond.i: + %i.0.i = phi i32 [ 0, %entry ], [ %inc.i, %for.body.i ] + %2 = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %0) ] + %cmp.i = icmp ult i32 %i.0.i, 8 + br i1 %cmp.i, label %for.body.i, label %exit.loopexit + +for.body.i: + %cmp1.i = icmp eq i32 %i.0.i, %tid + %inc.i = add nuw nsw i32 %i.0.i, 1 + br i1 %cmp1.i, label %if.then.i, label %for.cond.i + +exit.loopexit: + br label %exit + +if.then.i: + %hlsl.wave.active.max2.i = call spir_func i32 @llvm.spv.wave.reduce.umax.i32(i32 %tid) [ "convergencectrl"(token %2) ] + %3 = getelementptr inbounds i32, ptr %array, i32 %tid + store i32 %hlsl.wave.active.max2.i, ptr %3, align 4 + br label %exit + +exit: + ret void +} + +declare token @llvm.experimental.convergence.loop() #0 + +declare i32 @llvm.spv.wave.reduce.umax.i32(i32) #0 + +attributes #0 = { convergent } diff --git a/llvm/test/Transforms/IndVarSimplify/skip-predication-nested-convergence.ll b/llvm/test/Transforms/IndVarSimplify/skip-predication-nested-convergence.ll new file mode 100644 index 0000000000000..0944205839aca --- /dev/null +++ b/llvm/test/Transforms/IndVarSimplify/skip-predication-nested-convergence.ll @@ -0,0 +1,95 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt < %s -passes=indvars -indvars-predicate-loops=1 -S | FileCheck %s + +; Nested loops with body using loop convergence token should be skipped by IndVarSimplify. + +declare token @llvm.experimental.convergence.entry() #0 + +define void @nested(i32 %tidx, i32 %tidy, ptr %array) #0 { +; CHECK-LABEL: @nested( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[TMP0:%.*]] = tail call token @llvm.experimental.convergence.entry() +; CHECK-NEXT: [[MUL_I:%.*]] = shl nsw i32 [[TIDX:%.*]], 3 +; CHECK-NEXT: [[ADD_I:%.*]] = add nsw i32 [[MUL_I]], [[TIDY:%.*]] +; CHECK-NEXT: br label [[FOR_COND_I:%.*]] +; CHECK: for.cond.i: +; CHECK-NEXT: [[I_0_I:%.*]] = phi i32 [ 0, [[ENTRY:%.*]] ], [ [[INC10_I:%.*]], [[CLEANUP_I:%.*]] ] +; CHECK-NEXT: [[TMP1:%.*]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[TMP0]]) ] +; CHECK-NEXT: [[CMP_I:%.*]] = icmp ult i32 [[I_0_I]], 8 +; CHECK-NEXT: br i1 [[CMP_I]], label [[FOR_COND1_I_PREHEADER:%.*]], label [[EXIT:%.*]] +; CHECK: for.cond1.i.preheader: +; CHECK-NEXT: [[CMP5_I:%.*]] = icmp eq i32 [[I_0_I]], [[TIDX]] +; CHECK-NEXT: br label [[FOR_COND1_I:%.*]] +; CHECK: for.cond1.i: +; CHECK-NEXT: [[J_0_I:%.*]] = phi i32 [ [[INC_I:%.*]], [[FOR_BODY4_I:%.*]] ], [ 0, [[FOR_COND1_I_PREHEADER]] ] +; CHECK-NEXT: [[TMP2:%.*]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[TMP1]]) ] +; CHECK-NEXT: [[CMP2_I:%.*]] = icmp ult i32 [[J_0_I]], 8 +; CHECK-NEXT: br i1 [[CMP2_I]], label [[FOR_BODY4_I]], label [[CLEANUP_I_LOOPEXIT:%.*]] +; CHECK: for.body4.i: +; CHECK-NEXT: [[CMP6_I:%.*]] = icmp eq i32 [[J_0_I]], [[TIDY]] +; CHECK-NEXT: [[OR_COND:%.*]] = select i1 [[CMP5_I]], i1 [[CMP6_I]], i1 false +; CHECK-NEXT: [[INC_I]] = add nuw nsw i32 [[J_0_I]], 1 +; CHECK-NEXT: br i1 [[OR_COND]], label [[IF_THEN_I:%.*]], label [[FOR_COND1_I]] +; CHECK: cleanup.i.loopexit: +; CHECK-NEXT: br label [[CLEANUP_I]] +; CHECK: if.then.i: +; CHECK-NEXT: [[HLSL_WAVE_ACTIVE_MAX7_I:%.*]] = call spir_func i32 @llvm.spv.wave.reduce.umax.i32(i32 [[ADD_I]]) [ "convergencectrl"(token [[TMP2]]) ] +; CHECK-NEXT: [[TMP3:%.*]] = getelementptr inbounds i32, ptr [[ARRAY:%.*]], i32 [[ADD_I]] +; CHECK-NEXT: store i32 [[HLSL_WAVE_ACTIVE_MAX7_I]], ptr [[TMP3]], align 4 +; CHECK-NEXT: br label [[CLEANUP_I]] +; CHECK: cleanup.i: +; CHECK-NEXT: [[INC10_I]] = add nuw nsw i32 [[I_0_I]], 1 +; CHECK-NEXT: br label [[FOR_COND_I]] +; CHECK: exit: +; CHECK-NEXT: ret void +; +entry: + %0 = tail call token @llvm.experimental.convergence.entry() + %mul.i = shl nsw i32 %tidx, 3 + %add.i = add nsw i32 %mul.i, %tidy + br label %for.cond.i + +for.cond.i: + %i.0.i = phi i32 [ 0, %entry ], [ %inc10.i, %cleanup.i ] + %2 = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %0) ] + %cmp.i = icmp ult i32 %i.0.i, 8 + br i1 %cmp.i, label %for.cond1.i.preheader, label %exit + +for.cond1.i.preheader: + %cmp5.i = icmp eq i32 %i.0.i, %tidx + br label %for.cond1.i + +for.cond1.i: + %j.0.i = phi i32 [ %inc.i, %for.body4.i ], [ 0, %for.cond1.i.preheader ] + %3 = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %2) ] + %cmp2.i = icmp ult i32 %j.0.i, 8 + br i1 %cmp2.i, label %for.body4.i, label %cleanup.i.loopexit + +for.body4.i: + %cmp6.i = icmp eq i32 %j.0.i, %tidy + %or.cond = select i1 %cmp5.i, i1 %cmp6.i, i1 false + %inc.i = add nsw i32 %j.0.i, 1 + br i1 %or.cond, label %if.then.i, label %for.cond1.i + +cleanup.i.loopexit: + br label %cleanup.i + +if.then.i: + %hlsl.wave.active.max7.i = call spir_func i32 @llvm.spv.wave.reduce.umax.i32(i32 %add.i) [ "convergencectrl"(token %3) ] + %4 = getelementptr inbounds i32, ptr %array, i32 %add.i + store i32 %hlsl.wave.active.max7.i, ptr %4, align 4 + br label %cleanup.i + +cleanup.i: + %inc10.i = add nsw i32 %i.0.i, 1 + br label %for.cond.i + +exit: + ret void +} + +declare token @llvm.experimental.convergence.loop() #0 + +declare i32 @llvm.spv.wave.reduce.umax.i32(i32) #0 + +attributes #0 = { convergent }