-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[LV]Fix max safe elements calculations #122148
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[LV]Fix max safe elements calculations #122148
Conversation
Created using spr 1.3.5
|
@llvm/pr-subscribers-llvm-transforms @llvm/pr-subscribers-backend-risc-v Author: Alexey Bataev (alexey-bataev) ChangesCurrently the number of MaxSafeElements is calculated as Full diff: https://github.com/llvm/llvm-project/pull/122148.diff 4 Files Affected:
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 47866dac9ad913..4a1cf64c9cd100 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -3943,11 +3943,13 @@ FixedScalableVFPair LoopVectorizationCostModel::computeFeasibleMaxVF(
// It is computed by MaxVF * sizeOf(type) * 8, where type is taken from
// the memory accesses that is most restrictive (involved in the smallest
// dependence distance).
- unsigned MaxSafeElements =
- llvm::bit_floor(Legal->getMaxSafeVectorWidthInBits() / WidestType);
+ unsigned MaxSafeElements = Legal->getMaxSafeVectorWidthInBits() / WidestType;
+ if (Legal->isSafeForAnyVectorWidth())
+ MaxSafeElements = bit_ceil(MaxSafeElements);
+ unsigned MaxSafeElementsPowerOf2 = 1ULL << countr_zero(MaxSafeElements);
+ auto MaxSafeFixedVF = ElementCount::getFixed(MaxSafeElementsPowerOf2);
+ auto MaxSafeScalableVF = getMaxLegalScalableVF(MaxSafeElementsPowerOf2);
- auto MaxSafeFixedVF = ElementCount::getFixed(MaxSafeElements);
- auto MaxSafeScalableVF = getMaxLegalScalableVF(MaxSafeElements);
if (!Legal->isSafeForAnyVectorWidth())
this->MaxSafeElements = MaxSafeElements;
diff --git a/llvm/test/Transforms/LoopVectorize/RISCV/riscv-vector-reverse.ll b/llvm/test/Transforms/LoopVectorize/RISCV/riscv-vector-reverse.ll
index 951d833fa941e8..65582cdebc7933 100644
--- a/llvm/test/Transforms/LoopVectorize/RISCV/riscv-vector-reverse.ll
+++ b/llvm/test/Transforms/LoopVectorize/RISCV/riscv-vector-reverse.ll
@@ -21,7 +21,7 @@ define void @vector_reverse_i64(ptr nocapture noundef writeonly %A, ptr nocaptur
; CHECK-NEXT: LV: Found trip count: 0
; CHECK-NEXT: LV: Found maximum trip count: 4294967295
; CHECK-NEXT: LV: Scalable vectorization is available
-; CHECK-NEXT: LV: The max safe fixed VF is: 67108864.
+; CHECK-NEXT: LV: The max safe fixed VF is: 134217728.
; CHECK-NEXT: LV: The max safe scalable VF is: vscale x 4294967295.
; CHECK-NEXT: LV: Found uniform instruction: %cmp = icmp ugt i64 %indvars.iv, 1
; CHECK-NEXT: LV: Found uniform instruction: %arrayidx = getelementptr inbounds i32, ptr %B, i64 %idxprom
@@ -271,7 +271,7 @@ define void @vector_reverse_f32(ptr nocapture noundef writeonly %A, ptr nocaptur
; CHECK-NEXT: LV: Found trip count: 0
; CHECK-NEXT: LV: Found maximum trip count: 4294967295
; CHECK-NEXT: LV: Scalable vectorization is available
-; CHECK-NEXT: LV: The max safe fixed VF is: 67108864.
+; CHECK-NEXT: LV: The max safe fixed VF is: 134217728.
; CHECK-NEXT: LV: The max safe scalable VF is: vscale x 4294967295.
; CHECK-NEXT: LV: Found uniform instruction: %cmp = icmp ugt i64 %indvars.iv, 1
; CHECK-NEXT: LV: Found uniform instruction: %arrayidx = getelementptr inbounds float, ptr %B, i64 %idxprom
diff --git a/llvm/test/Transforms/LoopVectorize/memdep-fold-tail.ll b/llvm/test/Transforms/LoopVectorize/memdep-fold-tail.ll
index d1ad7e3f4fc0d8..ea592c1e1063ac 100644
--- a/llvm/test/Transforms/LoopVectorize/memdep-fold-tail.ll
+++ b/llvm/test/Transforms/LoopVectorize/memdep-fold-tail.ll
@@ -24,57 +24,9 @@ target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f3
define void @maxvf3() {
; CHECK-LABEL: @maxvf3(
; CHECK-NEXT: entry:
-; CHECK-NEXT: br i1 false, label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
-; CHECK: vector.ph:
-; CHECK-NEXT: br label [[VECTOR_BODY:%.*]]
-; CHECK: vector.body:
-; CHECK-NEXT: [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[PRED_STORE_CONTINUE6:%.*]] ]
-; CHECK-NEXT: [[VEC_IND:%.*]] = phi <2 x i32> [ <i32 0, i32 1>, [[VECTOR_PH]] ], [ [[VEC_IND_NEXT:%.*]], [[PRED_STORE_CONTINUE6]] ]
-; CHECK-NEXT: [[TMP0:%.*]] = icmp ule <2 x i32> [[VEC_IND]], splat (i32 14)
-; CHECK-NEXT: [[TMP1:%.*]] = extractelement <2 x i1> [[TMP0]], i32 0
-; CHECK-NEXT: br i1 [[TMP1]], label [[PRED_STORE_IF:%.*]], label [[PRED_STORE_CONTINUE:%.*]]
-; CHECK: pred.store.if:
-; CHECK-NEXT: [[TMP2:%.*]] = add i32 [[INDEX]], 0
-; CHECK-NEXT: [[TMP3:%.*]] = getelementptr inbounds [18 x i8], ptr @a, i32 0, i32 [[TMP2]]
-; CHECK-NEXT: store i8 69, ptr [[TMP3]], align 8
-; CHECK-NEXT: br label [[PRED_STORE_CONTINUE]]
-; CHECK: pred.store.continue:
-; CHECK-NEXT: [[TMP4:%.*]] = extractelement <2 x i1> [[TMP0]], i32 1
-; CHECK-NEXT: br i1 [[TMP4]], label [[PRED_STORE_IF1:%.*]], label [[PRED_STORE_CONTINUE2:%.*]]
-; CHECK: pred.store.if1:
-; CHECK-NEXT: [[TMP5:%.*]] = add i32 [[INDEX]], 1
-; CHECK-NEXT: [[TMP6:%.*]] = getelementptr inbounds [18 x i8], ptr @a, i32 0, i32 [[TMP5]]
-; CHECK-NEXT: store i8 69, ptr [[TMP6]], align 8
-; CHECK-NEXT: br label [[PRED_STORE_CONTINUE2]]
-; CHECK: pred.store.continue2:
-; CHECK-NEXT: [[TMP7:%.*]] = add nuw nsw <2 x i32> splat (i32 3), [[VEC_IND]]
-; CHECK-NEXT: [[TMP8:%.*]] = extractelement <2 x i1> [[TMP0]], i32 0
-; CHECK-NEXT: br i1 [[TMP8]], label [[PRED_STORE_IF3:%.*]], label [[PRED_STORE_CONTINUE4:%.*]]
-; CHECK: pred.store.if3:
-; CHECK-NEXT: [[TMP9:%.*]] = extractelement <2 x i32> [[TMP7]], i32 0
-; CHECK-NEXT: [[TMP10:%.*]] = getelementptr inbounds [18 x i8], ptr @a, i32 0, i32 [[TMP9]]
-; CHECK-NEXT: store i8 7, ptr [[TMP10]], align 8
-; CHECK-NEXT: br label [[PRED_STORE_CONTINUE4]]
-; CHECK: pred.store.continue4:
-; CHECK-NEXT: [[TMP11:%.*]] = extractelement <2 x i1> [[TMP0]], i32 1
-; CHECK-NEXT: br i1 [[TMP11]], label [[PRED_STORE_IF5:%.*]], label [[PRED_STORE_CONTINUE6]]
-; CHECK: pred.store.if5:
-; CHECK-NEXT: [[TMP12:%.*]] = extractelement <2 x i32> [[TMP7]], i32 1
-; CHECK-NEXT: [[TMP13:%.*]] = getelementptr inbounds [18 x i8], ptr @a, i32 0, i32 [[TMP12]]
-; CHECK-NEXT: store i8 7, ptr [[TMP13]], align 8
-; CHECK-NEXT: br label [[PRED_STORE_CONTINUE6]]
-; CHECK: pred.store.continue6:
-; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 2
-; CHECK-NEXT: [[VEC_IND_NEXT]] = add <2 x i32> [[VEC_IND]], splat (i32 2)
-; CHECK-NEXT: [[TMP14:%.*]] = icmp eq i32 [[INDEX_NEXT]], 16
-; CHECK-NEXT: br i1 [[TMP14]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
-; CHECK: middle.block:
-; CHECK-NEXT: br i1 true, label [[FOR_END:%.*]], label [[SCALAR_PH]]
-; CHECK: scalar.ph:
-; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi i32 [ 16, [[MIDDLE_BLOCK]] ], [ 0, [[ENTRY:%.*]] ]
; CHECK-NEXT: br label [[FOR_BODY:%.*]]
; CHECK: for.body:
-; CHECK-NEXT: [[J:%.*]] = phi i32 [ [[BC_RESUME_VAL]], [[SCALAR_PH]] ], [ [[J_NEXT:%.*]], [[FOR_BODY]] ]
+; CHECK-NEXT: [[J:%.*]] = phi i32 [ 0, [[ENTRY:%.*]] ], [ [[J_NEXT:%.*]], [[FOR_BODY]] ]
; CHECK-NEXT: [[AJ:%.*]] = getelementptr inbounds [18 x i8], ptr @a, i32 0, i32 [[J]]
; CHECK-NEXT: store i8 69, ptr [[AJ]], align 8
; CHECK-NEXT: [[JP3:%.*]] = add nuw nsw i32 3, [[J]]
@@ -82,7 +34,7 @@ define void @maxvf3() {
; CHECK-NEXT: store i8 7, ptr [[AJP3]], align 8
; CHECK-NEXT: [[J_NEXT]] = add nuw nsw i32 [[J]], 1
; CHECK-NEXT: [[EXITCOND:%.*]] = icmp eq i32 [[J_NEXT]], 15
-; CHECK-NEXT: br i1 [[EXITCOND]], label [[FOR_END]], label [[FOR_BODY]], !llvm.loop [[LOOP3:![0-9]+]]
+; CHECK-NEXT: br i1 [[EXITCOND]], label [[FOR_END:%.*]], label [[FOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
; CHECK: for.end:
; CHECK-NEXT: ret void
;
diff --git a/llvm/test/Transforms/LoopVectorize/memdep.ll b/llvm/test/Transforms/LoopVectorize/memdep.ll
index b891b4312f18d3..28cf3b61b2554a 100644
--- a/llvm/test/Transforms/LoopVectorize/memdep.ll
+++ b/llvm/test/Transforms/LoopVectorize/memdep.ll
@@ -226,7 +226,7 @@ for.end:
;Check the new calculation of the maximum safe distance in bits which can be vectorized.
;The previous behavior did not take account that the stride was 2.
-;Therefore the maxVF was computed as 8 instead of 4, as the dependence distance here is 6 iterations, given by |N-(N-12)|/2.
+;Therefore the maxVF was computed as 8 instead of 2, as the dependence distance here is 6 iterations, given by |N-(N-12)|/2.
;#define M 32
;#define N 2 * M
@@ -242,7 +242,7 @@ for.end:
;}
; RIGHTVF-LABEL: @pr34283
-; RIGHTVF: <4 x i64>
+; RIGHTVF: <2 x i64>
; WRONGVF-LABLE: @pr34283
; WRONGVF-NOT: <8 x i64>
|
david-arm
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like a reasonable change to me. I just had a couple of comments!
| unsigned MaxSafeElements = | ||
| llvm::bit_floor(Legal->getMaxSafeVectorWidthInBits() / WidestType); | ||
| unsigned MaxSafeElements = Legal->getMaxSafeVectorWidthInBits() / WidestType; | ||
| if (Legal->isSafeForAnyVectorWidth()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like MaxSafeElements should only be different after calling bit_ceil if WidestType isn't a power of 2. If you comment out this code do any tests fail? I'm just wondering if we actually have any test coverage where isSafeForAnyVectorWidth returns true and WidestType is non-power-of-2. If not, it would be nice to add a test I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, actually WidestType is always of power of 2. Legal->getMaxSafeVectorWidthInBits() can be non-power-of-2 because of load-store forwarding, if there is non-power-of-2 dependency between loads/stores.
Another problem is that if Legal->isSafeForAnyVectorWidth() is true, then Legal->getMaxSafeVectorWidthInBits() returns UMAX_INT, which is non-power-of (2^64-1), so need to use bit_ceil to avoid some (potential) VF dropping.
| unsigned MaxSafeElements = Legal->getMaxSafeVectorWidthInBits() / WidestType; | ||
| if (Legal->isSafeForAnyVectorWidth()) | ||
| MaxSafeElements = bit_ceil(MaxSafeElements); | ||
| unsigned MaxSafeElementsPowerOf2 = 1ULL << countr_zero(MaxSafeElements); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the case where isSafeForAnyVectorWidth returns true doesn't this simply return the same value as MaxSafeElements? I just wonder if you can rewrite the code as:
unsigned MaxSafeElementsPowerOf2;
if (Legal->isSafeForAnyVectorWidth())
MaxSafeElementsPowerOf2 = bit_ceil(MaxSafeElements);
else
MaxSafeElementsPowerOf2 = 1ULL << countr_zero(MaxSafeElements);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, but need to save modified MaxSafeElements value in the corresponding data member
Created using spr 1.3.5
|
Ping! |
| define void @maxvf3() { | ||
| ; CHECK-LABEL: @maxvf3( | ||
| ; CHECK-NEXT: entry: | ||
| ; CHECK-NEXT: br i1 false, label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks like a regression? The maximum safe VF should be 2 in this case I think, as the dependence isn't classified as preventing load-store forwarding?
|
Dropping after some more investigations, that the existing code is correct |
Currently the number of MaxSafeElements is calculated as
bit_floor(safe-vector-factor). But safe-vector-factor might be
non-power-of-2 value, so bit_floor cannot be used here (the resulting
value may be not precise divisor of safe-vector-factor). Instead,
1<<countr_zero(MaxSafeElements) should be used as safe power-of-2
distance.