Skip to content

Commit 289c3e4

Browse files
committed
[LegalizeTypes][RISCV] Control sign-extend for atomicrmw input argument
Similar to the argument sign-extend control provided by getExtendForAtomicCmpSwapArg for `cmpxchg`, this change adds getExtendForAtomicRMWArg for `atomicrmw <op>`. This mechanism is used to correct the RISCV code generation when using atomic pseudo expansions that perform sub-word comparions inside a LR/SC block, and expect a correctly sign-extended input argument.
1 parent 1c06bd8 commit 289c3e4

File tree

6 files changed

+197
-121
lines changed

6 files changed

+197
-121
lines changed

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2459,6 +2459,12 @@ class LLVM_ABI TargetLoweringBase {
24592459
return ISD::ANY_EXTEND;
24602460
}
24612461

2462+
/// Returns how the platform's atomic rmw operations expect their input
2463+
/// argument to be extended (ZERO_EXTEND, SIGN_EXTEND, or ANY_EXTEND).
2464+
virtual ISD::NodeType getExtendForAtomicRMWArg(unsigned Op) const {
2465+
return ISD::ANY_EXTEND;
2466+
}
2467+
24622468
/// @}
24632469

24642470
/// Returns true if we should normalize

llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,20 @@ SDValue DAGTypeLegalizer::PromoteIntRes_Atomic0(AtomicSDNode *N) {
429429
}
430430

431431
SDValue DAGTypeLegalizer::PromoteIntRes_Atomic1(AtomicSDNode *N) {
432-
SDValue Op2 = GetPromotedInteger(N->getOperand(2));
432+
SDValue Op2 = N->getOperand(2);
433+
switch(TLI.getExtendForAtomicRMWArg(N->getOpcode())) {
434+
case ISD::SIGN_EXTEND:
435+
Op2 = SExtPromotedInteger(Op2);
436+
break;
437+
case ISD::ZERO_EXTEND:
438+
Op2 = ZExtPromotedInteger(Op2);
439+
break;
440+
case ISD::ANY_EXTEND:
441+
Op2 = GetPromotedInteger(Op2);
442+
break;
443+
default:
444+
llvm_unreachable("Invalid atomic op extension");
445+
}
433446
SDValue Res = DAG.getAtomic(N->getOpcode(), SDLoc(N),
434447
N->getMemoryVT(),
435448
N->getChain(), N->getBasePtr(),

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24483,6 +24483,26 @@ ISD::NodeType RISCVTargetLowering::getExtendForAtomicCmpSwapArg() const {
2448324483
return Subtarget.hasStdExtZacas() ? ISD::ANY_EXTEND : ISD::SIGN_EXTEND;
2448424484
}
2448524485

24486+
ISD::NodeType RISCVTargetLowering::getExtendForAtomicRMWArg(unsigned Op) const {
24487+
// Zaamo will use amo<op>.w which does not require extension.
24488+
if (Subtarget.hasStdExtZaamo())
24489+
return ISD::ANY_EXTEND;
24490+
24491+
// Zalasr pseudo expansions with comparison
24492+
assert(Subtarget.hasStdExtZalrsc());
24493+
switch (Op) {
24494+
case ISD::ATOMIC_LOAD_MIN:
24495+
case ISD::ATOMIC_LOAD_MAX:
24496+
return ISD::SIGN_EXTEND;
24497+
case ISD::ATOMIC_LOAD_UMIN:
24498+
case ISD::ATOMIC_LOAD_UMAX:
24499+
return ISD::ZERO_EXTEND;
24500+
default:
24501+
break;
24502+
}
24503+
return ISD::ANY_EXTEND;
24504+
}
24505+
2448624506
Register RISCVTargetLowering::getExceptionPointerRegister(
2448724507
const Constant *PersonalityFn) const {
2448824508
return RISCV::X10;

llvm/lib/Target/RISCV/RISCVISelLowering.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ class RISCVTargetLowering : public TargetLowering {
245245
}
246246

247247
ISD::NodeType getExtendForAtomicCmpSwapArg() const override;
248+
ISD::NodeType getExtendForAtomicRMWArg(unsigned Op) const override;
248249

249250
bool shouldTransformSignedTruncationCheck(EVT XVT,
250251
unsigned KeptBits) const override;

0 commit comments

Comments
 (0)