Skip to content

Conversation

@paulwalker-arm
Copy link
Collaborator

Converting fixed length masks, as used by MLOAD, to scalable vectors is done by comparing the mask to zero. When the mask is the result of a compare we can instead promote the operands and regenerate the original compare. At worst this reduces the dependecy chain and in most cases removes the need for multiple compares.

Converting fixed length masks, as used by MLOAD, to scalable vectors
is done by comparing the mask to zero. When the mask is the result
of a compare we can instead promote the operands and regenerate the
original compare. At worst this reduces the dependecy chain and in
most cases removes the need for multiple compares.
@llvmbot
Copy link
Member

llvmbot commented Apr 4, 2025

@llvm/pr-subscribers-backend-aarch64

Author: Paul Walker (paulwalker-arm)

Changes

Converting fixed length masks, as used by MLOAD, to scalable vectors is done by comparing the mask to zero. When the mask is the result of a compare we can instead promote the operands and regenerate the original compare. At worst this reduces the dependecy chain and in most cases removes the need for multiple compares.


Full diff: https://github.com/llvm/llvm-project/pull/134402.diff

4 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+30-5)
  • (modified) llvm/test/CodeGen/AArch64/sve-fixed-length-masked-gather.ll (+2-3)
  • (modified) llvm/test/CodeGen/AArch64/sve-fixed-length-masked-loads.ll (+12-18)
  • (modified) llvm/test/CodeGen/AArch64/sve-fixed-length-masked-scatter.ll (+2-3)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index a1ba3922996a1..57a950cfc702a 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -20190,6 +20190,12 @@ performInsertSubvectorCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
   EVT VecVT = Vec.getValueType();
   EVT SubVT = SubVec.getValueType();
 
+  // Promote fixed length vector zeros.
+  if (VecVT.isScalableVector() && SubVT.isFixedLengthVector() &&
+      Vec.isUndef() && isZerosVector(SubVec.getNode()))
+    return VecVT.isInteger() ? DAG.getConstant(0, DL, VecVT)
+                             : DAG.getConstantFP(0, DL, VecVT);
+
   // Only do this for legal fixed vector types.
   if (!VecVT.isFixedLengthVector() ||
       !DAG.getTargetLoweringInfo().isTypeLegal(VecVT) ||
@@ -28697,17 +28703,36 @@ static SDValue convertFixedMaskToScalableVector(SDValue Mask,
   SDLoc DL(Mask);
   EVT InVT = Mask.getValueType();
   EVT ContainerVT = getContainerForFixedLengthVector(DAG, InVT);
-
-  auto Pg = getPredicateForFixedLengthVector(DAG, DL, InVT);
+  SDValue Pg = getPredicateForFixedLengthVector(DAG, DL, InVT);
 
   if (ISD::isBuildVectorAllOnes(Mask.getNode()))
     return Pg;
 
-  auto Op1 = convertToScalableVector(DAG, ContainerVT, Mask);
-  auto Op2 = DAG.getConstant(0, DL, ContainerVT);
+  bool InvertCond = false;
+  if (isBitwiseNot(Mask)) {
+    InvertCond = true;
+    Mask = Mask.getOperand(0);
+  }
+
+  SDValue Op1, Op2;
+  ISD::CondCode CC;
+
+  // When Mask is the result of a SETCC, it's better to regenerate the compare.
+  if (Mask.getOpcode() == ISD::SETCC) {
+    Op1 = convertToScalableVector(DAG, ContainerVT, Mask.getOperand(0));
+    Op2 = convertToScalableVector(DAG, ContainerVT, Mask.getOperand(1));
+    CC = cast<CondCodeSDNode>(Mask.getOperand(2))->get();
+  } else {
+    Op1 = convertToScalableVector(DAG, ContainerVT, Mask);
+    Op2 = DAG.getConstant(0, DL, ContainerVT);
+    CC = ISD::SETNE;
+  }
+
+  if (InvertCond)
+    CC = getSetCCInverse(CC, Op1.getValueType());
 
   return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, DL, Pg.getValueType(),
-                     {Pg, Op1, Op2, DAG.getCondCode(ISD::SETNE)});
+                     {Pg, Op1, Op2, DAG.getCondCode(CC)});
 }
 
 // Convert all fixed length vector loads larger than NEON to masked_loads.
