Skip to content

Commit ff59fd2

Browse files
[LLVM][CodeGen][SVE] Improve lowering of fixed length masked mem ops.
Converting fixed length masks, as used by MLOAD, to scalable vectors is done by comparing the mask to zero. When the mask is the result of a compare we can instead promote the operands and regenerate the original compare. At worst this reduces the dependecy chain and in most cases removes the need for multiple compares.
1 parent b0b97e3 commit ff59fd2

File tree

4 files changed

+46
-29
lines changed

4 files changed

+46
-29
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20190,6 +20190,12 @@ performInsertSubvectorCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
2019020190
EVT VecVT = Vec.getValueType();
2019120191
EVT SubVT = SubVec.getValueType();
2019220192

20193+
// Promote fixed length vector zeros.
20194+
if (VecVT.isScalableVector() && SubVT.isFixedLengthVector() &&
20195+
Vec.isUndef() && isZerosVector(SubVec.getNode()))
20196+
return VecVT.isInteger() ? DAG.getConstant(0, DL, VecVT)
20197+
: DAG.getConstantFP(0, DL, VecVT);
20198+
2019320199
// Only do this for legal fixed vector types.
2019420200
if (!VecVT.isFixedLengthVector() ||
2019520201
!DAG.getTargetLoweringInfo().isTypeLegal(VecVT) ||
@@ -28697,17 +28703,36 @@ static SDValue convertFixedMaskToScalableVector(SDValue Mask,
2869728703
SDLoc DL(Mask);
2869828704
EVT InVT = Mask.getValueType();
2869928705
EVT ContainerVT = getContainerForFixedLengthVector(DAG, InVT);
28700-
28701-
auto Pg = getPredicateForFixedLengthVector(DAG, DL, InVT);
28706+
SDValue Pg = getPredicateForFixedLengthVector(DAG, DL, InVT);
2870228707

2870328708
if (ISD::isBuildVectorAllOnes(Mask.getNode()))
2870428709
return Pg;
2870528710

28706-
auto Op1 = convertToScalableVector(DAG, ContainerVT, Mask);
28707-
auto Op2 = DAG.getConstant(0, DL, ContainerVT);
28711+
bool InvertCond = false;
28712+
if (isBitwiseNot(Mask)) {
28713+
InvertCond = true;
28714+
Mask = Mask.getOperand(0);
28715+
}
28716+
28717+
SDValue Op1, Op2;
28718+
ISD::CondCode CC;
28719+
28720+
// When Mask is the result of a SETCC, it's better to regenerate the compare.
28721+
if (Mask.getOpcode() == ISD::SETCC) {
28722+
Op1 = convertToScalableVector(DAG, ContainerVT, Mask.getOperand(0));
28723+
Op2 = convertToScalableVector(DAG, ContainerVT, Mask.getOperand(1));
28724+
CC = cast<CondCodeSDNode>(Mask.getOperand(2))->get();
28725+
} else {
28726+
Op1 = convertToScalableVector(DAG, ContainerVT, Mask);
28727+
Op2 = DAG.getConstant(0, DL, ContainerVT);
28728+
CC = ISD::SETNE;
28729+
}
28730+
28731+
if (InvertCond)
28732+
CC = getSetCCInverse(CC, Op1.getValueType());
2870828733

2870928734
return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, DL, Pg.getValueType(),
28710-
{Pg, Op1, Op2, DAG.getCondCode(ISD::SETNE)});
28735+
{Pg, Op1, Op2, DAG.getCondCode(CC)});
2871128736
}
2871228737

2871328738
// Convert all fixed length vector loads larger than NEON to masked_loads.

llvm/test/CodeGen/AArch64/sve-fixed-length-masked-gather.ll

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -460,10 +460,9 @@ define void @masked_gather_v1i64(ptr %a, ptr %b) vscale_range(2,0) #0 {
460460
define void @masked_gather_v2i64(ptr %a, ptr %b) vscale_range(2,0) #0 {
461461
; CHECK-LABEL: masked_gather_v2i64:
462462
; CHECK: // %bb.0:
463-
; CHECK-NEXT: ldr q0, [x0]
464463
; CHECK-NEXT: ptrue p0.d, vl2
465-
; CHECK-NEXT: cmeq v0.2d, v0.2d, #0
466-
; CHECK-NEXT: cmpne p0.d, p0/z, z0.d, #0
464+
; CHECK-NEXT: ldr q0, [x0]
465+
; CHECK-NEXT: cmpeq p0.d, p0/z, z0.d, #0
467466
; CHECK-NEXT: ldr q0, [x1]
468467
; CHECK-NEXT: ld1d { z0.d }, p0/z, [z0.d]
469468
; CHECK-NEXT: str q0, [x0]

