Skip to content

Commit a76b02d

Browse files
authored
[AMDGPU] Extending wave reduction intrinsics for i64 types - 2 (#151309)
Supporting Arithemtic Operations: `add`, `sub`
1 parent 94e2c19 commit a76b02d

File tree

4 files changed

+3624
-117
lines changed

4 files changed

+3624
-117
lines changed

llvm/lib/Target/AMDGPU/SIISelLowering.cpp

Lines changed: 158 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -5270,6 +5270,57 @@ static MachineBasicBlock *emitIndirectDst(MachineInstr &MI,
52705270
return LoopBB;
52715271
}
52725272

5273+
static MachineBasicBlock *Expand64BitScalarArithmetic(MachineInstr &MI,
5274+
MachineBasicBlock *BB) {
5275+
// For targets older than GFX12, we emit a sequence of 32-bit operations.
5276+
// For GFX12, we emit s_add_u64 and s_sub_u64.
5277+
MachineFunction *MF = BB->getParent();
5278+
const SIInstrInfo *TII = MF->getSubtarget<GCNSubtarget>().getInstrInfo();
5279+
const GCNSubtarget &ST = MF->getSubtarget<GCNSubtarget>();
5280+
MachineRegisterInfo &MRI = BB->getParent()->getRegInfo();
5281+
const DebugLoc &DL = MI.getDebugLoc();
5282+
MachineOperand &Dest = MI.getOperand(0);
5283+
MachineOperand &Src0 = MI.getOperand(1);
5284+
MachineOperand &Src1 = MI.getOperand(2);
5285+
bool IsAdd = (MI.getOpcode() == AMDGPU::S_ADD_U64_PSEUDO);
5286+
if (ST.hasScalarAddSub64()) {
5287+
unsigned Opc = IsAdd ? AMDGPU::S_ADD_U64 : AMDGPU::S_SUB_U64;
5288+
// clang-format off
5289+
BuildMI(*BB, MI, DL, TII->get(Opc), Dest.getReg())
5290+
.add(Src0)
5291+
.add(Src1);
5292+
// clang-format on
5293+
} else {
5294+
const SIRegisterInfo *TRI = ST.getRegisterInfo();
5295+
const TargetRegisterClass *BoolRC = TRI->getBoolRC();
5296+
5297+
Register DestSub0 = MRI.createVirtualRegister(&AMDGPU::SReg_32RegClass);
5298+
Register DestSub1 = MRI.createVirtualRegister(&AMDGPU::SReg_32RegClass);
5299+
5300+
MachineOperand Src0Sub0 = TII->buildExtractSubRegOrImm(
5301+
MI, MRI, Src0, BoolRC, AMDGPU::sub0, &AMDGPU::SReg_32RegClass);
5302+
MachineOperand Src0Sub1 = TII->buildExtractSubRegOrImm(
5303+
MI, MRI, Src0, BoolRC, AMDGPU::sub1, &AMDGPU::SReg_32RegClass);
5304+
5305+
MachineOperand Src1Sub0 = TII->buildExtractSubRegOrImm(
5306+
MI, MRI, Src1, BoolRC, AMDGPU::sub0, &AMDGPU::SReg_32RegClass);
5307+
MachineOperand Src1Sub1 = TII->buildExtractSubRegOrImm(
5308+
MI, MRI, Src1, BoolRC, AMDGPU::sub1, &AMDGPU::SReg_32RegClass);
5309+
5310+
unsigned LoOpc = IsAdd ? AMDGPU::S_ADD_U32 : AMDGPU::S_SUB_U32;
5311+
unsigned HiOpc = IsAdd ? AMDGPU::S_ADDC_U32 : AMDGPU::S_SUBB_U32;
5312+
BuildMI(*BB, MI, DL, TII->get(LoOpc), DestSub0).add(Src0Sub0).add(Src1Sub0);
5313+
BuildMI(*BB, MI, DL, TII->get(HiOpc), DestSub1).add(Src0Sub1).add(Src1Sub1);
5314+
BuildMI(*BB, MI, DL, TII->get(TargetOpcode::REG_SEQUENCE), Dest.getReg())
5315+
.addReg(DestSub0)
5316+
.addImm(AMDGPU::sub0)
5317+
.addReg(DestSub1)
5318+
.addImm(AMDGPU::sub1);
5319+
}
5320+
MI.eraseFromParent();
5321+
return BB;
5322+
}
5323+
52735324
static uint32_t getIdentityValueFor32BitWaveReduction(unsigned Opc) {
52745325
switch (Opc) {
52755326
case AMDGPU::S_MIN_U32:
@@ -5303,6 +5354,9 @@ static uint64_t getIdentityValueFor64BitWaveReduction(unsigned Opc) {
53035354
return std::numeric_limits<uint64_t>::min();
53045355
case AMDGPU::V_CMP_GT_I64_e64: // max.i64
53055356
return std::numeric_limits<int64_t>::min();
5357+
case AMDGPU::S_ADD_U64_PSEUDO:
5358+
case AMDGPU::S_SUB_U64_PSEUDO:
5359+
return std::numeric_limits<uint64_t>::min();
53065360
default:
53075361
llvm_unreachable(
53085362
"Unexpected opcode in getIdentityValueFor64BitWaveReduction");
@@ -5355,51 +5409,54 @@ static MachineBasicBlock *lowerWaveReduce(MachineInstr &MI,
53555409
}
53565410
case AMDGPU::S_XOR_B32:
53575411
case AMDGPU::S_ADD_I32:
5358-
case AMDGPU::S_SUB_I32: {
5412+
case AMDGPU::S_ADD_U64_PSEUDO:
5413+
case AMDGPU::S_SUB_I32:
5414+
case AMDGPU::S_SUB_U64_PSEUDO: {
53595415
const TargetRegisterClass *WaveMaskRegClass = TRI->getWaveMaskRegClass();
53605416
const TargetRegisterClass *DstRegClass = MRI.getRegClass(DstReg);
53615417
Register ExecMask = MRI.createVirtualRegister(WaveMaskRegClass);
5362-
Register ActiveLanes = MRI.createVirtualRegister(DstRegClass);
5418+
Register NumActiveLanes =
5419+
MRI.createVirtualRegister(&AMDGPU::SReg_32RegClass);
53635420

53645421
bool IsWave32 = ST.isWave32();
53655422
unsigned MovOpc = IsWave32 ? AMDGPU::S_MOV_B32 : AMDGPU::S_MOV_B64;
53665423
MCRegister ExecReg = IsWave32 ? AMDGPU::EXEC_LO : AMDGPU::EXEC;
5367-
unsigned CountReg =
5424+
unsigned BitCountOpc =
53685425
IsWave32 ? AMDGPU::S_BCNT1_I32_B32 : AMDGPU::S_BCNT1_I32_B64;
53695426

5370-
auto Exec =
5371-
BuildMI(BB, MI, DL, TII->get(MovOpc), ExecMask).addReg(ExecReg);
5427+
BuildMI(BB, MI, DL, TII->get(MovOpc), ExecMask).addReg(ExecReg);
53725428

5373-
auto NewAccumulator = BuildMI(BB, MI, DL, TII->get(CountReg), ActiveLanes)
5374-
.addReg(Exec->getOperand(0).getReg());
5429+
auto NewAccumulator =
5430+
BuildMI(BB, MI, DL, TII->get(BitCountOpc), NumActiveLanes)
5431+
.addReg(ExecMask);
53755432

53765433
switch (Opc) {
53775434
case AMDGPU::S_XOR_B32: {
53785435
// Performing an XOR operation on a uniform value
53795436
// depends on the parity of the number of active lanes.
53805437
// For even parity, the result will be 0, for odd
53815438
// parity the result will be the same as the input value.
5382-
Register ParityRegister = MRI.createVirtualRegister(DstRegClass);
5439+
Register ParityRegister =
5440+
MRI.createVirtualRegister(&AMDGPU::SReg_32RegClass);
53835441

5384-
auto ParityReg =
5385-
BuildMI(BB, MI, DL, TII->get(AMDGPU::S_AND_B32), ParityRegister)
5386-
.addReg(NewAccumulator->getOperand(0).getReg())
5387-
.addImm(1);
5442+
BuildMI(BB, MI, DL, TII->get(AMDGPU::S_AND_B32), ParityRegister)
5443+
.addReg(NewAccumulator->getOperand(0).getReg())
5444+
.addImm(1)
5445+
.setOperandDead(3); // Dead scc
53885446
BuildMI(BB, MI, DL, TII->get(AMDGPU::S_MUL_I32), DstReg)
53895447
.addReg(SrcReg)
5390-
.addReg(ParityReg->getOperand(0).getReg());
5448+
.addReg(ParityRegister);
53915449
break;
53925450
}
53935451
case AMDGPU::S_SUB_I32: {
53945452
Register NegatedVal = MRI.createVirtualRegister(DstRegClass);
53955453

53965454
// Take the negation of the source operand.
5397-
auto InvertedValReg =
5398-
BuildMI(BB, MI, DL, TII->get(AMDGPU::S_MUL_I32), NegatedVal)
5399-
.addImm(-1)
5400-
.addReg(SrcReg);
5455+
BuildMI(BB, MI, DL, TII->get(AMDGPU::S_SUB_I32), NegatedVal)
5456+
.addImm(0)
5457+
.addReg(SrcReg);
54015458
BuildMI(BB, MI, DL, TII->get(AMDGPU::S_MUL_I32), DstReg)
5402-
.addReg(InvertedValReg->getOperand(0).getReg())
5459+
.addReg(NegatedVal)
54035460
.addReg(NewAccumulator->getOperand(0).getReg());
54045461
break;
54055462
}
@@ -5409,6 +5466,75 @@ static MachineBasicBlock *lowerWaveReduce(MachineInstr &MI,
54095466
.addReg(NewAccumulator->getOperand(0).getReg());
54105467
break;
54115468
}
5469+
case AMDGPU::S_ADD_U64_PSEUDO:
5470+
case AMDGPU::S_SUB_U64_PSEUDO: {
5471+
Register DestSub0 = MRI.createVirtualRegister(&AMDGPU::SReg_32RegClass);
5472+
Register DestSub1 = MRI.createVirtualRegister(&AMDGPU::SReg_32RegClass);
5473+
Register Op1H_Op0L_Reg =
5474+
MRI.createVirtualRegister(&AMDGPU::SReg_32RegClass);
5475+
Register Op1L_Op0H_Reg =
5476+
MRI.createVirtualRegister(&AMDGPU::SReg_32RegClass);
5477+
Register CarryReg = MRI.createVirtualRegister(&AMDGPU::SReg_32RegClass);
5478+
Register AddReg = MRI.createVirtualRegister(&AMDGPU::SReg_32RegClass);
5479+
Register NegatedValLo =
5480+
MRI.createVirtualRegister(&AMDGPU::SReg_32RegClass);
5481+
Register NegatedValHi =
5482+
MRI.createVirtualRegister(&AMDGPU::SReg_32RegClass);
5483+
5484+
const TargetRegisterClass *Src1RC = MRI.getRegClass(SrcReg);
5485+
const TargetRegisterClass *Src1SubRC =
5486+
TRI->getSubRegisterClass(Src1RC, AMDGPU::sub0);
5487+
5488+
MachineOperand Op1L = TII->buildExtractSubRegOrImm(
5489+
MI, MRI, MI.getOperand(1), Src1RC, AMDGPU::sub0, Src1SubRC);
5490+
MachineOperand Op1H = TII->buildExtractSubRegOrImm(
5491+
MI, MRI, MI.getOperand(1), Src1RC, AMDGPU::sub1, Src1SubRC);
5492+
5493+
if (Opc == AMDGPU::S_SUB_U64_PSEUDO) {
5494+
BuildMI(BB, MI, DL, TII->get(AMDGPU::S_SUB_I32), NegatedValLo)
5495+
.addImm(0)
5496+
.addReg(NewAccumulator->getOperand(0).getReg())
5497+
.setOperandDead(3); // Dead scc
5498+
BuildMI(BB, MI, DL, TII->get(AMDGPU::S_ASHR_I32), NegatedValHi)
5499+
.addReg(NegatedValLo)
5500+
.addImm(31)
5501+
.setOperandDead(3); // Dead scc
5502+
BuildMI(BB, MI, DL, TII->get(AMDGPU::S_MUL_I32), Op1L_Op0H_Reg)
5503+
.add(Op1L)
5504+
.addReg(NegatedValHi);
5505+
}
5506+
Register LowOpcode = Opc == AMDGPU::S_SUB_U64_PSEUDO
5507+
? NegatedValLo
5508+
: NewAccumulator->getOperand(0).getReg();
5509+
BuildMI(BB, MI, DL, TII->get(AMDGPU::S_MUL_I32), DestSub0)
5510+
.add(Op1L)
5511+
.addReg(LowOpcode);
5512+
BuildMI(BB, MI, DL, TII->get(AMDGPU::S_MUL_HI_U32), CarryReg)
5513+
.add(Op1L)
5514+
.addReg(LowOpcode);
5515+
BuildMI(BB, MI, DL, TII->get(AMDGPU::S_MUL_I32), Op1H_Op0L_Reg)
5516+
.add(Op1H)
5517+
.addReg(LowOpcode);
5518+
5519+
Register HiVal = Opc == AMDGPU::S_SUB_U64_PSEUDO ? AddReg : DestSub1;
5520+
BuildMI(BB, MI, DL, TII->get(AMDGPU::S_ADD_U32), HiVal)
5521+
.addReg(CarryReg)
5522+
.addReg(Op1H_Op0L_Reg)
5523+
.setOperandDead(3); // Dead scc
5524+
5525+
if (Opc == AMDGPU::S_SUB_U64_PSEUDO) {
5526+
BuildMI(BB, MI, DL, TII->get(AMDGPU::S_ADD_U32), DestSub1)
5527+
.addReg(HiVal)
5528+
.addReg(Op1L_Op0H_Reg)
5529+
.setOperandDead(3); // Dead scc
5530+
}
5531+
BuildMI(BB, MI, DL, TII->get(TargetOpcode::REG_SEQUENCE), DstReg)
5532+
.addReg(DestSub0)
5533+
.addImm(AMDGPU::sub0)
5534+
.addReg(DestSub1)
5535+
.addImm(AMDGPU::sub1);
5536+
break;
5537+
}
54125538
}
54135539
RetBB = &BB;
54145540
}
@@ -5555,6 +5681,14 @@ static MachineBasicBlock *lowerWaveReduce(MachineInstr &MI,
55555681
.addReg(Accumulator->getOperand(0).getReg());
55565682
break;
55575683
}
5684+
case AMDGPU::S_ADD_U64_PSEUDO:
5685+
case AMDGPU::S_SUB_U64_PSEUDO: {
5686+
NewAccumulator = BuildMI(*ComputeLoop, I, DL, TII->get(Opc), DstReg)
5687+
.addReg(Accumulator->getOperand(0).getReg())
5688+
.addReg(LaneValue->getOperand(0).getReg());
5689+
ComputeLoop = Expand64BitScalarArithmetic(*NewAccumulator, ComputeLoop);
5690+
break;
5691+
}
55585692
}
55595693
}
55605694
// Manipulate the iterator to get the next active lane
@@ -5565,8 +5699,7 @@ static MachineBasicBlock *lowerWaveReduce(MachineInstr &MI,
55655699
.addReg(ActiveBitsReg);
55665700

55675701
// Add phi nodes
5568-
Accumulator.addReg(NewAccumulator->getOperand(0).getReg())
5569-
.addMBB(ComputeLoop);
5702+
Accumulator.addReg(DstReg).addMBB(ComputeLoop);
55705703
ActiveBits.addReg(NewActiveBitsReg).addMBB(ComputeLoop);
55715704

55725705
// Creating branching
@@ -5610,8 +5743,12 @@ SITargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI,
56105743
return lowerWaveReduce(MI, *BB, *getSubtarget(), AMDGPU::V_CMP_GT_I64_e64);
56115744
case AMDGPU::WAVE_REDUCE_ADD_PSEUDO_I32:
56125745
return lowerWaveReduce(MI, *BB, *getSubtarget(), AMDGPU::S_ADD_I32);
5746+
case AMDGPU::WAVE_REDUCE_ADD_PSEUDO_U64:
5747+
return lowerWaveReduce(MI, *BB, *getSubtarget(), AMDGPU::S_ADD_U64_PSEUDO);
56135748
case AMDGPU::WAVE_REDUCE_SUB_PSEUDO_I32:
56145749
return lowerWaveReduce(MI, *BB, *getSubtarget(), AMDGPU::S_SUB_I32);
5750+
case AMDGPU::WAVE_REDUCE_SUB_PSEUDO_U64:
5751+
return lowerWaveReduce(MI, *BB, *getSubtarget(), AMDGPU::S_SUB_U64_PSEUDO);
56155752
case AMDGPU::WAVE_REDUCE_AND_PSEUDO_B32:
56165753
return lowerWaveReduce(MI, *BB, *getSubtarget(), AMDGPU::S_AND_B32);
56175754
case AMDGPU::WAVE_REDUCE_OR_PSEUDO_B32:
@@ -5644,55 +5781,7 @@ SITargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI,
56445781
}
56455782
case AMDGPU::S_ADD_U64_PSEUDO:
56465783
case AMDGPU::S_SUB_U64_PSEUDO: {
5647-
// For targets older than GFX12, we emit a sequence of 32-bit operations.
5648-
// For GFX12, we emit s_add_u64 and s_sub_u64.
5649-
const GCNSubtarget &ST = MF->getSubtarget<GCNSubtarget>();
5650-
MachineRegisterInfo &MRI = BB->getParent()->getRegInfo();
5651-
const DebugLoc &DL = MI.getDebugLoc();
5652-
MachineOperand &Dest = MI.getOperand(0);
5653-
MachineOperand &Src0 = MI.getOperand(1);
5654-
MachineOperand &Src1 = MI.getOperand(2);
5655-
bool IsAdd = (MI.getOpcode() == AMDGPU::S_ADD_U64_PSEUDO);
5656-
if (Subtarget->hasScalarAddSub64()) {
5657-
unsigned Opc = IsAdd ? AMDGPU::S_ADD_U64 : AMDGPU::S_SUB_U64;
5658-
// clang-format off
5659-
BuildMI(*BB, MI, DL, TII->get(Opc), Dest.getReg())
5660-
.add(Src0)
5661-
.add(Src1);
5662-
// clang-format on
5663-
} else {
5664-
const SIRegisterInfo *TRI = ST.getRegisterInfo();
5665-
const TargetRegisterClass *BoolRC = TRI->getBoolRC();
5666-
5667-
Register DestSub0 = MRI.createVirtualRegister(&AMDGPU::SReg_32RegClass);
5668-
Register DestSub1 = MRI.createVirtualRegister(&AMDGPU::SReg_32RegClass);
5669-
5670-
MachineOperand Src0Sub0 = TII->buildExtractSubRegOrImm(
5671-
MI, MRI, Src0, BoolRC, AMDGPU::sub0, &AMDGPU::SReg_32RegClass);
5672-
MachineOperand Src0Sub1 = TII->buildExtractSubRegOrImm(
5673-
MI, MRI, Src0, BoolRC, AMDGPU::sub1, &AMDGPU::SReg_32RegClass);
5674-
5675-
MachineOperand Src1Sub0 = TII->buildExtractSubRegOrImm(
5676-
MI, MRI, Src1, BoolRC, AMDGPU::sub0, &AMDGPU::SReg_32RegClass);
5677-
MachineOperand Src1Sub1 = TII->buildExtractSubRegOrImm(
5678-
MI, MRI, Src1, BoolRC, AMDGPU::sub1, &AMDGPU::SReg_32RegClass);
5679-
5680-
unsigned LoOpc = IsAdd ? AMDGPU::S_ADD_U32 : AMDGPU::S_SUB_U32;
5681-
unsigned HiOpc = IsAdd ? AMDGPU::S_ADDC_U32 : AMDGPU::S_SUBB_U32;
5682-
BuildMI(*BB, MI, DL, TII->get(LoOpc), DestSub0)
5683-
.add(Src0Sub0)
5684-
.add(Src1Sub0);
5685-
BuildMI(*BB, MI, DL, TII->get(HiOpc), DestSub1)
5686-
.add(Src0Sub1)
5687-
.add(Src1Sub1);
5688-
BuildMI(*BB, MI, DL, TII->get(TargetOpcode::REG_SEQUENCE), Dest.getReg())
5689-
.addReg(DestSub0)
5690-
.addImm(AMDGPU::sub0)
5691-
.addReg(DestSub1)
5692-
.addImm(AMDGPU::sub1);
5693-
}
5694-
MI.eraseFromParent();
5695-
return BB;
5784+
return Expand64BitScalarArithmetic(MI, BB);
56965785
}
56975786
case AMDGPU::V_ADD_U64_PSEUDO:
56985787
case AMDGPU::V_SUB_U64_PSEUDO: {

llvm/lib/Target/AMDGPU/SIInstructions.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,8 @@ defvar Operations = [
367367
WaveReduceOp<"min", "I64", i64, SGPR_64, VSrc_b64>,
368368
WaveReduceOp<"umax", "U64", i64, SGPR_64, VSrc_b64>,
369369
WaveReduceOp<"max", "I64", i64, SGPR_64, VSrc_b64>,
370+
WaveReduceOp<"add", "U64", i64, SGPR_64, VSrc_b64>,
371+
WaveReduceOp<"sub", "U64", i64, SGPR_64, VSrc_b64>,
370372
];
371373

372374
foreach Op = Operations in {

0 commit comments

Comments
 (0)