Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions llvm/lib/Transforms/Vectorize/VPlan.h
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,10 @@ class VPIRFlags {
AllFlags = Other.AllFlags;
}

/// Only keep flags also present in \p Other. \p Other must have the same
/// OpType as the current object.
void intersectFlags(const VPIRFlags &Other);

/// Drop all poison-generating flags.
void dropPoisonGeneratingFlags() {
// NOTE: This needs to be kept in-sync with
Expand Down
36 changes: 36 additions & 0 deletions llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,42 @@ void VPPartialReductionRecipe::print(raw_ostream &O, const Twine &Indent,
}
#endif

void VPIRFlags::intersectFlags(const VPIRFlags &Other) {
assert(OpType == Other.OpType && "OpType must match");
switch (OpType) {
case OperationType::OverflowingBinOp:
WrapFlags.HasNUW &= Other.WrapFlags.HasNUW;
WrapFlags.HasNSW &= Other.WrapFlags.HasNSW;
break;
case OperationType::Trunc:
TruncFlags.HasNUW &= Other.TruncFlags.HasNUW;
TruncFlags.HasNSW &= Other.TruncFlags.HasNSW;
break;
case OperationType::DisjointOp:
DisjointFlags.IsDisjoint &= Other.DisjointFlags.IsDisjoint;
break;
case OperationType::PossiblyExactOp:
ExactFlags.IsExact = Other.ExactFlags.IsExact;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ExactFlags.IsExact = Other.ExactFlags.IsExact;
ExactFlags.IsExact &= Other.ExactFlags.IsExact;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be fixed, thanks. Also added a few missing tests

break;
case OperationType::GEPOp:
GEPFlags &= Other.GEPFlags;
break;
case OperationType::FPMathOp:
FMFs.NoNaNs &= Other.FMFs.NoNaNs;
FMFs.NoInfs &= Other.FMFs.NoInfs;
break;
case OperationType::NonNegOp:
NonNegFlags.NonNeg &= Other.NonNegFlags.NonNeg;
break;
case OperationType::Cmp:
assert(CmpPredicate == Other.CmpPredicate && "Cannot drop CmpPredicate");
break;
case OperationType::Other:
assert(AllFlags == Other.AllFlags && "Cannot drop other flags");
break;
}
}

FastMathFlags VPIRFlags::getFastMathFlags() const {
assert(OpType == OperationType::FPMathOp &&
"recipe doesn't have fast math flags");
Expand Down
4 changes: 2 additions & 2 deletions llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2042,9 +2042,9 @@ void VPlanTransforms::cse(VPlan &Plan) {
// V must dominate Def for a valid replacement.
if (!VPDT.dominates(V->getParent(), VPBB))
continue;
// Drop poison-generating flags when reusing a value.
// Only keep flags present on both V and Def.
if (auto *RFlags = dyn_cast<VPRecipeWithIRFlags>(V))
RFlags->dropPoisonGeneratingFlags();
RFlags->intersectFlags(*cast<VPRecipeWithIRFlags>(Def));
Def->replaceAllUsesWith(V);
continue;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ define dso_local void @test(ptr %Arr, i32 signext %Len) {
; CHECK: vector.body:
; CHECK-NEXT: [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
; CHECK-NEXT: [[TMP1:%.*]] = sext i32 [[INDEX]] to i64
; CHECK-NEXT: [[TMP2:%.*]] = getelementptr i32, ptr [[ARR:%.*]], i64 [[TMP1]]
; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i32, ptr [[ARR:%.*]], i64 [[TMP1]]
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <4 x i32>, ptr [[TMP2]], align 4
; CHECK-NEXT: [[TMP4:%.*]] = call <4 x i32> @llvm.bswap.v4i32(<4 x i32> [[WIDE_LOAD]])
; CHECK-NEXT: store <4 x i32> [[TMP4]], ptr [[TMP2]], align 4
Expand Down
8 changes: 4 additions & 4 deletions llvm/test/Transforms/LoopVectorize/X86/scatter_crash.ll
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ define void @_Z3fn1v() #0 {
; CHECK-NEXT: [[TMP32:%.*]] = add nsw <16 x i64> [[TMP30]], [[VEC_IND37]]
; CHECK-NEXT: [[TMP33:%.*]] = getelementptr inbounds [10 x i32], <16 x ptr> [[TMP31]], <16 x i64> [[TMP32]], i64 0
; CHECK-NEXT: call void @llvm.masked.scatter.v16i32.v16p0(<16 x i32> splat (i32 8), <16 x ptr> [[TMP33]], i32 16, <16 x i1> [[TMP34]])
; CHECK-NEXT: [[TMP49:%.*]] = or <16 x i64> [[VEC_IND37]], splat (i64 1)
; CHECK-NEXT: [[TMP36:%.*]] = add <16 x i64> [[TMP30]], [[TMP49]]
; CHECK-NEXT: [[TMP49:%.*]] = or disjoint <16 x i64> [[VEC_IND37]], splat (i64 1)
; CHECK-NEXT: [[TMP36:%.*]] = add nsw <16 x i64> [[TMP30]], [[TMP49]]
; CHECK-NEXT: [[TMP37:%.*]] = getelementptr inbounds [10 x i32], <16 x ptr> [[TMP31]], <16 x i64> [[TMP36]], i64 0
; CHECK-NEXT: call void @llvm.masked.scatter.v16i32.v16p0(<16 x i32> splat (i32 8), <16 x ptr> [[TMP37]], i32 8, <16 x i1> [[TMP34]])
; CHECK-NEXT: call void @llvm.masked.scatter.v16i32.v16p0(<16 x i32> splat (i32 7), <16 x ptr> [[TMP33]], i32 16, <16 x i1> [[BROADCAST_SPLAT]])
Expand Down Expand Up @@ -191,8 +191,8 @@ define void @_Z3fn1v() #0 {
; CHECK-NEXT: [[TMP46:%.*]] = add nsw <8 x i64> [[TMP44]], [[VEC_IND70]]
; CHECK-NEXT: [[TMP47:%.*]] = getelementptr inbounds [10 x i32], <8 x ptr> [[TMP45]], <8 x i64> [[TMP46]], i64 0
; CHECK-NEXT: call void @llvm.masked.scatter.v8i32.v8p0(<8 x i32> splat (i32 8), <8 x ptr> [[TMP47]], i32 16, <8 x i1> [[TMP48]])
; CHECK-NEXT: [[TMP54:%.*]] = or <8 x i64> [[VEC_IND70]], splat (i64 1)
; CHECK-NEXT: [[TMP50:%.*]] = add <8 x i64> [[TMP44]], [[TMP54]]
; CHECK-NEXT: [[TMP54:%.*]] = or disjoint <8 x i64> [[VEC_IND70]], splat (i64 1)
; CHECK-NEXT: [[TMP50:%.*]] = add nsw <8 x i64> [[TMP44]], [[TMP54]]
; CHECK-NEXT: [[TMP51:%.*]] = getelementptr inbounds [10 x i32], <8 x ptr> [[TMP45]], <8 x i64> [[TMP50]], i64 0
; CHECK-NEXT: call void @llvm.masked.scatter.v8i32.v8p0(<8 x i32> splat (i32 8), <8 x ptr> [[TMP51]], i32 8, <8 x i1> [[TMP48]])
; CHECK-NEXT: call void @llvm.masked.scatter.v8i32.v8p0(<8 x i32> splat (i32 7), <8 x ptr> [[TMP47]], i32 16, <8 x i1> [[BROADCAST_SPLAT73]])
Expand Down
2 changes: 1 addition & 1 deletion llvm/test/Transforms/LoopVectorize/flags.ll
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ define void @gep_with_shared_nusw_and_others(i64 %n, ptr %A) {
; CHECK-NEXT: br label %[[VECTOR_BODY:.*]]
; CHECK: [[VECTOR_BODY]]:
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr float, ptr [[A]], i64 [[INDEX]]
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr nusw float, ptr [[A]], i64 [[INDEX]]
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <4 x float>, ptr [[TMP1]], align 4
; CHECK-NEXT: store <4 x float> [[WIDE_LOAD]], ptr [[TMP1]], align 4
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4
Expand Down