-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[LLVM][CodeGen][SVE] Improve lowering of fixed length masked mem ops. #134402
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Converting fixed length masks, as used by MLOAD, to scalable vectors is done by comparing the mask to zero. When the mask is the result of a compare we can instead promote the operands and regenerate the original compare. At worst this reduces the dependecy chain and in most cases removes the need for multiple compares.
|
@llvm/pr-subscribers-backend-aarch64 Author: Paul Walker (paulwalker-arm) ChangesConverting fixed length masks, as used by MLOAD, to scalable vectors is done by comparing the mask to zero. When the mask is the result of a compare we can instead promote the operands and regenerate the original compare. At worst this reduces the dependecy chain and in most cases removes the need for multiple compares. Full diff: https://github.com/llvm/llvm-project/pull/134402.diff 4 Files Affected:
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index a1ba3922996a1..57a950cfc702a 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -20190,6 +20190,12 @@ performInsertSubvectorCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
EVT VecVT = Vec.getValueType();
EVT SubVT = SubVec.getValueType();
+ // Promote fixed length vector zeros.
+ if (VecVT.isScalableVector() && SubVT.isFixedLengthVector() &&
+ Vec.isUndef() && isZerosVector(SubVec.getNode()))
+ return VecVT.isInteger() ? DAG.getConstant(0, DL, VecVT)
+ : DAG.getConstantFP(0, DL, VecVT);
+
// Only do this for legal fixed vector types.
if (!VecVT.isFixedLengthVector() ||
!DAG.getTargetLoweringInfo().isTypeLegal(VecVT) ||
@@ -28697,17 +28703,36 @@ static SDValue convertFixedMaskToScalableVector(SDValue Mask,
SDLoc DL(Mask);
EVT InVT = Mask.getValueType();
EVT ContainerVT = getContainerForFixedLengthVector(DAG, InVT);
-
- auto Pg = getPredicateForFixedLengthVector(DAG, DL, InVT);
+ SDValue Pg = getPredicateForFixedLengthVector(DAG, DL, InVT);
if (ISD::isBuildVectorAllOnes(Mask.getNode()))
return Pg;
- auto Op1 = convertToScalableVector(DAG, ContainerVT, Mask);
- auto Op2 = DAG.getConstant(0, DL, ContainerVT);
+ bool InvertCond = false;
+ if (isBitwiseNot(Mask)) {
+ InvertCond = true;
+ Mask = Mask.getOperand(0);
+ }
+
+ SDValue Op1, Op2;
+ ISD::CondCode CC;
+
+ // When Mask is the result of a SETCC, it's better to regenerate the compare.
+ if (Mask.getOpcode() == ISD::SETCC) {
+ Op1 = convertToScalableVector(DAG, ContainerVT, Mask.getOperand(0));
+ Op2 = convertToScalableVector(DAG, ContainerVT, Mask.getOperand(1));
+ CC = cast<CondCodeSDNode>(Mask.getOperand(2))->get();
+ } else {
+ Op1 = convertToScalableVector(DAG, ContainerVT, Mask);
+ Op2 = DAG.getConstant(0, DL, ContainerVT);
+ CC = ISD::SETNE;
+ }
+
+ if (InvertCond)
+ CC = getSetCCInverse(CC, Op1.getValueType());
return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, DL, Pg.getValueType(),
- {Pg, Op1, Op2, DAG.getCondCode(ISD::SETNE)});
+ {Pg, Op1, Op2, DAG.getCondCode(CC)});
}
// Convert all fixed length vector loads larger than NEON to masked_loads.
diff --git a/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-gather.ll b/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-gather.ll
index a50d0dc37eaf6..093e6cd9328c8 100644
--- a/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-gather.ll
+++ b/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-gather.ll
@@ -460,10 +460,9 @@ define void @masked_gather_v1i64(ptr %a, ptr %b) vscale_range(2,0) #0 {
define void @masked_gather_v2i64(ptr %a, ptr %b) vscale_range(2,0) #0 {
; CHECK-LABEL: masked_gather_v2i64:
; CHECK: // %bb.0:
-; CHECK-NEXT: ldr q0, [x0]
; CHECK-NEXT: ptrue p0.d, vl2
-; CHECK-NEXT: cmeq v0.2d, v0.2d, #0
-; CHECK-NEXT: cmpne p0.d, p0/z, z0.d, #0
+; CHECK-NEXT: ldr q0, [x0]
+; CHECK-NEXT: cmpeq p0.d, p0/z, z0.d, #0
; CHECK-NEXT: ldr q0, [x1]
; CHECK-NEXT: ld1d { z0.d }, p0/z, [z0.d]
; CHECK-NEXT: str q0, [x0]
diff --git a/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-loads.ll b/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-loads.ll
index 6513b01d00922..34dc0bb5ef2d2 100644
--- a/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-loads.ll
+++ b/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-loads.ll
@@ -401,11 +401,10 @@ define void @masked_load_sext_v32i8i16(ptr %ap, ptr %bp, ptr %c) #0 {
define void @masked_load_sext_v16i8i32(ptr %ap, ptr %bp, ptr %c) #0 {
; VBITS_GE_256-LABEL: masked_load_sext_v16i8i32:
; VBITS_GE_256: // %bb.0:
-; VBITS_GE_256-NEXT: ldr q0, [x1]
; VBITS_GE_256-NEXT: ptrue p0.b, vl16
+; VBITS_GE_256-NEXT: ldr q0, [x1]
; VBITS_GE_256-NEXT: mov x8, #8 // =0x8
-; VBITS_GE_256-NEXT: cmeq v0.16b, v0.16b, #0
-; VBITS_GE_256-NEXT: cmpne p0.b, p0/z, z0.b, #0
+; VBITS_GE_256-NEXT: cmpeq p0.b, p0/z, z0.b, #0
; VBITS_GE_256-NEXT: ld1b { z0.b }, p0/z, [x0]
; VBITS_GE_256-NEXT: ptrue p0.s, vl8
; VBITS_GE_256-NEXT: ext v1.16b, v0.16b, v0.16b, #8
@@ -436,11 +435,10 @@ define void @masked_load_sext_v16i8i32(ptr %ap, ptr %bp, ptr %c) #0 {
define void @masked_load_sext_v8i8i64(ptr %ap, ptr %bp, ptr %c) #0 {
; VBITS_GE_256-LABEL: masked_load_sext_v8i8i64:
; VBITS_GE_256: // %bb.0:
-; VBITS_GE_256-NEXT: ldr d0, [x1]
; VBITS_GE_256-NEXT: ptrue p0.b, vl8
+; VBITS_GE_256-NEXT: ldr d0, [x1]
; VBITS_GE_256-NEXT: mov x8, #4 // =0x4
-; VBITS_GE_256-NEXT: cmeq v0.8b, v0.8b, #0
-; VBITS_GE_256-NEXT: cmpne p0.b, p0/z, z0.b, #0
+; VBITS_GE_256-NEXT: cmpeq p0.b, p0/z, z0.b, #0
; VBITS_GE_256-NEXT: ld1b { z0.b }, p0/z, [x0]
; VBITS_GE_256-NEXT: ptrue p0.d, vl4
; VBITS_GE_256-NEXT: sshll v0.8h, v0.8b, #0
@@ -504,11 +502,10 @@ define void @masked_load_sext_v16i16i32(ptr %ap, ptr %bp, ptr %c) #0 {
define void @masked_load_sext_v8i16i64(ptr %ap, ptr %bp, ptr %c) #0 {
; VBITS_GE_256-LABEL: masked_load_sext_v8i16i64:
; VBITS_GE_256: // %bb.0:
-; VBITS_GE_256-NEXT: ldr q0, [x1]
; VBITS_GE_256-NEXT: ptrue p0.h, vl8
+; VBITS_GE_256-NEXT: ldr q0, [x1]
; VBITS_GE_256-NEXT: mov x8, #4 // =0x4
-; VBITS_GE_256-NEXT: cmeq v0.8h, v0.8h, #0
-; VBITS_GE_256-NEXT: cmpne p0.h, p0/z, z0.h, #0
+; VBITS_GE_256-NEXT: cmpeq p0.h, p0/z, z0.h, #0
; VBITS_GE_256-NEXT: ld1h { z0.h }, p0/z, [x0]
; VBITS_GE_256-NEXT: ptrue p0.d, vl4
; VBITS_GE_256-NEXT: ext v1.16b, v0.16b, v0.16b, #8
@@ -603,11 +600,10 @@ define void @masked_load_zext_v32i8i16(ptr %ap, ptr %bp, ptr %c) #0 {
define void @masked_load_zext_v16i8i32(ptr %ap, ptr %bp, ptr %c) #0 {
; VBITS_GE_256-LABEL: masked_load_zext_v16i8i32:
; VBITS_GE_256: // %bb.0:
-; VBITS_GE_256-NEXT: ldr q0, [x1]
; VBITS_GE_256-NEXT: ptrue p0.b, vl16
+; VBITS_GE_256-NEXT: ldr q0, [x1]
; VBITS_GE_256-NEXT: mov x8, #8 // =0x8
-; VBITS_GE_256-NEXT: cmeq v0.16b, v0.16b, #0
-; VBITS_GE_256-NEXT: cmpne p0.b, p0/z, z0.b, #0
+; VBITS_GE_256-NEXT: cmpeq p0.b, p0/z, z0.b, #0
; VBITS_GE_256-NEXT: ld1b { z0.b }, p0/z, [x0]
; VBITS_GE_256-NEXT: ptrue p0.s, vl8
; VBITS_GE_256-NEXT: ext v1.16b, v0.16b, v0.16b, #8
@@ -638,11 +634,10 @@ define void @masked_load_zext_v16i8i32(ptr %ap, ptr %bp, ptr %c) #0 {
define void @masked_load_zext_v8i8i64(ptr %ap, ptr %bp, ptr %c) #0 {
; VBITS_GE_256-LABEL: masked_load_zext_v8i8i64:
; VBITS_GE_256: // %bb.0:
-; VBITS_GE_256-NEXT: ldr d0, [x1]
; VBITS_GE_256-NEXT: ptrue p0.b, vl8
+; VBITS_GE_256-NEXT: ldr d0, [x1]
; VBITS_GE_256-NEXT: mov x8, #4 // =0x4
-; VBITS_GE_256-NEXT: cmeq v0.8b, v0.8b, #0
-; VBITS_GE_256-NEXT: cmpne p0.b, p0/z, z0.b, #0
+; VBITS_GE_256-NEXT: cmpeq p0.b, p0/z, z0.b, #0
; VBITS_GE_256-NEXT: ld1b { z0.b }, p0/z, [x0]
; VBITS_GE_256-NEXT: ptrue p0.d, vl4
; VBITS_GE_256-NEXT: ushll v0.8h, v0.8b, #0
@@ -706,11 +701,10 @@ define void @masked_load_zext_v16i16i32(ptr %ap, ptr %bp, ptr %c) #0 {
define void @masked_load_zext_v8i16i64(ptr %ap, ptr %bp, ptr %c) #0 {
; VBITS_GE_256-LABEL: masked_load_zext_v8i16i64:
; VBITS_GE_256: // %bb.0:
-; VBITS_GE_256-NEXT: ldr q0, [x1]
; VBITS_GE_256-NEXT: ptrue p0.h, vl8
+; VBITS_GE_256-NEXT: ldr q0, [x1]
; VBITS_GE_256-NEXT: mov x8, #4 // =0x4
-; VBITS_GE_256-NEXT: cmeq v0.8h, v0.8h, #0
-; VBITS_GE_256-NEXT: cmpne p0.h, p0/z, z0.h, #0
+; VBITS_GE_256-NEXT: cmpeq p0.h, p0/z, z0.h, #0
; VBITS_GE_256-NEXT: ld1h { z0.h }, p0/z, [x0]
; VBITS_GE_256-NEXT: ptrue p0.d, vl4
; VBITS_GE_256-NEXT: ext v1.16b, v0.16b, v0.16b, #8
diff --git a/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-scatter.ll b/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-scatter.ll
index a42fce70f4f15..ed03f9b322432 100644
--- a/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-scatter.ll
+++ b/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-scatter.ll
@@ -433,11 +433,10 @@ define void @masked_scatter_v1i64(ptr %a, ptr %b) vscale_range(2,0) #0 {
define void @masked_scatter_v2i64(ptr %a, ptr %b) vscale_range(2,0) #0 {
; CHECK-LABEL: masked_scatter_v2i64:
; CHECK: // %bb.0:
-; CHECK-NEXT: ldr q0, [x0]
; CHECK-NEXT: ptrue p0.d, vl2
-; CHECK-NEXT: cmeq v1.2d, v0.2d, #0
-; CHECK-NEXT: cmpne p0.d, p0/z, z1.d, #0
+; CHECK-NEXT: ldr q0, [x0]
; CHECK-NEXT: ldr q1, [x1]
+; CHECK-NEXT: cmpeq p0.d, p0/z, z0.d, #0
; CHECK-NEXT: st1d { z0.d }, p0, [z1.d]
; CHECK-NEXT: ret
%vals = load <2 x i64>, ptr %a
|
| if (VecVT.isScalableVector() && SubVT.isFixedLengthVector() && | ||
| Vec.isUndef() && isZerosVector(SubVec.getNode())) | ||
| return VecVT.isInteger() ? DAG.getConstant(0, DL, VecVT) | ||
| : DAG.getConstantFP(0, DL, VecVT); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was forced to write this to maintain existing code quality. There is no specific reason to limit the combine to zeros but I figured any expansion was best done in a separate PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah this does look like a useful combine when applied to other constants, but like you say best for another PR. Not sure in practice if we'll actually end up with different assembly or not?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure in practice if we'll actually end up with different assembly or not?
We will because NEON only has reg-reg and reg-zero compare instructions whereas SVE has reg-imm as well. You can see this today by changing the existing SVE VLS tests to use non-zero immediates where the generated code emits an unnecessary splat.
| ISD::CondCode CC; | ||
|
|
||
| // When Mask is the result of a SETCC, it's better to regenerate the compare. | ||
| if (Mask.getOpcode() == ISD::SETCC) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Could this be extended to peak through ISD::SIGN_EXTEND too? I'm thinking of cases such as:
t14: v16i8 = sign_extend t6
t6: v16i1 = setcc t2, t4, seteq:ch
I've seen this come up when using @llvm.experimental.cttz.elts with fixed-length vectors (although presumably it's a general pattern), e.g.:
define i64 @cmpeq_i8(<16 x i8> %a, <16 x i8> %b) {
%cmp = icmp eq <16 x i8> %a, %b
%ctz = tail call i64 @llvm.experimental.cttz.elts(<16 x i1> %cmp, i1 1)
ret i64 %ctz
}Otherwise I'll have a look into it once this PR lands. :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For operation legalisation I would not expect to see such code because v16i1 is not a legal type. Typically these would be merged so that for operation legalisation you'd just see v16i8 = setcc t2, t4, seteq:ch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, I'll have a look later then. Thanks :)
huntergr-arm
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
|
|
||
| // When Mask is the result of a SETCC, it's better to regenerate the compare. | ||
| if (Mask.getOpcode() == ISD::SETCC) { | ||
| Op1 = convertToScalableVector(DAG, ContainerVT, Mask.getOperand(0)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is ContainerVT guaranteed to be correct for the SETCC inputs? It looks like we base ContainerVT on the result. I'm not sure if something like this is legal for NEON:
v4i32 = SETCC NE v4i16 %a, v4i16 %b
since v4i16 is also a legal type. I'm just a bit worried that we're effectively promoting a type here and, if so, is that a problem?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't believe mixing element sizes like this is legal for NEON. The expectation is that all vector types will have the same element count and bit length. You can see this today (albeit slightly less so since I've refactored the integer side) but for NEON we simply lower SETCC operations onto AArch64ISD::FCM## operations who definitions have the same requirement.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair enough! Just wanted to make sure.
Converting fixed length masks, as used by MLOAD, to scalable vectors is done by comparing the mask to zero. When the mask is the result of a compare we can instead promote the operands and regenerate the original compare. At worst this reduces the dependecy chain and in most cases removes the need for multiple compares.