diff --git a/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-gather.ll b/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-gather.ll
index a50d0dc37eaf6..093e6cd9328c8 100644
--- a/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-gather.ll
+++ b/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-gather.ll
@@ -460,10 +460,9 @@ define void @masked_gather_v1i64(ptr %a, ptr %b) vscale_range(2,0) #0 {
 define void @masked_gather_v2i64(ptr %a, ptr %b) vscale_range(2,0) #0 {
 ; CHECK-LABEL: masked_gather_v2i64:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    ldr q0, [x0]
 ; CHECK-NEXT:    ptrue p0.d, vl2
-; CHECK-NEXT:    cmeq v0.2d, v0.2d, #0
-; CHECK-NEXT:    cmpne p0.d, p0/z, z0.d, #0
+; CHECK-NEXT:    ldr q0, [x0]
+; CHECK-NEXT:    cmpeq p0.d, p0/z, z0.d, #0
 ; CHECK-NEXT:    ldr q0, [x1]
 ; CHECK-NEXT:    ld1d { z0.d }, p0/z, [z0.d]
 ; CHECK-NEXT:    str q0, [x0]
diff --git a/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-loads.ll b/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-loads.ll
index 6513b01d00922..34dc0bb5ef2d2 100644
--- a/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-loads.ll
+++ b/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-loads.ll
@@ -401,11 +401,10 @@ define void @masked_load_sext_v32i8i16(ptr %ap, ptr %bp, ptr %c) #0 {
 define void @masked_load_sext_v16i8i32(ptr %ap, ptr %bp, ptr %c) #0 {
 ; VBITS_GE_256-LABEL: masked_load_sext_v16i8i32:
 ; VBITS_GE_256:       // %bb.0:
-; VBITS_GE_256-NEXT:    ldr q0, [x1]
 ; VBITS_GE_256-NEXT:    ptrue p0.b, vl16
+; VBITS_GE_256-NEXT:    ldr q0, [x1]
 ; VBITS_GE_256-NEXT:    mov x8, #8 // =0x8
-; VBITS_GE_256-NEXT:    cmeq v0.16b, v0.16b, #0
-; VBITS_GE_256-NEXT:    cmpne p0.b, p0/z, z0.b, #0
+; VBITS_GE_256-NEXT:    cmpeq p0.b, p0/z, z0.b, #0
 ; VBITS_GE_256-NEXT:    ld1b { z0.b }, p0/z, [x0]
 ; VBITS_GE_256-NEXT:    ptrue p0.s, vl8
 ; VBITS_GE_256-NEXT:    ext v1.16b, v0.16b, v0.16b, #8
@@ -436,11 +435,10 @@ define void @masked_load_sext_v16i8i32(ptr %ap, ptr %bp, ptr %c) #0 {
 define void @masked_load_sext_v8i8i64(ptr %ap, ptr %bp, ptr %c) #0 {
 ; VBITS_GE_256-LABEL: masked_load_sext_v8i8i64:
 ; VBITS_GE_256:       // %bb.0:
-; VBITS_GE_256-NEXT:    ldr d0, [x1]
 ; VBITS_GE_256-NEXT:    ptrue p0.b, vl8
+; VBITS_GE_256-NEXT:    ldr d0, [x1]
 ; VBITS_GE_256-NEXT:    mov x8, #4 // =0x4
-; VBITS_GE_256-NEXT:    cmeq v0.8b, v0.8b, #0
-; VBITS_GE_256-NEXT:    cmpne p0.b, p0/z, z0.b, #0
+; VBITS_GE_256-NEXT:    cmpeq p0.b, p0/z, z0.b, #0
 ; VBITS_GE_256-NEXT:    ld1b { z0.b }, p0/z, [x0]
 ; VBITS_GE_256-NEXT:    ptrue p0.d, vl4
 ; VBITS_GE_256-NEXT:    sshll v0.8h, v0.8b, #0
@@ -504,11 +502,10 @@ define void @masked_load_sext_v16i16i32(ptr %ap, ptr %bp, ptr %c) #0 {
 define void @masked_load_sext_v8i16i64(ptr %ap, ptr %bp, ptr %c) #0 {
 ; VBITS_GE_256-LABEL: masked_load_sext_v8i16i64:
 ; VBITS_GE_256:       // %bb.0:
-; VBITS_GE_256-NEXT:    ldr q0, [x1]
 ; VBITS_GE_256-NEXT:    ptrue p0.h, vl8
+; VBITS_GE_256-NEXT:    ldr q0, [x1]
 ; VBITS_GE_256-NEXT:    mov x8, #4 // =0x4
