From 4032fc101ba58d577c160d74296f9d01b0bdcc29 Mon Sep 17 00:00:00 2001 From: Alex Maclean Date: Thu, 24 Apr 2025 16:03:03 +0000 Subject: [PATCH 1/3] [NVPTX][NFC] Refactor and cleanup NVPTXISelLowering --- llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 186 ++++++++------------ 1 file changed, 73 insertions(+), 113 deletions(-) diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index 8dd9bf2876927..b0bffb1345da3 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -493,13 +493,6 @@ VectorizePTXValueVTs(const SmallVectorImpl &ValueVTs, return VectorInfo; } -static SDValue MaybeBitcast(SelectionDAG &DAG, SDLoc DL, EVT VT, - SDValue Value) { - if (Value->getValueType(0) == VT) - return Value; - return DAG.getNode(ISD::BITCAST, DL, VT, Value); -} - // NVPTXTargetLowering Constructor. NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, const NVPTXSubtarget &STI) @@ -1587,9 +1580,9 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign, IsVAArg); SmallVector StoreOperands; - for (unsigned j = 0, je = VTs.size(); j != je; ++j) { - EVT EltVT = VTs[j]; - int CurOffset = Offsets[j]; + for (const unsigned J : llvm::seq(VTs.size())) { + EVT EltVT = VTs[J]; + const int CurOffset = Offsets[J]; MaybeAlign PartAlign; if (NeedAlign) PartAlign = commonAlignment(ArgAlign, CurOffset); @@ -1629,7 +1622,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, // If we have a PVF_SCALAR entry, it may not be sufficiently aligned for a // scalar store. In such cases, fall back to byte stores. - if (VectorInfo[j] == PVF_SCALAR && !IsVAArg && PartAlign.has_value() && + if (VectorInfo[J] == PVF_SCALAR && !IsVAArg && PartAlign.has_value() && PartAlign.value() < DL.getABITypeAlign(EltVT.getTypeForEVT(*DAG.getContext()))) { assert(StoreOperands.empty() && "Unfinished preceeding store."); @@ -1645,7 +1638,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, } // New store. - if (VectorInfo[j] & PVF_FIRST) { + if (VectorInfo[J] & PVF_FIRST) { assert(StoreOperands.empty() && "Unfinished preceding store."); StoreOperands.push_back(Chain); StoreOperands.push_back( @@ -1665,8 +1658,8 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, // Record the value to store. StoreOperands.push_back(StVal); - if (VectorInfo[j] & PVF_LAST) { - unsigned NumElts = StoreOperands.size() - 3; + if (VectorInfo[J] & PVF_LAST) { + const unsigned NumElts = StoreOperands.size() - 3; NVPTXISD::NodeType Op; switch (NumElts) { case 1: @@ -2168,7 +2161,7 @@ SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const { ISD::OR, DL, MVT::i16, {Extend0, DAG.getNode(ISD::SHL, DL, MVT::i16, {Extend1, Const8})}); EVT ToVT = Op->getValueType(0); - return MaybeBitcast(DAG, DL, ToVT, AsInt); + return DAG.getBitcast(ToVT, AsInt); } // We can init constant f16x2/v2i16/v4i8 with a single .b32 move. Normally it @@ -3367,18 +3360,10 @@ SDValue NVPTXTargetLowering::LowerFormalArguments( auto PtrVT = getPointerTy(DAG.getDataLayout()); const Function *F = &MF.getFunction(); - const AttributeList &PAL = F->getAttributes(); - const TargetLowering *TLI = STI.getTargetLowering(); SDValue Root = DAG.getRoot(); - std::vector OutChains; + SmallVector OutChains; - std::vector argTypes; - std::vector theArgs; - for (const Argument &I : F->args()) { - theArgs.push_back(&I); - argTypes.push_back(I.getType()); - } // argTypes.size() (or theArgs.size()) and Ins.size() need not match. // Ins.size() will be larger // * if there is an aggregate argument with multiple fields (each field @@ -3388,49 +3373,55 @@ SDValue NVPTXTargetLowering::LowerFormalArguments( // individually present in Ins. // So a different index should be used for indexing into Ins. // See similar issue in LowerCall. - unsigned InsIdx = 0; + const auto *In = Ins.begin(); - for (unsigned i = 0, e = theArgs.size(); i != e; ++i, ++InsIdx) { - Type *Ty = argTypes[i]; + for (const auto &Arg : F->args()) { + Type *Ty = Arg.getType(); - if (theArgs[i]->use_empty()) { - // argument is dead - if (shouldPassAsArray(Ty) && !Ty->isVectorTy()) { - SmallVector vtparts; + if (In == Ins.end() || In->OrigArgIndex != Arg.getArgNo()) + report_fatal_error("Empty parameter types are not supported"); - ComputePTXValueVTs(*this, DAG.getDataLayout(), Ty, vtparts); - if (vtparts.empty()) - report_fatal_error("Empty parameter types are not supported"); - - for (unsigned parti = 0, parte = vtparts.size(); parti != parte; - ++parti) { - InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT)); - ++InsIdx; - } - if (vtparts.size() > 0) - --InsIdx; - continue; - } - if (Ty->isVectorTy()) { - EVT ObjectVT = getValueType(DL, Ty); - unsigned NumRegs = TLI->getNumRegisters(F->getContext(), ObjectVT); - for (unsigned parti = 0; parti < NumRegs; ++parti) { - InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT)); - ++InsIdx; - } - if (NumRegs > 0) - --InsIdx; - continue; + if (Arg.use_empty()) { + // argument is dead + for (; In != Ins.end() && In->OrigArgIndex == Arg.getArgNo(); ++In) { + assert(!In->Used && "Arg.use_empty() is true but Arg is used?"); + InVals.push_back(DAG.getUNDEF(In->VT)); } - InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT)); continue; } + SDValue ArgSymbol = getParamSymbol(DAG, Arg.getArgNo(), PtrVT); + // In the following cases, assign a node order of "i+1" // to newly created nodes. The SDNodes for params have to // appear in the same order as their order of appearance // in the original function. "i+1" holds that order. - if (!PAL.hasParamAttr(i, Attribute::ByVal)) { + if (Arg.hasByValAttr()) { + // Param has ByVal attribute + // Return MoveParam(param symbol). + // Ideally, the param symbol can be returned directly, + // but when SDNode builder decides to use it in a CopyToReg(), + // machine instruction fails because TargetExternalSymbol + // (not lowered) is target dependent, and CopyToReg assumes + // the source is lowered. + assert(getValueType(DL, Ty) == In->VT && + "Ins type did not match function type"); + assert(In->VT == PtrVT && "ByVal argument must be a pointer"); + + SDValue P; + if (isKernelFunction(*F)) { + P = DAG.getNode(NVPTXISD::Wrapper, dl, In->VT, ArgSymbol); + P.getNode()->setIROrder(Arg.getArgNo() + 1); + } else { + P = DAG.getNode(NVPTXISD::MoveParam, dl, In->VT, ArgSymbol); + P.getNode()->setIROrder(Arg.getArgNo() + 1); + P = DAG.getAddrSpaceCast(dl, In->VT, P, ADDRESS_SPACE_LOCAL, + ADDRESS_SPACE_GENERIC); + } + In++; + InVals.push_back(P); + + } else { bool aggregateIsPacked = false; if (StructType *STy = dyn_cast(Ty)) aggregateIsPacked = STy->isPacked(); @@ -3442,21 +3433,20 @@ SDValue NVPTXTargetLowering::LowerFormalArguments( report_fatal_error("Empty parameter types are not supported"); Align ArgAlign = getFunctionArgumentAlignment( - F, Ty, i + AttributeList::FirstArgIndex, DL); + F, Ty, Arg.getArgNo() + AttributeList::FirstArgIndex, DL); auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign); - SDValue Arg = getParamSymbol(DAG, i, PtrVT); int VecIdx = -1; // Index of the first element of the current vector. - for (unsigned parti = 0, parte = VTs.size(); parti != parte; ++parti) { - if (VectorInfo[parti] & PVF_FIRST) { + for (const unsigned PartI : llvm::seq(VTs.size())) { + if (VectorInfo[PartI] & PVF_FIRST) { assert(VecIdx == -1 && "Orphaned vector."); - VecIdx = parti; + VecIdx = PartI; } // That's the last element of this store op. - if (VectorInfo[parti] & PVF_LAST) { - unsigned NumElts = parti - VecIdx + 1; - EVT EltVT = VTs[parti]; + if (VectorInfo[PartI] & PVF_LAST) { + const unsigned NumElts = PartI - VecIdx + 1; + EVT EltVT = VTs[PartI]; // i1 is loaded/stored as i8. EVT LoadVT = EltVT; if (EltVT == MVT::i1) @@ -3469,10 +3459,8 @@ SDValue NVPTXTargetLowering::LowerFormalArguments( EVT VecVT = EVT::getVectorVT(F->getContext(), LoadVT, NumElts); SDValue VecAddr = - DAG.getNode(ISD::ADD, dl, PtrVT, Arg, + DAG.getNode(ISD::ADD, dl, PtrVT, ArgSymbol, DAG.getConstant(Offsets[VecIdx], dl, PtrVT)); - Value *srcValue = Constant::getNullValue( - PointerType::get(F->getContext(), ADDRESS_SPACE_PARAM)); const MaybeAlign PartAlign = [&]() -> MaybeAlign { if (aggregateIsPacked) @@ -3481,23 +3469,23 @@ SDValue NVPTXTargetLowering::LowerFormalArguments( return std::nullopt; Align PartAlign = DL.getABITypeAlign(EltVT.getTypeForEVT(F->getContext())); - return commonAlignment(PartAlign, Offsets[parti]); + return commonAlignment(PartAlign, Offsets[PartI]); }(); - SDValue P = DAG.getLoad(VecVT, dl, Root, VecAddr, - MachinePointerInfo(srcValue), PartAlign, - MachineMemOperand::MODereferenceable | - MachineMemOperand::MOInvariant); + SDValue P = + DAG.getLoad(VecVT, dl, Root, VecAddr, + MachinePointerInfo(ADDRESS_SPACE_PARAM), PartAlign, + MachineMemOperand::MODereferenceable | + MachineMemOperand::MOInvariant); if (P.getNode()) - P.getNode()->setIROrder(i + 1); - for (unsigned j = 0; j < NumElts; ++j) { + P.getNode()->setIROrder(Arg.getArgNo() + 1); + for (const unsigned J : llvm::seq(NumElts)) { SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, LoadVT, P, - DAG.getIntPtrConstant(j, dl)); + DAG.getIntPtrConstant(J, dl)); // We've loaded i1 as an i8 and now must truncate it back to i1 if (EltVT == MVT::i1) Elt = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, Elt); // v2f16 was loaded as an i32. Now we must bitcast it back. - else if (EltVT != LoadVT) - Elt = DAG.getNode(ISD::BITCAST, dl, EltVT, Elt); + Elt = DAG.getBitcast(EltVT, Elt); // If a promoted integer type is used, truncate down to the original MVT PromotedVT; @@ -3507,12 +3495,11 @@ SDValue NVPTXTargetLowering::LowerFormalArguments( // Extend the element if necessary (e.g. an i8 is loaded // into an i16 register) - if (Ins[InsIdx].VT.isInteger() && - Ins[InsIdx].VT.getFixedSizeInBits() > - LoadVT.getFixedSizeInBits()) { - unsigned Extend = Ins[InsIdx].Flags.isSExt() ? ISD::SIGN_EXTEND - : ISD::ZERO_EXTEND; - Elt = DAG.getNode(Extend, dl, Ins[InsIdx].VT, Elt); + if (In->VT.isInteger() && + In->VT.getFixedSizeInBits() > LoadVT.getFixedSizeInBits()) { + unsigned Extend = + In->Flags.isSExt() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND; + Elt = DAG.getNode(Extend, dl, In->VT, Elt); } InVals.push_back(Elt); } @@ -3520,40 +3507,13 @@ SDValue NVPTXTargetLowering::LowerFormalArguments( // Reset vector tracking state. VecIdx = -1; } - ++InsIdx; + ++In; } - if (VTs.size() > 0) - --InsIdx; - continue; - } - - // Param has ByVal attribute - // Return MoveParam(param symbol). - // Ideally, the param symbol can be returned directly, - // but when SDNode builder decides to use it in a CopyToReg(), - // machine instruction fails because TargetExternalSymbol - // (not lowered) is target dependent, and CopyToReg assumes - // the source is lowered. - EVT ObjectVT = getValueType(DL, Ty); - assert(ObjectVT == Ins[InsIdx].VT && - "Ins type did not match function type"); - SDValue Arg = getParamSymbol(DAG, i, PtrVT); - - SDValue P; - if (isKernelFunction(*F)) { - P = DAG.getNode(NVPTXISD::Wrapper, dl, ObjectVT, Arg); - P.getNode()->setIROrder(i + 1); - } else { - P = DAG.getNode(NVPTXISD::MoveParam, dl, ObjectVT, Arg); - P.getNode()->setIROrder(i + 1); - P = DAG.getAddrSpaceCast(dl, ObjectVT, P, ADDRESS_SPACE_LOCAL, - ADDRESS_SPACE_GENERIC); } - InVals.push_back(P); } if (!OutChains.empty()) - DAG.setRoot(DAG.getNode(ISD::TokenFactor, dl, MVT::Other, OutChains)); + DAG.setRoot(DAG.getTokenFactor(dl, OutChains)); return Chain; } @@ -5784,7 +5744,7 @@ static void ReplaceBITCAST(SDNode *Node, SelectionDAG &DAG, // Bitcast to i16 and unpack elements into a vector SDLoc DL(Node); - SDValue AsInt = MaybeBitcast(DAG, DL, MVT::i16, Op->getOperand(0)); + SDValue AsInt = DAG.getBitcast(MVT::i16, Op->getOperand(0)); SDValue Vec0 = DAG.getNode(ISD::TRUNCATE, DL, MVT::i8, AsInt); SDValue Const8 = DAG.getConstant(8, DL, MVT::i16); SDValue Vec1 = From 5c6025dfb789d4c2059e53e8f7527878fbea004b Mon Sep 17 00:00:00 2001 From: Alex Maclean Date: Thu, 24 Apr 2025 22:35:34 +0000 Subject: [PATCH 2/3] more cleanup --- llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 49 +++++++++++++-------- 1 file changed, 30 insertions(+), 19 deletions(-) diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index b0bffb1345da3..0854c6e4bb366 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -3374,18 +3374,28 @@ SDValue NVPTXTargetLowering::LowerFormalArguments( // So a different index should be used for indexing into Ins. // See similar issue in LowerCall. const auto *In = Ins.begin(); + auto ConsumeArgIns = [&](const Argument &Arg) { + const auto *ArgInsBegin = In; + const auto *ArgInsEnd = In; + while (ArgInsEnd != Ins.end() && ArgInsEnd->OrigArgIndex == Arg.getArgNo()) + ++ArgInsEnd; + In = ArgInsEnd; + return llvm::ArrayRef(ArgInsBegin, ArgInsEnd); + }; for (const auto &Arg : F->args()) { + const auto ArgIns = ConsumeArgIns(Arg); + Type *Ty = Arg.getType(); - if (In == Ins.end() || In->OrigArgIndex != Arg.getArgNo()) + if (ArgIns.empty()) report_fatal_error("Empty parameter types are not supported"); if (Arg.use_empty()) { // argument is dead - for (; In != Ins.end() && In->OrigArgIndex == Arg.getArgNo(); ++In) { - assert(!In->Used && "Arg.use_empty() is true but Arg is used?"); - InVals.push_back(DAG.getUNDEF(In->VT)); + for (const auto &In : ArgIns) { + assert(!In.Used && "Arg.use_empty() is true but Arg is used?"); + InVals.push_back(DAG.getUNDEF(In.VT)); } continue; } @@ -3404,23 +3414,23 @@ SDValue NVPTXTargetLowering::LowerFormalArguments( // machine instruction fails because TargetExternalSymbol // (not lowered) is target dependent, and CopyToReg assumes // the source is lowered. - assert(getValueType(DL, Ty) == In->VT && + assert(ArgIns.size() == 1 && "ByVal argument must be a pointer"); + const auto &ByvalIn = ArgIns[0]; + assert(getValueType(DL, Ty) == ByvalIn.VT && "Ins type did not match function type"); - assert(In->VT == PtrVT && "ByVal argument must be a pointer"); + assert(ByvalIn.VT == PtrVT && "ByVal argument must be a pointer"); SDValue P; if (isKernelFunction(*F)) { - P = DAG.getNode(NVPTXISD::Wrapper, dl, In->VT, ArgSymbol); + P = DAG.getNode(NVPTXISD::Wrapper, dl, ByvalIn.VT, ArgSymbol); P.getNode()->setIROrder(Arg.getArgNo() + 1); } else { - P = DAG.getNode(NVPTXISD::MoveParam, dl, In->VT, ArgSymbol); + P = DAG.getNode(NVPTXISD::MoveParam, dl, ByvalIn.VT, ArgSymbol); P.getNode()->setIROrder(Arg.getArgNo() + 1); - P = DAG.getAddrSpaceCast(dl, In->VT, P, ADDRESS_SPACE_LOCAL, + P = DAG.getAddrSpaceCast(dl, ByvalIn.VT, P, ADDRESS_SPACE_LOCAL, ADDRESS_SPACE_GENERIC); } - In++; InVals.push_back(P); - } else { bool aggregateIsPacked = false; if (StructType *STy = dyn_cast(Ty)) @@ -3429,12 +3439,13 @@ SDValue NVPTXTargetLowering::LowerFormalArguments( SmallVector VTs; SmallVector Offsets; ComputePTXValueVTs(*this, DL, Ty, VTs, &Offsets, 0); - if (VTs.empty()) - report_fatal_error("Empty parameter types are not supported"); + assert(VTs.size() == ArgIns.size() && "Size mismatch"); + assert(VTs.size() == Offsets.size() && "Size mismatch"); Align ArgAlign = getFunctionArgumentAlignment( F, Ty, Arg.getArgNo() + AttributeList::FirstArgIndex, DL); auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign); + assert(VectorInfo.size() == VTs.size() && "Size mismatch"); int VecIdx = -1; // Index of the first element of the current vector. for (const unsigned PartI : llvm::seq(VTs.size())) { @@ -3495,11 +3506,12 @@ SDValue NVPTXTargetLowering::LowerFormalArguments( // Extend the element if necessary (e.g. an i8 is loaded // into an i16 register) - if (In->VT.isInteger() && - In->VT.getFixedSizeInBits() > LoadVT.getFixedSizeInBits()) { - unsigned Extend = - In->Flags.isSExt() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND; - Elt = DAG.getNode(Extend, dl, In->VT, Elt); + if (ArgIns[PartI].VT.getFixedSizeInBits() != + LoadVT.getFixedSizeInBits()) { + assert(ArgIns[PartI].VT.isInteger() && LoadVT.isInteger() && + "Non-integer argument type size mismatch"); + Elt = DAG.getExtOrTrunc(ArgIns[PartI].Flags.isSExt(), Elt, dl, + ArgIns[PartI].VT); } InVals.push_back(Elt); } @@ -3507,7 +3519,6 @@ SDValue NVPTXTargetLowering::LowerFormalArguments( // Reset vector tracking state. VecIdx = -1; } - ++In; } } } From 6a52cd5b7401de506a625a2faa474df7eacd8dbd Mon Sep 17 00:00:00 2001 From: Alex Maclean Date: Thu, 24 Apr 2025 23:12:36 +0000 Subject: [PATCH 3/3] more cleanup --- llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index 0854c6e4bb366..c41741ed10232 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -3373,18 +3373,12 @@ SDValue NVPTXTargetLowering::LowerFormalArguments( // individually present in Ins. // So a different index should be used for indexing into Ins. // See similar issue in LowerCall. - const auto *In = Ins.begin(); - auto ConsumeArgIns = [&](const Argument &Arg) { - const auto *ArgInsBegin = In; - const auto *ArgInsEnd = In; - while (ArgInsEnd != Ins.end() && ArgInsEnd->OrigArgIndex == Arg.getArgNo()) - ++ArgInsEnd; - In = ArgInsEnd; - return llvm::ArrayRef(ArgInsBegin, ArgInsEnd); - }; + auto AllIns = ArrayRef(Ins); for (const auto &Arg : F->args()) { - const auto ArgIns = ConsumeArgIns(Arg); + const auto ArgIns = AllIns.take_while( + [&](auto I) { return I.OrigArgIndex == Arg.getArgNo(); }); + AllIns = AllIns.drop_front(ArgIns.size()); Type *Ty = Arg.getType();