llvm/test/CodeGen/AArch64/sve-fixed-length-masked-loads.ll

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -401,11 +401,10 @@ define void @masked_load_sext_v32i8i16(ptr %ap, ptr %bp, ptr %c) #0 {
401401
define void @masked_load_sext_v16i8i32(ptr %ap, ptr %bp, ptr %c) #0 {
402402
; VBITS_GE_256-LABEL: masked_load_sext_v16i8i32:
403403
; VBITS_GE_256: // %bb.0:
404-
; VBITS_GE_256-NEXT: ldr q0, [x1]
405404
; VBITS_GE_256-NEXT: ptrue p0.b, vl16
405+
; VBITS_GE_256-NEXT: ldr q0, [x1]
406406
; VBITS_GE_256-NEXT: mov x8, #8 // =0x8
407-
; VBITS_GE_256-NEXT: cmeq v0.16b, v0.16b, #0
408-
; VBITS_GE_256-NEXT: cmpne p0.b, p0/z, z0.b, #0
407+
; VBITS_GE_256-NEXT: cmpeq p0.b, p0/z, z0.b, #0
409408
; VBITS_GE_256-NEXT: ld1b { z0.b }, p0/z, [x0]
410409
; VBITS_GE_256-NEXT: ptrue p0.s, vl8
411410
; VBITS_GE_256-NEXT: ext v1.16b, v0.16b, v0.16b, #8
@@ -436,11 +435,10 @@ define void @masked_load_sext_v16i8i32(ptr %ap, ptr %bp, ptr %c) #0 {
436435
define void @masked_load_sext_v8i8i64(ptr %ap, ptr %bp, ptr %c) #0 {
437436
; VBITS_GE_256-LABEL: masked_load_sext_v8i8i64:
438437
; VBITS_GE_256: // %bb.0:
439-
; VBITS_GE_256-NEXT: ldr d0, [x1]
440438
; VBITS_GE_256-NEXT: ptrue p0.b, vl8
439+
; VBITS_GE_256-NEXT: ldr d0, [x1]
441440
; VBITS_GE_256-NEXT: mov x8, #4 // =0x4
442-
; VBITS_GE_256-NEXT: cmeq v0.8b, v0.8b, #0
443-
; VBITS_GE_256-NEXT: cmpne p0.b, p0/z, z0.b, #0
441+
; VBITS_GE_256-NEXT: cmpeq p0.b, p0/z, z0.b, #0
444442
; VBITS_GE_256-NEXT: ld1b { z0.b }, p0/z, [x0]
445443
; VBITS_GE_256-NEXT: ptrue p0.d, vl4
446444
; VBITS_GE_256-NEXT: sshll v0.8h, v0.8b, #0
@@ -504,11 +502,10 @@ define void @masked_load_sext_v16i16i32(ptr %ap, ptr %bp, ptr %c) #0 {
504502
define void @masked_load_sext_v8i16i64(ptr %ap, ptr %bp, ptr %c) #0 {
505503
; VBITS_GE_256-LABEL: masked_load_sext_v8i16i64:
506504
; VBITS_GE_256: // %bb.0:
507-
; VBITS_GE_256-NEXT: ldr q0, [x1]
508505
; VBITS_GE_256-NEXT: ptrue p0.h, vl8
506+
; VBITS_GE_256-NEXT: ldr q0, [x1]
509507
; VBITS_GE_256-NEXT: mov x8, #4 // =0x4
510-
; VBITS_GE_256-NEXT: cmeq v0.8h, v0.8h, #0
511-
; VBITS_GE_256-NEXT: cmpne p0.h, p0/z, z0.h, #0
508+
; VBITS_GE_256-NEXT: cmpeq p0.h, p0/z, z0.h, #0
512509
; VBITS_GE_256-NEXT: ld1h { z0.h }, p0/z, [x0]
513510
; VBITS_GE_256-NEXT: ptrue p0.d, vl4
514511
; VBITS_GE_256-NEXT: ext v1.16b, v0.16b, v0.16b, #8
@@ -603,11 +600,10 @@ define void @masked_load_zext_v32i8i16(ptr %ap, ptr %bp, ptr %c) #0 {
603600
define void @masked_load_zext_v16i8i32(ptr %ap, ptr %bp, ptr %c) #0 {
604601
; VBITS_GE_256-LABEL: masked_load_zext_v16i8i32:
605602
; VBITS_GE_256: // %bb.0:
606-
; VBITS_GE_256-NEXT: ldr q0, [x1]
607603
; VBITS_GE_256-NEXT: ptrue p0.b, vl16
604+
; VBITS_GE_256-NEXT: ldr q0, [x1]
608605
; VBITS_GE_256-NEXT: mov x8, #8 // =0x8
609-
; VBITS_GE_256-NEXT: cmeq v0.16b, v0.16b, #0
610-
; VBITS_GE_256-NEXT: cmpne p0.b, p0/z, z0.b, #0
606+
; VBITS_GE_256-NEXT: cmpeq p0.b, p0/z, z0.b, #0
611607
; VBITS_GE_256-NEXT: ld1b { z0.b }, p0/z, [x0]
612608
; VBITS_GE_256-NEXT: ptrue p0.s, vl8
613609
; VBITS_GE_256-NEXT: ext v1.16b, v0.16b, v0.16b, #8
@@ -638,11 +634,10 @@ define void @masked_load_zext_v16i8i32(ptr %ap, ptr %bp, ptr %c) #0 {
638634
define void @masked_load_zext_v8i8i64(ptr %ap, ptr %bp, ptr %c) #0 {
639635
; VBITS_GE_256-LABEL: masked_load_zext_v8i8i64:
640636
; VBITS_GE_256: // %bb.0:
641-
; VBITS_GE_256-NEXT: ldr d0, [x1]
642637
; VBITS_GE_256-NEXT: ptrue p0.b, vl8
638+
; VBITS_GE_256-NEXT: ldr d0, [x1]
643639
; VBITS_GE_256-NEXT: mov x8, #4 // =0x4
644-
; VBITS_GE_256-NEXT: cmeq v0.8b, v0.8b, #0
645-
; VBITS_GE_256-NEXT: cmpne p0.b, p0/z, z0.b, #0
640+
; VBITS_GE_256-NEXT: cmpeq p0.b, p0/z, z0.b, #0
646641
; VBITS_GE_256-NEXT: ld1b { z0.b }, p0/z, [x0]
647642
; VBITS_GE_256-NEXT: ptrue p0.d, vl4
648643
; VBITS_GE_256-NEXT: ushll v0.8h, v0.8b, #0
@@ -706,11 +701,10 @@ define void @masked_load_zext_v16i16i32(ptr %ap, ptr %bp, ptr %c) #0 {
706701
define void @masked_load_zext_v8i16i64(ptr %ap, ptr %bp, ptr %c) #0 {
707702
; VBITS_GE_256-LABEL: masked_load_zext_v8i16i64:
708703
; VBITS_GE_256: // %bb.0:
709-
; VBITS_GE_256-NEXT: ldr q0, [x1]
710704
; VBITS_GE_256-NEXT: ptrue p0.h, vl8
705+
; VBITS_GE_256-NEXT: ldr q0, [x1]
711706
; VBITS_GE_256-NEXT: mov x8, #4 // =0x4
712-
; VBITS_GE_256-NEXT: cmeq v0.8h, v0.8h, #0
713-
; VBITS_GE_256-NEXT: cmpne p0.h, p0/z, z0.h, #0
707+
; VBITS_GE_256-NEXT: cmpeq p0.h, p0/z, z0.h, #0
714708
; VBITS_GE_256-NEXT: ld1h { z0.h }, p0/z, [x0]
715709
; VBITS_GE_256-NEXT: ptrue p0.d, vl4
716710
; VBITS_GE_256-NEXT: ext v1.16b, v0.16b, v0.16b, #8

llvm/test/CodeGen/AArch64/sve-fixed-length-masked-scatter.ll

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -433,11 +433,10 @@ define void @masked_scatter_v1i64(ptr %a, ptr %b) vscale_range(2,0) #0 {
433433
define void @masked_scatter_v2i64(ptr %a, ptr %b) vscale_range(2,0) #0 {
434434
; CHECK-LABEL: masked_scatter_v2i64:
435435
; CHECK: // %bb.0:
436-
; CHECK-NEXT: ldr q0, [x0]
437436
; CHECK-NEXT: ptrue p0.d, vl2
438-
; CHECK-NEXT: cmeq v1.2d, v0.2d, #0
439-
; CHECK-NEXT: cmpne p0.d, p0/z, z1.d, #0
437+
; CHECK-NEXT: ldr q0, [x0]
440438
; CHECK-NEXT: ldr q1, [x1]
439+
; CHECK-NEXT: cmpeq p0.d, p0/z, z0.d, #0
441440
; CHECK-NEXT: st1d { z0.d }, p0, [z1.d]
442441
; CHECK-NEXT: ret
443442
%vals = load <2 x i64>, ptr %a

0 commit comments

Comments
 (0)