-; VBITS_GE_256-NEXT:    cmeq v0.8h, v0.8h, #0
-; VBITS_GE_256-NEXT:    cmpne p0.h, p0/z, z0.h, #0
+; VBITS_GE_256-NEXT:    cmpeq p0.h, p0/z, z0.h, #0
 ; VBITS_GE_256-NEXT:    ld1h { z0.h }, p0/z, [x0]
 ; VBITS_GE_256-NEXT:    ptrue p0.d, vl4
 ; VBITS_GE_256-NEXT:    ext v1.16b, v0.16b, v0.16b, #8
@@ -603,11 +600,10 @@ define void @masked_load_zext_v32i8i16(ptr %ap, ptr %bp, ptr %c) #0 {
 define void @masked_load_zext_v16i8i32(ptr %ap, ptr %bp, ptr %c) #0 {
 ; VBITS_GE_256-LABEL: masked_load_zext_v16i8i32:
 ; VBITS_GE_256:       // %bb.0:
-; VBITS_GE_256-NEXT:    ldr q0, [x1]
 ; VBITS_GE_256-NEXT:    ptrue p0.b, vl16
+; VBITS_GE_256-NEXT:    ldr q0, [x1]
 ; VBITS_GE_256-NEXT:    mov x8, #8 // =0x8
-; VBITS_GE_256-NEXT:    cmeq v0.16b, v0.16b, #0
-; VBITS_GE_256-NEXT:    cmpne p0.b, p0/z, z0.b, #0
+; VBITS_GE_256-NEXT:    cmpeq p0.b, p0/z, z0.b, #0
 ; VBITS_GE_256-NEXT:    ld1b { z0.b }, p0/z, [x0]
 ; VBITS_GE_256-NEXT:    ptrue p0.s, vl8
 ; VBITS_GE_256-NEXT:    ext v1.16b, v0.16b, v0.16b, #8
@@ -638,11 +634,10 @@ define void @masked_load_zext_v16i8i32(ptr %ap, ptr %bp, ptr %c) #0 {
 define void @masked_load_zext_v8i8i64(ptr %ap, ptr %bp, ptr %c) #0 {
 ; VBITS_GE_256-LABEL: masked_load_zext_v8i8i64:
 ; VBITS_GE_256:       // %bb.0:
-; VBITS_GE_256-NEXT:    ldr d0, [x1]
 ; VBITS_GE_256-NEXT:    ptrue p0.b, vl8
+; VBITS_GE_256-NEXT:    ldr d0, [x1]
 ; VBITS_GE_256-NEXT:    mov x8, #4 // =0x4
-; VBITS_GE_256-NEXT:    cmeq v0.8b, v0.8b, #0
-; VBITS_GE_256-NEXT:    cmpne p0.b, p0/z, z0.b, #0
+; VBITS_GE_256-NEXT:    cmpeq p0.b, p0/z, z0.b, #0
 ; VBITS_GE_256-NEXT:    ld1b { z0.b }, p0/z, [x0]
 ; VBITS_GE_256-NEXT:    ptrue p0.d, vl4
 ; VBITS_GE_256-NEXT:    ushll v0.8h, v0.8b, #0
@@ -706,11 +701,10 @@ define void @masked_load_zext_v16i16i32(ptr %ap, ptr %bp, ptr %c) #0 {
 define void @masked_load_zext_v8i16i64(ptr %ap, ptr %bp, ptr %c) #0 {
 ; VBITS_GE_256-LABEL: masked_load_zext_v8i16i64:
 ; VBITS_GE_256:       // %bb.0:
-; VBITS_GE_256-NEXT:    ldr q0, [x1]
 ; VBITS_GE_256-NEXT:    ptrue p0.h, vl8
+; VBITS_GE_256-NEXT:    ldr q0, [x1]
 ; VBITS_GE_256-NEXT:    mov x8, #4 // =0x4
-; VBITS_GE_256-NEXT:    cmeq v0.8h, v0.8h, #0
-; VBITS_GE_256-NEXT:    cmpne p0.h, p0/z, z0.h, #0
+; VBITS_GE_256-NEXT:    cmpeq p0.h, p0/z, z0.h, #0
 ; VBITS_GE_256-NEXT:    ld1h { z0.h }, p0/z, [x0]
 ; VBITS_GE_256-NEXT:    ptrue p0.d, vl4
 ; VBITS_GE_256-NEXT:    ext v1.16b, v0.16b, v0.16b, #8
