Skip to content

Commit 53d6db2

Browse files
committed
[LegalizeTypes][RISCV] Control sign-extend for atomicrmw input argument
Similar to the argument sign-extend control provided by getExtendForAtomicCmpSwapArg for `cmpxchg`, this change adds getExtendForAtomicRMWArg for `atomicrmw <op>`. This mechanism is used to correct the RISCV code generation when using atomic pseudo expansions that perform sub-word comparions inside a LR/SC block, and expect a correctly sign-extended input argument. A shorter instruction sequence is possible for umin/max by leveraging the fact that two sign-extended arguments still order correctly during an unsigned comparison.
1 parent 5c976f7 commit 53d6db2

File tree

8 files changed

+190
-160
lines changed

8 files changed

+190
-160
lines changed

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2459,6 +2459,12 @@ class LLVM_ABI TargetLoweringBase {
24592459
return ISD::ANY_EXTEND;
24602460
}
24612461

2462+
/// Returns how the platform's atomic rmw operations expect their input
2463+
/// argument to be extended (ZERO_EXTEND, SIGN_EXTEND, or ANY_EXTEND).
2464+
virtual ISD::NodeType getExtendForAtomicRMWArg(unsigned Op) const {
2465+
return ISD::ANY_EXTEND;
2466+
}
2467+
24622468
/// @}
24632469

24642470
/// Returns true if we should normalize

llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,20 @@ SDValue DAGTypeLegalizer::PromoteIntRes_Atomic0(AtomicSDNode *N) {
429429
}
430430

