-
Notifications
You must be signed in to change notification settings - Fork 15k
[X86] Fix assertion in AVX512 setcc combine due to invalid APInt mask width #155775
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
… width The AVX512 setcc combine in X86ISelLowering was calling `APInt::getLowBitsSet` with a mask width (`Len`) that could exceed the bit width of the broadcasted scalar operand (`BroadcastOpVT.getSizeInBits()`), leading to assertion failures. This patch replaces `Len` with the number of defined (non-undef) elements in the constant pool vector, computed using `UndefElts.popcount()`. It also introduces a named variable `BroadcastOpBitWidth` for clarity. This ensures the generated mask is valid and avoids crashes when the constant pool contains more elements than the scalar bit width can represent. Fixes llvm#155762
|
@llvm/pr-subscribers-backend-x86 Author: Abhishek Kaushik (abhishek-kaushik22) ChangesThe AVX512 setcc combine in X86ISelLowering was calling This patch replaces This ensures the generated mask is valid and avoids crashes when the constant pool contains more elements than the scalar bit width can represent. Fixes #155762 Full diff: https://github.com/llvm/llvm-project/pull/155775.diff 2 Files Affected:
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 19131fbd4102b..2d376a434123a 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -56247,7 +56247,13 @@ static SDValue combineAVX512SetCCToKMOV(EVT VT, SDValue Op0, ISD::CondCode CC,
SDValue Masked = BroadcastOp;
if (N != 0) {
- APInt Mask = APInt::getLowBitsSet(BroadcastOpVT.getSizeInBits(), Len);
+ unsigned BroadcastOpBitWidth = BroadcastOpVT.getSizeInBits();
+ unsigned NumDefinedElts = UndefElts.getBitWidth() - UndefElts.popcount();
+
+ if (NumDefinedElts > BroadcastOpBitWidth)
+ return SDValue();
+
+ APInt Mask = APInt::getLowBitsSet(BroadcastOpBitWidth, NumDefinedElts);
SDValue ShiftedValue = DAG.getNode(ISD::SRL, DL, BroadcastOpVT, BroadcastOp,
DAG.getConstant(N, DL, BroadcastOpVT));
Masked = DAG.getNode(ISD::AND, DL, BroadcastOpVT, ShiftedValue,
diff --git a/llvm/test/CodeGen/X86/kmov.ll b/llvm/test/CodeGen/X86/kmov.ll
index cab810d30cd77..8b1e69a97d545 100644
--- a/llvm/test/CodeGen/X86/kmov.ll
+++ b/llvm/test/CodeGen/X86/kmov.ll
@@ -143,6 +143,57 @@ define <8 x i1> @invert_i8_mask_extract_8(i8 %mask) {
ret <8 x i1> %cmp.45
}
+define <8 x i1> @i8_mask_extract_7(i8 %mask) {
+; X64-AVX512-LABEL: i8_mask_extract_7:
+; X64-AVX512: # %bb.0:
+; X64-AVX512-NEXT: shrb %dil
+; X64-AVX512-NEXT: movzbl %dil, %eax
+; X64-AVX512-NEXT: kmovd %eax, %k0
+; X64-AVX512-NEXT: vpmovm2w %k0, %xmm0
+; X64-AVX512-NEXT: retq
+;
+; X64-KNL-LABEL: i8_mask_extract_7:
+; X64-KNL: # %bb.0:
+; X64-KNL-NEXT: vmovd %edi, %xmm0
+; X64-KNL-NEXT: vpbroadcastb %xmm0, %xmm0
+; X64-KNL-NEXT: vpbroadcastq {{.*#+}} xmm1 = [2,4,8,16,32,64,128,0,2,4,8,16,32,64,128,0]
+; X64-KNL-NEXT: vpand %xmm1, %xmm0, %xmm0
+; X64-KNL-NEXT: vpcmpeqb %xmm1, %xmm0, %xmm0
+; X64-KNL-NEXT: vpmovzxbw {{.*#+}} xmm0 = xmm0[0],zero,xmm0[1],zero,xmm0[2],zero,xmm0[3],zero,xmm0[4],zero,xmm0[5],zero,xmm0[6],zero,xmm0[7],zero
+; X64-KNL-NEXT: retq
+ %.splatinsert = insertelement <8 x i8> poison, i8 %mask, i64 0
+ %.splat = shufflevector <8 x i8> %.splatinsert, <8 x i8> poison, <8 x i32> zeroinitializer
+ %1 = and <8 x i8> %.splat, <i8 2, i8 4, i8 8, i8 16, i8 32, i8 64, i8 128, i8 poison>
+ %cmp.45 = icmp ne <8 x i8> %1, zeroinitializer
+ ret <8 x i1> %cmp.45
+}
+
+define <8 x i1> @invert_i8_mask_extract_7(i8 %mask) {
+; X64-AVX512-LABEL: invert_i8_mask_extract_7:
+; X64-AVX512: # %bb.0:
+; X64-AVX512-NEXT: shrb %dil
+; X64-AVX512-NEXT: movzbl %dil, %eax
+; X64-AVX512-NEXT: kmovd %eax, %k0
+; X64-AVX512-NEXT: knotb %k0, %k0
+; X64-AVX512-NEXT: vpmovm2w %k0, %xmm0
+; X64-AVX512-NEXT: retq
+;
+; X64-KNL-LABEL: invert_i8_mask_extract_7:
+; X64-KNL: # %bb.0:
+; X64-KNL-NEXT: vmovd %edi, %xmm0
+; X64-KNL-NEXT: vpbroadcastb %xmm0, %xmm0
+; X64-KNL-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
+; X64-KNL-NEXT: vpxor %xmm1, %xmm1, %xmm1
+; X64-KNL-NEXT: vpcmpeqb %xmm1, %xmm0, %xmm0
+; X64-KNL-NEXT: vpmovzxbw {{.*#+}} xmm0 = xmm0[0],zero,xmm0[1],zero,xmm0[2],zero,xmm0[3],zero,xmm0[4],zero,xmm0[5],zero,xmm0[6],zero,xmm0[7],zero
+; X64-KNL-NEXT: retq
+ %.splatinsert = insertelement <8 x i8> poison, i8 %mask, i64 0
+ %.splat = shufflevector <8 x i8> %.splatinsert, <8 x i8> poison, <8 x i32> zeroinitializer
+ %1 = and <8 x i8> %.splat, <i8 2, i8 4, i8 8, i8 16, i8 32, i8 64, i8 128, i8 poison>
+ %cmp.45 = icmp eq <8 x i8> %1, zeroinitializer
+ ret <8 x i1> %cmp.45
+}
+
define <4 x i1> @i16_mask_extract_4(i16 %mask) {
; X64-AVX512-LABEL: i16_mask_extract_4:
; X64-AVX512: # %bb.0:
|
| if (N != 0) { | ||
| APInt Mask = APInt::getLowBitsSet(BroadcastOpVT.getSizeInBits(), Len); | ||
| unsigned BroadcastOpBitWidth = BroadcastOpVT.getSizeInBits(); | ||
| unsigned NumDefinedElts = UndefElts.getBitWidth() - UndefElts.popcount(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should it be UndefElts.getActiveBits?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the target constant
<i8 2, i8 4, i8 8, i8 16, i8 32, i8 64, i8 -128, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef>UndefElts is 1111 1111 1000 0000, getActiveBits returns 16, but we want the number of zeros (the defined elements). Since they will always be continuous and at the end, I've updated the code to use countTrailingZeros
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM - cheers
|
/cherry-pick 33b2c26 |
|
/pull-request #158206 |
… width (llvm#155775) The AVX512 setcc combine in X86ISelLowering was calling `APInt::getLowBitsSet` with a mask width (`Len`) that could exceed the bit width of the broadcasted scalar operand (`BroadcastOpVT.getSizeInBits()`), leading to assertion failures. This patch replaces `Len` with the number of defined (non-undef) elements in the constant pool vector. This ensures the generated mask is valid and avoids crashes when the constant pool contains more elements than the scalar bit width can represent. Fixes llvm#155762 (cherry picked from commit 33b2c26)
The AVX512 setcc combine in X86ISelLowering was calling
APInt::getLowBitsSetwith a mask width (Len) that could exceed the bit width of the broadcasted scalar operand (BroadcastOpVT.getSizeInBits()), leading to assertion failures.This patch replaces
Lenwith the number of defined (non-undef) elements in the constant pool vector.This ensures the generated mask is valid and avoids crashes when the constant pool contains more elements than the scalar bit width can represent.
Fixes #155762