diff --git a/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-scatter.ll b/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-scatter.ll
index a42fce70f4f15..ed03f9b322432 100644
--- a/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-scatter.ll
+++ b/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-scatter.ll
@@ -433,11 +433,10 @@ define void @masked_scatter_v1i64(ptr %a, ptr %b) vscale_range(2,0) #0 {
 define void @masked_scatter_v2i64(ptr %a, ptr %b) vscale_range(2,0) #0 {
 ; CHECK-LABEL: masked_scatter_v2i64:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    ldr q0, [x0]
 ; CHECK-NEXT:    ptrue p0.d, vl2
-; CHECK-NEXT:    cmeq v1.2d, v0.2d, #0
-; CHECK-NEXT:    cmpne p0.d, p0/z, z1.d, #0
+; CHECK-NEXT:    ldr q0, [x0]
 ; CHECK-NEXT:    ldr q1, [x1]
+; CHECK-NEXT:    cmpeq p0.d, p0/z, z0.d, #0
 ; CHECK-NEXT:    st1d { z0.d }, p0, [z1.d]
 ; CHECK-NEXT:    ret
   %vals = load <2 x i64>, ptr %a

Comment on lines +20194 to +20197
if (VecVT.isScalableVector() && SubVT.isFixedLengthVector() &&
Vec.isUndef() && isZerosVector(SubVec.getNode()))
return VecVT.isInteger() ? DAG.getConstant(0, DL, VecVT)
: DAG.getConstantFP(0, DL, VecVT);
Copy link
Collaborator Author

@paulwalker-arm paulwalker-arm Apr 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was forced to write this to maintain existing code quality. There is no specific reason to limit the combine to zeros but I figured any expansion was best done in a separate PR?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this does look like a useful combine when applied to other constants, but like you say best for another PR. Not sure in practice if we'll actually end up with different assembly or not?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure in practice if we'll actually end up with different assembly or not?

We will because NEON only has reg-reg and reg-zero compare instructions whereas SVE has reg-imm as well. You can see this today by changing the existing SVE VLS tests to use non-zero immediates where the generated code emits an unnecessary splat.

ISD::CondCode CC;

// When Mask is the result of a SETCC, it's better to regenerate the compare.
if (Mask.getOpcode() == ISD::SETCC) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Could this be extended to peak through ISD::SIGN_EXTEND too? I'm thinking of cases such as:

t14: v16i8 = sign_extend t6
  t6: v16i1 = setcc t2, t4, seteq:ch

I've seen this come up when using @llvm.experimental.cttz.elts with fixed-length vectors (although presumably it's a general pattern), e.g.:

define i64 @cmpeq_i8(<16 x i8> %a, <16 x i8> %b) {
  %cmp = icmp eq <16 x i8> %a, %b
  %ctz = tail call i64 @llvm.experimental.cttz.elts(<16 x i1> %cmp, i1 1)
  ret i64 %ctz
}

Otherwise I'll have a look into it once this PR lands. :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For operation legalisation I would not expect to see such code because v16i1 is not a legal type. Typically these would be merged so that for operation legalisation you'd just see v16i8 = setcc t2, t4, seteq:ch.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, I'll have a look later then. Thanks :)

Copy link
Collaborator

@huntergr-arm huntergr-arm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM


// When Mask is the result of a SETCC, it's better to regenerate the compare.
if (Mask.getOpcode() == ISD::SETCC) {
Op1 = convertToScalableVector(DAG, ContainerVT, Mask.getOperand(0));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is ContainerVT guaranteed to be correct for the SETCC inputs? It looks like we base ContainerVT on the result. I'm not sure if something like this is legal for NEON:

v4i32 = SETCC NE v4i16 %a, v4i16 %b

since v4i16 is also a legal type. I'm just a bit worried that we're effectively promoting a type here and, if so, is that a problem?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't believe mixing element sizes like this is legal for NEON. The expectation is that all vector types will have the same element count and bit length. You can see this today (albeit slightly less so since I've refactored the integer side) but for NEON we simply lower SETCC operations onto AArch64ISD::FCM## operations who definitions have the same requirement.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough! Just wanted to make sure.

@paulwalker-arm paulwalker-arm merged commit e06a9ca into llvm:main Apr 8, 2025
13 checks passed
@paulwalker-arm paulwalker-arm deleted the sve-vls-masked-loads branch April 8, 2025 11:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants