Skip to content

Commit 19c1381

Browse files
authored
[AArch64][GlobalISel] Fix vecreduce(zext) fold from illegal types. (#167944)
We generate a ADDLV node that incorporates a vecreduce(zext) from elements of half the size. This means that we need the input type to be at least twice the size of the input. I updated some variable names whilst I was here. Fixes #167935
1 parent 63e6373 commit 19c1381

File tree

2 files changed

+48
-14
lines changed

2 files changed

+48
-14
lines changed

llvm/lib/Target/AArch64/GISel/AArch64PreLegalizerCombiner.cpp

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,8 @@ bool matchExtUaddvToUaddlv(MachineInstr &MI, MachineRegisterInfo &MRI,
435435
Register ExtSrcReg = ExtMI->getOperand(1).getReg();
436436
LLT ExtSrcTy = MRI.getType(ExtSrcReg);
437437
LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
438+
if (ExtSrcTy.getScalarSizeInBits() * 2 > DstTy.getScalarSizeInBits())
439+
return false;
438440
if ((DstTy.getScalarSizeInBits() == 16 &&
439441
ExtSrcTy.getNumElements() % 8 == 0 && ExtSrcTy.getNumElements() < 256) ||
440442
(DstTy.getScalarSizeInBits() == 32 &&
@@ -492,7 +494,7 @@ void applyExtUaddvToUaddlv(MachineInstr &MI, MachineRegisterInfo &MRI,
492494

493495
unsigned MidScalarSize = MainTy.getScalarSizeInBits() * 2;
494496
LLT MidScalarLLT = LLT::scalar(MidScalarSize);
495-
Register zeroReg = B.buildConstant(LLT::scalar(64), 0).getReg(0);
497+
Register ZeroReg = B.buildConstant(LLT::scalar(64), 0).getReg(0);
496498
for (unsigned I = 0; I < WorkingRegisters.size(); I++) {
497499
// If the number of elements is too small to build an instruction, extend
498500
// its size before applying addlv
@@ -508,10 +510,10 @@ void applyExtUaddvToUaddlv(MachineInstr &MI, MachineRegisterInfo &MRI,
508510

509511
// Generate the {U/S}ADDLV instruction, whose output is always double of the
510512
// Src's Scalar size
511-
LLT addlvTy = MidScalarSize <= 32 ? LLT::fixed_vector(4, 32)
513+
LLT AddlvTy = MidScalarSize <= 32 ? LLT::fixed_vector(4, 32)
512514
: LLT::fixed_vector(2, 64);
513-
Register addlvReg =
514-
B.buildInstr(Opc, {addlvTy}, {WorkingRegisters[I]}).getReg(0);
515+
Register AddlvReg =
516+
B.buildInstr(Opc, {AddlvTy}, {WorkingRegisters[I]}).getReg(0);
515517

516518
// The output from {U/S}ADDLV gets placed in the lowest lane of a v4i32 or
517519
// v2i64 register.
@@ -520,36 +522,36 @@ void applyExtUaddvToUaddlv(MachineInstr &MI, MachineRegisterInfo &MRI,
520522
// Therefore we have to extract/truncate the the value to the right type
521523
if (MidScalarSize == 32 || MidScalarSize == 64) {
522524
WorkingRegisters[I] = B.buildInstr(AArch64::G_EXTRACT_VECTOR_ELT,
523-
{MidScalarLLT}, {addlvReg, zeroReg})
525+
{MidScalarLLT}, {AddlvReg, ZeroReg})
524526
.getReg(0);
525527
} else {
526-
Register extractReg = B.buildInstr(AArch64::G_EXTRACT_VECTOR_ELT,
527-
{LLT::scalar(32)}, {addlvReg, zeroReg})
528+
Register ExtractReg = B.buildInstr(AArch64::G_EXTRACT_VECTOR_ELT,
529+
{LLT::scalar(32)}, {AddlvReg, ZeroReg})
528530
.getReg(0);
529531
WorkingRegisters[I] =
530-
B.buildTrunc({MidScalarLLT}, {extractReg}).getReg(0);
532+
B.buildTrunc({MidScalarLLT}, {ExtractReg}).getReg(0);
531533
}
532534
}
533535

534-
Register outReg;
536+
Register OutReg;
535537
if (WorkingRegisters.size() > 1) {
536-
outReg = B.buildAdd(MidScalarLLT, WorkingRegisters[0], WorkingRegisters[1])
538+
OutReg = B.buildAdd(MidScalarLLT, WorkingRegisters[0], WorkingRegisters[1])
537539
.getReg(0);
538540
for (unsigned I = 2; I < WorkingRegisters.size(); I++) {
539-
outReg = B.buildAdd(MidScalarLLT, outReg, WorkingRegisters[I]).getReg(0);
541+
OutReg = B.buildAdd(MidScalarLLT, OutReg, WorkingRegisters[I]).getReg(0);
540542
}
541543
} else {
542-
outReg = WorkingRegisters[0];
544+
OutReg = WorkingRegisters[0];
543545
}
544546

545547
if (DstTy.getScalarSizeInBits() > MidScalarSize) {
546548
// Handle the scalar value if the DstTy's Scalar Size is more than double
547549
// Src's ScalarType
548550
B.buildInstr(std::get<1>(MatchInfo) ? TargetOpcode::G_SEXT
549551
: TargetOpcode::G_ZEXT,
550-
{DstReg}, {outReg});
552+
{DstReg}, {OutReg});
551553
} else {
552-
B.buildCopy(DstReg, outReg);
554+
B.buildCopy(DstReg, OutReg);
553555
}
554556

555557
MI.eraseFromParent();

llvm/test/CodeGen/AArch64/vecreduce-add.ll

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4808,6 +4808,38 @@ define i64 @extract_scalable(<2 x i32> %0) "target-features"="+sve2" {
48084808
ret i64 %5
48094809
}
48104810

4811+
define i32 @vecreduce_add_from_i21_zero() {
4812+
; CHECK-SD-LABEL: vecreduce_add_from_i21_zero:
4813+
; CHECK-SD: // %bb.0: // %entry
4814+
; CHECK-SD-NEXT: mov w0, wzr
4815+
; CHECK-SD-NEXT: ret
4816+
;
4817+
; CHECK-GI-LABEL: vecreduce_add_from_i21_zero:
4818+
; CHECK-GI: // %bb.0: // %entry
4819+
; CHECK-GI-NEXT: movi v0.2d, #0000000000000000
4820+
; CHECK-GI-NEXT: addv s0, v0.4s
4821+
; CHECK-GI-NEXT: fmov w0, s0
4822+
; CHECK-GI-NEXT: ret
4823+
entry:
4824+
%0 = zext <4 x i21> zeroinitializer to <4 x i32>
4825+
%1 = tail call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %0)
4826+
ret i32 %1
4827+
}
4828+
4829+
define i32 @vecreduce_add_from_i21(<4 x i21> %a) {
4830+
; CHECK-LABEL: vecreduce_add_from_i21:
4831+
; CHECK: // %bb.0: // %entry
4832+
; CHECK-NEXT: movi v1.4s, #31, msl #16
4833+
; CHECK-NEXT: and v0.16b, v0.16b, v1.16b
4834+
; CHECK-NEXT: addv s0, v0.4s
4835+
; CHECK-NEXT: fmov w0, s0
4836+
; CHECK-NEXT: ret
4837+
entry:
4838+
%0 = zext <4 x i21> %a to <4 x i32>
4839+
%1 = tail call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %0)
4840+
ret i32 %1
4841+
}
4842+
48114843
declare <8 x i32> @llvm.abs.v8i32(<8 x i32>, i1 immarg) #1
48124844
declare i16 @llvm.vector.reduce.add.v32i16(<32 x i16>)
48134845
declare i16 @llvm.vector.reduce.add.v24i16(<24 x i16>)

0 commit comments

Comments
 (0)