431431
SDValue DAGTypeLegalizer::PromoteIntRes_Atomic1(AtomicSDNode *N) {
432-
SDValue Op2 = GetPromotedInteger(N->getOperand(2));
432+
SDValue Op2 = N->getOperand(2);
433+
switch(TLI.getExtendForAtomicRMWArg(N->getOpcode())) {
434+
case ISD::SIGN_EXTEND:
435+
Op2 = SExtPromotedInteger(Op2);
436+
break;
437+
case ISD::ZERO_EXTEND:
438+
Op2 = ZExtPromotedInteger(Op2);
439+
break;
440+
case ISD::ANY_EXTEND:
441+
Op2 = GetPromotedInteger(Op2);
442+
break;
443+
default:
444+
llvm_unreachable("Invalid atomic op extension");
445+
}
433446
SDValue Res = DAG.getAtomic(N->getOpcode(), SDLoc(N),
434447
N->getMemoryVT(),
435448
N->getChain(), N->getBasePtr(),

llvm/lib/Target/RISCV/RISCVExpandAtomicPseudoInsts.cpp

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -523,17 +523,6 @@ static void insertSext(const RISCVInstrInfo *TII, DebugLoc DL,
523523
.addReg(ShamtReg);
524524
}
525525

526-
static void insertZext(const RISCVInstrInfo *TII, DebugLoc DL,
527-
MachineBasicBlock *MBB, Register ValReg,
528-
Register SrcReg, int64_t Shamt) {
529-
BuildMI(MBB, DL, TII->get(RISCV::SLLI), ValReg)
530-
.addReg(SrcReg)
531-
.addImm(Shamt);
532-
BuildMI(MBB, DL, TII->get(RISCV::SRLI), ValReg)
533-
.addReg(ValReg)
534-
.addImm(Shamt);
535-
}
536-
537526
static void doAtomicMinMaxOpExpansion(const RISCVInstrInfo *TII, MachineInstr &MI,
538527
DebugLoc DL, MachineBasicBlock *ThisMBB,
539528
MachineBasicBlock *LoopHeadMBB,
@@ -546,9 +535,6 @@ static void doAtomicMinMaxOpExpansion(const RISCVInstrInfo *TII, MachineInstr &M
546535
Register ScratchReg = MI.getOperand(1).getReg();
547536
Register AddrReg = MI.getOperand(2).getReg();
548537
Register IncrReg = MI.getOperand(3).getReg();
549-
bool IsUnsigned = BinOp == AtomicRMWInst::UMin ||
550-
BinOp == AtomicRMWInst::UMax;
551-
bool Zext = IsUnsigned && STI->is64Bit() && Width == 32;
552538
AtomicOrdering Ordering =
553539
static_cast<AtomicOrdering>(MI.getOperand(4).getImm());
554540

@@ -558,12 +544,9 @@ static void doAtomicMinMaxOpExpansion(const RISCVInstrInfo *TII, MachineInstr &M
558544
// ifnochangeneeded scratch, incr, .looptail
559545
BuildMI(LoopHeadMBB, DL, TII->get(getLRForRMW(Ordering, Width, STI)), DestReg)
560546
.addReg(AddrReg);
561-
if (Zext)
562-
insertZext(TII, DL, LoopHeadMBB, ScratchReg, DestReg, 32);
563-
else
564-
BuildMI(LoopHeadMBB, DL, TII->get(RISCV::ADDI), ScratchReg)
565-
.addReg(DestReg)
566-
.addImm(0);
547+
BuildMI(LoopHeadMBB, DL, TII->get(RISCV::ADDI), ScratchReg)
548+
.addReg(DestReg)
549+
.addImm(0);
567550
switch (BinOp) {
568551
default:
569552
llvm_unreachable("Unexpected AtomicRMW BinOp");

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24483,6 +24483,25 @@ ISD::NodeType RISCVTargetLowering::getExtendForAtomicCmpSwapArg() const {
2448324483
return Subtarget.hasStdExtZacas() ? ISD::ANY_EXTEND : ISD::SIGN_EXTEND;
2448424484
}
2448524485

24486+
ISD::NodeType RISCVTargetLowering::getExtendForAtomicRMWArg(unsigned Op) const {
24487+
// Zaamo will use amo<op>.w which does not require extension.
24488+
if (Subtarget.hasStdExtZaamo() || Subtarget.hasForcedAtomics())
24489+
return ISD::ANY_EXTEND;
24490+
24491+
// Zalrsc pseudo expansions with comparison require sign-extension.
24492+
assert(Subtarget.hasStdExtZalrsc());
24493+
switch (Op) {
24494+
case ISD::ATOMIC_LOAD_MIN:
24495+
case ISD::ATOMIC_LOAD_MAX:
24496+
case ISD::ATOMIC_LOAD_UMIN:
24497+
case ISD::ATOMIC_LOAD_UMAX:
24498+
return ISD::SIGN_EXTEND;
24499+
default:
24500+
break;
24501+
}
24502+
return ISD::ANY_EXTEND;
24503+
}
24504+
2448624505
Register RISCVTargetLowering::getExceptionPointerRegister(
2448724506
const Constant *PersonalityFn) const {
2448824507
return RISCV::X10;

llvm/lib/Target/RISCV/RISCVISelLowering.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ class RISCVTargetLowering : public TargetLowering {
245245
}
246246

247247
ISD::NodeType getExtendForAtomicCmpSwapArg() const override;
248+
ISD::NodeType getExtendForAtomicRMWArg(unsigned Op) const override;
248249

249250
bool shouldTransformSignedTruncationCheck(EVT XVT,
250251
unsigned KeptBits) const override;

llvm/lib/Target/RISCV/RISCVInstrInfoA.td

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -321,11 +321,9 @@ def PseudoAtomicLoadXor32 : PseudoAMO;
321321
let Size = 24 in {
322322
def PseudoAtomicLoadMax32 : PseudoAMO;
323323
def PseudoAtomicLoadMin32 : PseudoAMO;
324-
} // Size = 24
325-
let Size = 28 in {
326324
def PseudoAtomicLoadUMax32 : PseudoAMO;
327325
def PseudoAtomicLoadUMin32 : PseudoAMO;
328-
} // Size = 28
326+
} // Size = 24
329327

330328
defm : PseudoAMOPat<"atomic_swap_i32", PseudoAtomicSwap32>;
331329
defm : PseudoAMOPat<"atomic_load_add_i32", PseudoAtomicLoadAdd32>;

0 commit comments

Comments
 (0)