Skip to content

Commit 9ae5b27

Browse files
committed
[LegalizeTypes][RISCV] Control sign-extend for atomicrmw input argument
Similar to the argument sign-extend control provided by getExtendForAtomicCmpSwapArg for `cmpxchg`, this change adds getExtendForAtomicRMWArg for `atomicrmw <op>`. This mechanism is used to correct the RISCV code generation when using atomic pseudo expansions that perform sub-word comparions inside a LR/SC block, and expect a correctly sign-extended input argument. A shorter instruction sequence is possible for umin/max by leveraging the fact that two sign-extended arguments still order correctly during an unsigned comparison.
1 parent b2436be commit 9ae5b27

File tree

8 files changed

+190
-156
lines changed

8 files changed

+190
-156
lines changed

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2459,6 +2459,12 @@ class LLVM_ABI TargetLoweringBase {
24592459
return ISD::ANY_EXTEND;
24602460
}
24612461

2462+
/// Returns how the platform's atomic rmw operations expect their input
2463+
/// argument to be extended (ZERO_EXTEND, SIGN_EXTEND, or ANY_EXTEND).
2464+
virtual ISD::NodeType getExtendForAtomicRMWArg(unsigned Op) const {
2465+
return ISD::ANY_EXTEND;
2466+
}
2467+
24622468
/// @}
24632469

24642470
/// Returns true if we should normalize

llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,20 @@ SDValue DAGTypeLegalizer::PromoteIntRes_Atomic0(AtomicSDNode *N) {
429429
}
430430

431431
SDValue DAGTypeLegalizer::PromoteIntRes_Atomic1(AtomicSDNode *N) {
432-
SDValue Op2 = GetPromotedInteger(N->getOperand(2));
432+
SDValue Op2 = N->getOperand(2);
433+
switch (TLI.getExtendForAtomicRMWArg(N->getOpcode())) {
434+
case ISD::SIGN_EXTEND:
435+
Op2 = SExtPromotedInteger(Op2);
436+
break;
437+
case ISD::ZERO_EXTEND:
438+
Op2 = ZExtPromotedInteger(Op2);
439+
break;
440+
case ISD::ANY_EXTEND:
441+
Op2 = GetPromotedInteger(Op2);
442+
break;
443+
default:
444+
llvm_unreachable("Invalid atomic op extension");
445+
}
433446
SDValue Res = DAG.getAtomic(N->getOpcode(), SDLoc(N),
434447
N->getMemoryVT(),
435448
N->getChain(), N->getBasePtr(),

llvm/lib/Target/RISCV/RISCVExpandAtomicPseudoInsts.cpp

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -521,13 +521,6 @@ static void insertSext(const RISCVInstrInfo *TII, DebugLoc DL,
521521
.addReg(ShamtReg);
522522
}
523523

524-
static void insertZext(const RISCVInstrInfo *TII, DebugLoc DL,
525-
MachineBasicBlock *MBB, Register ValReg, Register SrcReg,
526-
int64_t Shamt) {
527-
BuildMI(MBB, DL, TII->get(RISCV::SLLI), ValReg).addReg(SrcReg).addImm(Shamt);
528-
BuildMI(MBB, DL, TII->get(RISCV::SRLI), ValReg).addReg(ValReg).addImm(Shamt);
529-
}
530-
531524
static void doAtomicMinMaxOpExpansion(
532525
const RISCVInstrInfo *TII, MachineInstr &MI, DebugLoc DL,
533526
MachineBasicBlock *ThisMBB, MachineBasicBlock *LoopHeadMBB,
@@ -538,9 +531,6 @@ static void doAtomicMinMaxOpExpansion(
538531
Register ScratchReg = MI.getOperand(1).getReg();
539532
Register AddrReg = MI.getOperand(2).getReg();
540533
Register IncrReg = MI.getOperand(3).getReg();
541-
bool IsUnsigned =
542-
BinOp == AtomicRMWInst::UMin || BinOp == AtomicRMWInst::UMax;
543-
bool Zext = IsUnsigned && STI->is64Bit() && Width == 32;
544534
AtomicOrdering Ordering =
545535
static_cast<AtomicOrdering>(MI.getOperand(4).getImm());
546536

@@ -550,12 +540,9 @@ static void doAtomicMinMaxOpExpansion(
550540
// ifnochangeneeded scratch, incr, .looptail
551541
BuildMI(LoopHeadMBB, DL, TII->get(getLRForRMW(Ordering, Width, STI)), DestReg)
552542
.addReg(AddrReg);
553-
if (Zext)
554-
insertZext(TII, DL, LoopHeadMBB, ScratchReg, DestReg, 32);
555-
else
556-
BuildMI(LoopHeadMBB, DL, TII->get(RISCV::ADDI), ScratchReg)
557-
.addReg(DestReg)
558-
.addImm(0);
543+
BuildMI(LoopHeadMBB, DL, TII->get(RISCV::ADDI), ScratchReg)
544+
.addReg(DestReg)
545+
.addImm(0);
559546
switch (BinOp) {
560547
default:
561548
llvm_unreachable("Unexpected AtomicRMW BinOp");

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24485,6 +24485,25 @@ ISD::NodeType RISCVTargetLowering::getExtendForAtomicCmpSwapArg() const {
2448524485
return Subtarget.hasStdExtZacas() ? ISD::ANY_EXTEND : ISD::SIGN_EXTEND;
2448624486
}
2448724487

24488+
ISD::NodeType RISCVTargetLowering::getExtendForAtomicRMWArg(unsigned Op) const {
24489+
// Zaamo will use amo<op>.w which does not require extension.
24490+
if (Subtarget.hasStdExtZaamo() || Subtarget.hasForcedAtomics())
24491+
return ISD::ANY_EXTEND;
24492+
24493+
// Zalrsc pseudo expansions with comparison require sign-extension.
24494+
assert(Subtarget.hasStdExtZalrsc());
24495+
switch (Op) {
24496+
case ISD::ATOMIC_LOAD_MIN:
24497+
case ISD::ATOMIC_LOAD_MAX:
24498+
case ISD::ATOMIC_LOAD_UMIN:
24499+
case ISD::ATOMIC_LOAD_UMAX:
24500+
return ISD::SIGN_EXTEND;
24501+
default:
24502+
break;
24503+
}
24504+
return ISD::ANY_EXTEND;
24505+
}
24506+
2448824507
Register RISCVTargetLowering::getExceptionPointerRegister(
2448924508
const Constant *PersonalityFn) const {
2449024509
return RISCV::X10;

llvm/lib/Target/RISCV/RISCVISelLowering.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ class RISCVTargetLowering : public TargetLowering {
245245
}
246246

247247
ISD::NodeType getExtendForAtomicCmpSwapArg() const override;
248+
ISD::NodeType getExtendForAtomicRMWArg(unsigned Op) const override;
248249

249250
bool shouldTransformSignedTruncationCheck(EVT XVT,
250251
unsigned KeptBits) const override;

llvm/lib/Target/RISCV/RISCVInstrInfoA.td

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -321,11 +321,9 @@ def PseudoAtomicLoadXor32 : PseudoAMO;
321321
let Size = 24 in {
322322
def PseudoAtomicLoadMax32 : PseudoAMO;
323323
def PseudoAtomicLoadMin32 : PseudoAMO;
324-
} // Size = 24
325-
let Size = 28 in {
326324
def PseudoAtomicLoadUMax32 : PseudoAMO;
327325
def PseudoAtomicLoadUMin32 : PseudoAMO;
328-
} // Size = 28
326+
} // Size = 24
329327

330328
defm : PseudoAMOPat<"atomic_swap_i32", PseudoAtomicSwap32>;
331329
defm : PseudoAMOPat<"atomic_load_add_i32", PseudoAtomicLoadAdd32>;

0 commit comments

Comments
 (0)