Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
b26f615
[AArch64] Add CodeGen support for FEAT_CPA
rgwott Aug 22, 2024
b99a0b9
Merge branch main into branch cpa
Jan 20, 2025
fb33270
[SelectionDAG] Refactor use of getMemBasePlusOffset() where applicable
rgwott Jan 20, 2025
6e491cd
[SelectionDAG] Inverted args for PTRADD on getMemBasePlusOffset()
rgwott Jan 22, 2025
eebe538
Fix minor lexical problems
rgwott Jan 22, 2025
76dae60
Satisfy undef deprecator
rgwott Jan 28, 2025
10651fa
Remove fold with unhandeable corner case
rgwott Jan 31, 2025
e64ef52
Add comment explaining removed fold for future work
rgwott Jan 31, 2025
d0b2362
Minor lexical fix
rgwott Jan 31, 2025
82d0550
Merge branch main into branch cpa
rgwott Feb 4, 2025
c70948f
Remove getMemBasePlusOffset inversion logic after #125279
rgwott Feb 4, 2025
4330997
Remove inadequate use of getMemBasePlusOffset()
rgwott Feb 4, 2025
0202b39
Remove obsolete AddedCompletixy in FEAT_CPA tablegen patterns
rgwott Feb 6, 2025
15f8b6b
Modify comment in PTRADD declaration
rgwott Feb 6, 2025
a61a19f
Gate FEAT_CPA CodeGen behind -mcpa-codegen flag
rgwott Feb 6, 2025
3a57e4d
Merge remote-tracking branch 'origin/main' into cpa
rgwott Jun 18, 2025
aec4f71
Autogenerate FEAT_CPA tests with utils/update_llc_test_checks.py
rgwott Jun 18, 2025
0935a57
Fix minor typo
rgwott Jun 18, 2025
df8a380
Fix botched merge from upstream
rgwott Jun 18, 2025
1a7982e
Fix remaining botched merge from upstream
rgwott Jun 18, 2025
d049ea5
Adapt FEAT_CPA SDAG/GISel to shouldPreservePtrArith in targetLowering
rgwott Jun 18, 2025
3e6ccee
Fix missing newline
rgwott Jun 18, 2025
b0b575b
Remove user-facing option -mcpa-codegen and use backend option instead
rgwott Jun 18, 2025
8d8cca7
Remove now-superfluous HasCPACodegen() function
rgwott Jun 18, 2025
ca45208
Remove -global-isel-abort=1 from GlobalISel tests
rgwott Jun 20, 2025
fcde60c
Insert printf() call in CPA tests to prevent case from optimizing away
rgwott Jun 20, 2025
029f1c3
Make (ADDPT|SUBPT) shift operand an i64 as opposed to i32
rgwott Jun 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions clang/include/clang/Driver/Options.td
Original file line number Diff line number Diff line change
Expand Up @@ -5025,6 +5025,11 @@ def msve_vector_bits_EQ : Joined<["-"], "msve-vector-bits=">, Group<m_aarch64_Fe
Visibility<[ClangOption, FlangOption]>,
HelpText<"Specify the size in bits of an SVE vector register. Defaults to the"
" vector length agnostic value of \"scalable\". (AArch64 only)">;

def mcpa_codegen : Flag<["-"], "mcpa-codegen">,
Visibility<[ClangOption]>,
Group<m_aarch64_Features_Group>,
HelpText<"Generate scalar FEAT_CPA instructions (AArch64 only)">;
} // let Flags = [TargetSpecific]

def mvscale_min_EQ : Joined<["-"], "mvscale-min=">,
Expand Down
5 changes: 5 additions & 0 deletions clang/lib/Driver/ToolChains/Arch/AArch64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,11 @@ void aarch64::getAArch64TargetFeatures(const Driver &D,
Extensions.disable(llvm::AArch64::AEK_FP);
}

// -mcpa-codegen enables generation of scalar FEAT_CPA instructions
if (Args.getLastArg(options::OPT_mcpa_codegen)) {
Features.push_back("+cpa-codegen");
}

// En/disable crc
if (Arg *A = Args.getLastArg(options::OPT_mcrc, options::OPT_mnocrc)) {
if (A->getOption().matches(options::OPT_mcrc))
Expand Down
5 changes: 5 additions & 0 deletions llvm/include/llvm/CodeGen/ISDOpcodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -1462,6 +1462,11 @@ enum NodeType {
// Outputs: [rv], output chain, glue
PATCHPOINT,

// PTRADD represents pointer arithmatic semantics, for targets that opt in
// using shouldPreservePtrArith().
// ptr = PTRADD ptr, offset
PTRADD,

// Vector Predication
#define BEGIN_REGISTER_VP_SDNODE(VPSDID, ...) VPSDID,
#include "llvm/IR/VPIntrinsics.def"
Expand Down
5 changes: 5 additions & 0 deletions llvm/include/llvm/Target/TargetMachine.h
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,11 @@ class TargetMachine {
return false;
}

/// True if target has some particular form of dealing with pointer arithmetic
/// semantics. False if pointer arithmetic should not be preserved for passes
/// such as instruction selection, and can fallback to regular arithmetic.
virtual bool shouldPreservePtrArith(const Function &F) const { return false; }

/// Create a pass configuration object to be used by addPassToEmitX methods
/// for generating a pipeline of CodeGen passes.
virtual TargetPassConfig *createPassConfig(PassManagerBase &PM) {
Expand Down
2 changes: 1 addition & 1 deletion llvm/include/llvm/Target/TargetSelectionDAG.td
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ def tblockaddress: SDNode<"ISD::TargetBlockAddress", SDTPtrLeaf, [],

def add : SDNode<"ISD::ADD" , SDTIntBinOp ,
[SDNPCommutative, SDNPAssociative]>;
def ptradd : SDNode<"ISD::ADD" , SDTPtrAddOp, []>;
def ptradd : SDNode<"ISD::PTRADD" , SDTPtrAddOp, []>;
def sub : SDNode<"ISD::SUB" , SDTIntBinOp>;
def mul : SDNode<"ISD::MUL" , SDTIntBinOp,
[SDNPCommutative, SDNPAssociative]>;
Expand Down
103 changes: 100 additions & 3 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,9 @@ namespace {
SDValue visitMERGE_VALUES(SDNode *N);
SDValue visitADD(SDNode *N);
SDValue visitADDLike(SDNode *N);
SDValue visitADDLikeCommutative(SDValue N0, SDValue N1, SDNode *LocReference);
SDValue visitADDLikeCommutative(SDValue N0, SDValue N1,
SDNode *LocReference);
SDValue visitPTRADD(SDNode *N);
SDValue visitSUB(SDNode *N);
SDValue visitADDSAT(SDNode *N);
SDValue visitSUBSAT(SDNode *N);
Expand Down Expand Up @@ -1083,7 +1085,7 @@ bool DAGCombiner::reassociationCanBreakAddressingModePattern(unsigned Opc,
// (load/store (add, (add, x, y), offset2)) ->
// (load/store (add, (add, x, offset2), y)).

if (N0.getOpcode() != ISD::ADD)
if (N0.getOpcode() != ISD::ADD && N0.getOpcode() != ISD::PTRADD)
return false;

// Check for vscale addressing modes.
Expand Down Expand Up @@ -1840,6 +1842,7 @@ SDValue DAGCombiner::visit(SDNode *N) {
case ISD::TokenFactor: return visitTokenFactor(N);
case ISD::MERGE_VALUES: return visitMERGE_VALUES(N);
case ISD::ADD: return visitADD(N);
case ISD::PTRADD: return visitPTRADD(N);
case ISD::SUB: return visitSUB(N);
case ISD::SADDSAT:
case ISD::UADDSAT: return visitADDSAT(N);
Expand Down Expand Up @@ -2373,7 +2376,7 @@ static bool canFoldInAddressingMode(SDNode *N, SDNode *Use, SelectionDAG &DAG,
}

TargetLowering::AddrMode AM;
if (N->getOpcode() == ISD::ADD) {
if (N->getOpcode() == ISD::ADD || N->getOpcode() == ISD::PTRADD) {
AM.HasBaseReg = true;
ConstantSDNode *Offset = dyn_cast<ConstantSDNode>(N->getOperand(1));
if (Offset)
Expand Down Expand Up @@ -2602,6 +2605,100 @@ SDValue DAGCombiner::foldSubToAvg(SDNode *N, const SDLoc &DL) {
return SDValue();
}

/// Try to fold a pointer arithmetic node.
/// This needs to be done separately from normal addition, because pointer
/// addition is not commutative.
/// This function was adapted from DAGCombiner::visitPTRADD() from the Morello
/// project, which is based on CHERI.
SDValue DAGCombiner::visitPTRADD(SDNode *N) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
EVT PtrVT = N0.getValueType();
EVT IntVT = N1.getValueType();
SDLoc DL(N);

// fold (ptradd undef, y) -> undef
if (N0.isUndef())
return N0;

// fold (ptradd x, undef) -> undef
if (N1.isUndef())
return DAG.getUNDEF(PtrVT);

// fold (ptradd x, 0) -> x
if (isNullConstant(N1))
return N0;

if (N0.getOpcode() == ISD::PTRADD &&
!reassociationCanBreakAddressingModePattern(ISD::PTRADD, DL, N, N0, N1)) {
SDValue X = N0.getOperand(0);
SDValue Y = N0.getOperand(1);
SDValue Z = N1;
bool N0OneUse = N0.hasOneUse();
bool YIsConstant = DAG.isConstantIntBuildVectorOrConstantInt(Y);
bool ZIsConstant = DAG.isConstantIntBuildVectorOrConstantInt(Z);
bool ZOneUse = Z.hasOneUse();

// (ptradd (ptradd x, y), z) -> (ptradd x, (add y, z)) if:
// * x is a null pointer; or
// * y is a constant and z has one use; or
// * y is a constant and (ptradd x, y) has one use; or
// * (ptradd x, y) and z have one use and z is not a constant.
if (isNullConstant(X) || (YIsConstant && ZOneUse) ||
(YIsConstant && N0OneUse) || (N0OneUse && ZOneUse && !ZIsConstant)) {
SDValue Add = DAG.getNode(ISD::ADD, DL, IntVT, {Y, Z});

// Calling visit() can replace the Add node with ISD::DELETED_NODE if
// there aren't any users, so keep a handle around whilst we visit it.
HandleSDNode ADDHandle(Add);

SDValue VisitedAdd = visit(Add.getNode());
if (VisitedAdd) {
// If visit() returns the same node, it means the SDNode was RAUW'd, and
// therefore we have to load the new value to perform the checks whether
// the reassociation fold is profitable.
if (VisitedAdd.getNode() == Add.getNode())
Add = ADDHandle.getValue();
else
Add = VisitedAdd;
}

return DAG.getMemBasePlusOffset(X, Add, DL, SDNodeFlags());
}

// TODO: There is another possible fold here that was proven useful.
// It would be this:
//
// (ptradd (ptradd x, y), z) -> (ptradd (ptradd x, z), y) if:
// * (ptradd x, y) has one use; and
// * y is a constant; and
// * z is not a constant.
//
// In some cases, specifically in AArch64's FEAT_CPA, it exposes the
// opportunity to select more complex instructions such as SUBPT and
// MSUBPT. However, a hypothetical corner case has been found that we could
// not avoid. Consider this (pseudo-POSIX C):
//
// char *foo(char *x, int z) {return (x + LARGE_CONSTANT) + z;}
// char *p = mmap(LARGE_CONSTANT);
// char *q = foo(p, -LARGE_CONSTANT);
//
// Then x + LARGE_CONSTANT is one-past-the-end, so valid, and a
// further + z takes it back to the start of the mapping, so valid,
// regardless of the address mmap gave back. However, if mmap gives you an
// address < LARGE_CONSTANT (ignoring high bits), x - LARGE_CONSTANT will
// borrow from the high bits (with the subsequent + z carrying back into
// the high bits to give you a well-defined pointer) and thus trip
// FEAT_CPA's pointer corruption checks.
//
// We leave this fold as an opportunity for future work, addressing the
// corner case for FEAT_CPA, as well as reconciling the solution with the
// more general application of pointer arithmetic in other future targets.
}

return SDValue();
}

/// Try to fold a 'not' shifted sign-bit with add/sub with constant operand into
/// a shift and add with a different constant.
static SDValue foldAddSubOfSignBit(SDNode *N, const SDLoc &DL,
Expand Down
10 changes: 8 additions & 2 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5667,7 +5667,8 @@ bool SelectionDAG::isADDLike(SDValue Op, bool NoWrap) const {

bool SelectionDAG::isBaseWithConstantOffset(SDValue Op) const {
return Op.getNumOperands() == 2 && isa<ConstantSDNode>(Op.getOperand(1)) &&
(Op.getOpcode() == ISD::ADD || isADDLike(Op));
(Op.getOpcode() == ISD::ADD || Op.getOpcode() == ISD::PTRADD ||
isADDLike(Op));
}

bool SelectionDAG::isKnownNeverNaN(SDValue Op, bool SNaN, unsigned Depth) const {
Expand Down Expand Up @@ -8071,7 +8072,12 @@ SDValue SelectionDAG::getMemBasePlusOffset(SDValue Ptr, SDValue Offset,
const SDNodeFlags Flags) {
assert(Offset.getValueType().isInteger());
EVT BasePtrVT = Ptr.getValueType();
return getNode(ISD::ADD, DL, BasePtrVT, Ptr, Offset, Flags);
if (!this->getTarget().shouldPreservePtrArith(
this->getMachineFunction().getFunction())) {
return getNode(ISD::ADD, DL, BasePtrVT, Ptr, Offset, Flags);
} else {
return getNode(ISD::PTRADD, DL, BasePtrVT, Ptr, Offset, Flags);
}
}

/// Returns true if memcpy source is constant data.
Expand Down
19 changes: 9 additions & 10 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4335,8 +4335,8 @@ void SelectionDAGBuilder::visitGetElementPtr(const User &I) {
(int64_t(Offset) >= 0 && NW.hasNoUnsignedSignedWrap()))
Flags |= SDNodeFlags::NoUnsignedWrap;

N = DAG.getNode(ISD::ADD, dl, N.getValueType(), N,
DAG.getConstant(Offset, dl, N.getValueType()), Flags);
N = DAG.getMemBasePlusOffset(
N, DAG.getConstant(Offset, dl, N.getValueType()), dl, Flags);
}
} else {
// IdxSize is the width of the arithmetic according to IR semantics.
Expand Down Expand Up @@ -4380,7 +4380,7 @@ void SelectionDAGBuilder::visitGetElementPtr(const User &I) {

OffsVal = DAG.getSExtOrTrunc(OffsVal, dl, N.getValueType());

N = DAG.getNode(ISD::ADD, dl, N.getValueType(), N, OffsVal, Flags);
N = DAG.getMemBasePlusOffset(N, OffsVal, dl, Flags);
continue;
}

Expand Down Expand Up @@ -4440,7 +4440,7 @@ void SelectionDAGBuilder::visitGetElementPtr(const User &I) {
SDNodeFlags AddFlags;
AddFlags.setNoUnsignedWrap(NW.hasNoUnsignedWrap());

N = DAG.getNode(ISD::ADD, dl, N.getValueType(), N, IdxN, AddFlags);
N = DAG.getMemBasePlusOffset(N, IdxN, dl, AddFlags);
}
}

Expand Down Expand Up @@ -9195,8 +9195,8 @@ bool SelectionDAGBuilder::visitMemPCpyCall(const CallInst &I) {
Size = DAG.getSExtOrTrunc(Size, sdl, Dst.getValueType());

// Adjust return pointer to point just past the last dst byte.
SDValue DstPlusSize = DAG.getNode(ISD::ADD, sdl, Dst.getValueType(),
Dst, Size);
SDNodeFlags Flags;
SDValue DstPlusSize = DAG.getMemBasePlusOffset(Dst, Size, sdl, Flags);
setValue(&I, DstPlusSize);
return true;
}
Expand Down Expand Up @@ -11293,10 +11293,9 @@ TargetLowering::LowerCallTo(TargetLowering::CallLoweringInfo &CLI) const {
MachineFunction &MF = CLI.DAG.getMachineFunction();
Align HiddenSRetAlign = MF.getFrameInfo().getObjectAlign(DemoteStackIdx);
for (unsigned i = 0; i < NumValues; ++i) {
SDValue Add =
CLI.DAG.getNode(ISD::ADD, CLI.DL, PtrVT, DemoteStackSlot,
CLI.DAG.getConstant(Offsets[i], CLI.DL, PtrVT),
SDNodeFlags::NoUnsignedWrap);
SDValue Add = CLI.DAG.getMemBasePlusOffset(
DemoteStackSlot, CLI.DAG.getConstant(Offsets[i], CLI.DL, PtrVT),
CLI.DL, SDNodeFlags::NoUnsignedWrap);
SDValue L = CLI.DAG.getLoad(
RetTys[i], CLI.DL, CLI.Chain, Add,
MachinePointerInfo::getFixedStack(CLI.DAG.getMachineFunction(),
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {

// Binary operators
case ISD::ADD: return "add";
case ISD::PTRADD: return "ptradd";
case ISD::SUB: return "sub";
case ISD::MUL: return "mul";
case ISD::MULHU: return "mulhu";
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/AArch64/AArch64Features.td
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,9 @@ def FeaturePAuthLR : ExtensionWithMArch<"pauth-lr", "PAuthLR", "FEAT_PAuth_LR",
def FeatureTLBIW : ExtensionWithMArch<"tlbiw", "TLBIW", "FEAT_TLBIW",
"Enable Armv9.5-A TLBI VMALL for Dirty State">;

def FeatureCPACodegen : SubtargetFeature<"cpa-codegen",
"HasCPACodegen", "true", "Generate scalar FEAT_CPA instructions">;

//===----------------------------------------------------------------------===//
// Armv9.6 Architecture Extensions
//===----------------------------------------------------------------------===//
Expand Down
18 changes: 18 additions & 0 deletions llvm/lib/Target/AArch64/AArch64InstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,9 @@ def HasGCS : Predicate<"Subtarget->hasGCS()">,
AssemblerPredicateWithAll<(all_of FeatureGCS), "gcs">;
def HasCPA : Predicate<"Subtarget->hasCPA()">,
AssemblerPredicateWithAll<(all_of FeatureCPA), "cpa">;
def HasCPACodegen : Predicate<"Subtarget->hasCPACodegen()">,
AssemblerPredicateWithAll<(all_of FeatureCPACodegen),
"cpa-codegen">;
def IsLE : Predicate<"Subtarget->isLittleEndian()">;
def IsBE : Predicate<"!Subtarget->isLittleEndian()">;
def IsWindows : Predicate<"Subtarget->isTargetWindows()">;
Expand Down Expand Up @@ -10397,6 +10400,21 @@ let Predicates = [HasCPA] in {
// Scalar multiply-add/subtract
def MADDPT : MulAccumCPA<0, "maddpt">;
def MSUBPT : MulAccumCPA<1, "msubpt">;

def : Pat<(ptradd GPR64sp:$Rn, GPR64sp:$Rm),
(ADDPT_shift GPR64sp:$Rn, GPR64sp:$Rm, (i32 0))>;
def : Pat<(ptradd GPR64sp:$Rn, (shl GPR64sp:$Rm, (i64 imm0_7:$imm))),
(ADDPT_shift GPR64sp:$Rn, GPR64sp:$Rm,
(i32 (trunc_imm imm0_7:$imm)))>;
def : Pat<(ptradd GPR64sp:$Rn, (ineg GPR64sp:$Rm)),
(SUBPT_shift GPR64sp:$Rn, GPR64sp:$Rm, (i32 0))>;
def : Pat<(ptradd GPR64sp:$Rn, (ineg (shl GPR64sp:$Rm, (i64 imm0_7:$imm)))),
(SUBPT_shift GPR64sp:$Rn, GPR64sp:$Rm,
(i32 (trunc_imm imm0_7:$imm)))>;
def : Pat<(ptradd GPR64:$Ra, (mul GPR64:$Rn, GPR64:$Rm)),
(MADDPT GPR64:$Rn, GPR64:$Rm, GPR64:$Ra)>;
def : Pat<(ptradd GPR64:$Ra, (mul GPR64:$Rn, (ineg GPR64:$Rm))),
(MSUBPT GPR64:$Rn, GPR64:$Rm, GPR64:$Ra)>;
}

def round_v4fp32_to_v4bf16 :
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Target/AArch64/AArch64TargetMachine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -928,3 +928,7 @@ bool AArch64TargetMachine::parseMachineFunctionInfo(
MF.getInfo<AArch64FunctionInfo>()->initializeBaseYamlFields(YamlMFI);
return false;
}

bool AArch64TargetMachine::shouldPreservePtrArith(const Function &F) const {
return getSubtargetImpl(F)->hasCPA() && getSubtargetImpl(F)->hasCPACodegen();
}
4 changes: 4 additions & 0 deletions llvm/lib/Target/AArch64/AArch64TargetMachine.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ class AArch64TargetMachine : public CodeGenTargetMachineImpl {
return getPointerSize(SrcAS) == getPointerSize(DestAS);
}

/// In AArch64, true if FEAT_CPA is present. Allows pointer arithmetic
/// semantics to be preserved for instruction selection.
bool shouldPreservePtrArith(const Function &F) const override;

private:
bool isLittle;
};
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2197,6 +2197,10 @@ bool AArch64InstructionSelector::preISelLower(MachineInstr &I) {
return Changed;
}
case TargetOpcode::G_PTR_ADD:
// If Checked Pointer Arithmetic (FEAT_CPA) is present, preserve the pointer
// arithmetic semantics instead of falling back to regular arithmetic.
if (TM.shouldPreservePtrArith(MF.getFunction()))
return false;
return convertPtrAddToAdd(I, MRI);
case TargetOpcode::G_LOAD: {
// For scalar loads of pointers, we try to convert the dest type from p0
Expand Down
Loading
Loading