-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[NVPTX][NFC] Refactoring and cleanup in NVPTXISelLowering #137222
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -493,13 +493,6 @@ VectorizePTXValueVTs(const SmallVectorImpl<EVT> &ValueVTs, | |
| return VectorInfo; | ||
| } | ||
|
|
||
| static SDValue MaybeBitcast(SelectionDAG &DAG, SDLoc DL, EVT VT, | ||
| SDValue Value) { | ||
| if (Value->getValueType(0) == VT) | ||
| return Value; | ||
| return DAG.getNode(ISD::BITCAST, DL, VT, Value); | ||
| } | ||
|
|
||
| // NVPTXTargetLowering Constructor. | ||
| NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, | ||
| const NVPTXSubtarget &STI) | ||
|
|
@@ -1587,9 +1580,9 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, | |
|
|
||
| auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign, IsVAArg); | ||
| SmallVector<SDValue, 6> StoreOperands; | ||
| for (unsigned j = 0, je = VTs.size(); j != je; ++j) { | ||
| EVT EltVT = VTs[j]; | ||
| int CurOffset = Offsets[j]; | ||
| for (const unsigned J : llvm::seq(VTs.size())) { | ||
| EVT EltVT = VTs[J]; | ||
| const int CurOffset = Offsets[J]; | ||
| MaybeAlign PartAlign; | ||
| if (NeedAlign) | ||
| PartAlign = commonAlignment(ArgAlign, CurOffset); | ||
|
|
@@ -1629,7 +1622,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, | |
|
|
||
| // If we have a PVF_SCALAR entry, it may not be sufficiently aligned for a | ||
| // scalar store. In such cases, fall back to byte stores. | ||
| if (VectorInfo[j] == PVF_SCALAR && !IsVAArg && PartAlign.has_value() && | ||
| if (VectorInfo[J] == PVF_SCALAR && !IsVAArg && PartAlign.has_value() && | ||
| PartAlign.value() < | ||
| DL.getABITypeAlign(EltVT.getTypeForEVT(*DAG.getContext()))) { | ||
| assert(StoreOperands.empty() && "Unfinished preceeding store."); | ||
|
|
@@ -1645,7 +1638,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, | |
| } | ||
|
|
||
| // New store. | ||
| if (VectorInfo[j] & PVF_FIRST) { | ||
| if (VectorInfo[J] & PVF_FIRST) { | ||
| assert(StoreOperands.empty() && "Unfinished preceding store."); | ||
| StoreOperands.push_back(Chain); | ||
| StoreOperands.push_back( | ||
|
|
@@ -1665,8 +1658,8 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, | |
| // Record the value to store. | ||
| StoreOperands.push_back(StVal); | ||
|
|
||
| if (VectorInfo[j] & PVF_LAST) { | ||
| unsigned NumElts = StoreOperands.size() - 3; | ||
| if (VectorInfo[J] & PVF_LAST) { | ||
| const unsigned NumElts = StoreOperands.size() - 3; | ||
| NVPTXISD::NodeType Op; | ||
| switch (NumElts) { | ||
| case 1: | ||
|
|
@@ -2168,7 +2161,7 @@ SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const { | |
| ISD::OR, DL, MVT::i16, | ||
| {Extend0, DAG.getNode(ISD::SHL, DL, MVT::i16, {Extend1, Const8})}); | ||
| EVT ToVT = Op->getValueType(0); | ||
| return MaybeBitcast(DAG, DL, ToVT, AsInt); | ||
| return DAG.getBitcast(ToVT, AsInt); | ||
| } | ||
|
|
||
| // We can init constant f16x2/v2i16/v4i8 with a single .b32 move. Normally it | ||
|
|
@@ -3367,18 +3360,10 @@ SDValue NVPTXTargetLowering::LowerFormalArguments( | |
| auto PtrVT = getPointerTy(DAG.getDataLayout()); | ||
|
|
||
| const Function *F = &MF.getFunction(); | ||
| const AttributeList &PAL = F->getAttributes(); | ||
| const TargetLowering *TLI = STI.getTargetLowering(); | ||
|
|
||
| SDValue Root = DAG.getRoot(); | ||
| std::vector<SDValue> OutChains; | ||
| SmallVector<SDValue, 16> OutChains; | ||
|
|
||
| std::vector<Type *> argTypes; | ||
| std::vector<const Argument *> theArgs; | ||
| for (const Argument &I : F->args()) { | ||
| theArgs.push_back(&I); | ||
| argTypes.push_back(I.getType()); | ||
| } | ||
| // argTypes.size() (or theArgs.size()) and Ins.size() need not match. | ||
| // Ins.size() will be larger | ||
| // * if there is an aggregate argument with multiple fields (each field | ||
|
|
@@ -3388,75 +3373,91 @@ SDValue NVPTXTargetLowering::LowerFormalArguments( | |
| // individually present in Ins. | ||
| // So a different index should be used for indexing into Ins. | ||
| // See similar issue in LowerCall. | ||
| unsigned InsIdx = 0; | ||
| const auto *In = Ins.begin(); | ||
| auto ConsumeArgIns = [&](const Argument &Arg) { | ||
| const auto *ArgInsBegin = In; | ||
| const auto *ArgInsEnd = In; | ||
| while (ArgInsEnd != Ins.end() && ArgInsEnd->OrigArgIndex == Arg.getArgNo()) | ||
| ++ArgInsEnd; | ||
| In = ArgInsEnd; | ||
| return llvm::ArrayRef(ArgInsBegin, ArgInsEnd); | ||
| }; | ||
|
||
|
|
||
| for (unsigned i = 0, e = theArgs.size(); i != e; ++i, ++InsIdx) { | ||
| Type *Ty = argTypes[i]; | ||
| for (const auto &Arg : F->args()) { | ||
| const auto ArgIns = ConsumeArgIns(Arg); | ||
|
|
||
| if (theArgs[i]->use_empty()) { | ||
| // argument is dead | ||
| if (shouldPassAsArray(Ty) && !Ty->isVectorTy()) { | ||
| SmallVector<EVT, 16> vtparts; | ||
| Type *Ty = Arg.getType(); | ||
|
|
||
| ComputePTXValueVTs(*this, DAG.getDataLayout(), Ty, vtparts); | ||
| if (vtparts.empty()) | ||
| report_fatal_error("Empty parameter types are not supported"); | ||
| if (ArgIns.empty()) | ||
| report_fatal_error("Empty parameter types are not supported"); | ||
|
|
||
| for (unsigned parti = 0, parte = vtparts.size(); parti != parte; | ||
| ++parti) { | ||
| InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT)); | ||
| ++InsIdx; | ||
| } | ||
| if (vtparts.size() > 0) | ||
| --InsIdx; | ||
| continue; | ||
| } | ||
| if (Ty->isVectorTy()) { | ||
| EVT ObjectVT = getValueType(DL, Ty); | ||
| unsigned NumRegs = TLI->getNumRegisters(F->getContext(), ObjectVT); | ||
| for (unsigned parti = 0; parti < NumRegs; ++parti) { | ||
| InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT)); | ||
| ++InsIdx; | ||
| } | ||
| if (NumRegs > 0) | ||
| --InsIdx; | ||
| continue; | ||
| if (Arg.use_empty()) { | ||
| // argument is dead | ||
| for (const auto &In : ArgIns) { | ||
| assert(!In.Used && "Arg.use_empty() is true but Arg is used?"); | ||
| InVals.push_back(DAG.getUNDEF(In.VT)); | ||
| } | ||
| InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT)); | ||
| continue; | ||
| } | ||
|
|
||
| SDValue ArgSymbol = getParamSymbol(DAG, Arg.getArgNo(), PtrVT); | ||
|
|
||
| // In the following cases, assign a node order of "i+1" | ||
| // to newly created nodes. The SDNodes for params have to | ||
| // appear in the same order as their order of appearance | ||
| // in the original function. "i+1" holds that order. | ||
| if (!PAL.hasParamAttr(i, Attribute::ByVal)) { | ||
| if (Arg.hasByValAttr()) { | ||
| // Param has ByVal attribute | ||
| // Return MoveParam(param symbol). | ||
| // Ideally, the param symbol can be returned directly, | ||
| // but when SDNode builder decides to use it in a CopyToReg(), | ||
| // machine instruction fails because TargetExternalSymbol | ||
| // (not lowered) is target dependent, and CopyToReg assumes | ||
| // the source is lowered. | ||
| assert(ArgIns.size() == 1 && "ByVal argument must be a pointer"); | ||
| const auto &ByvalIn = ArgIns[0]; | ||
| assert(getValueType(DL, Ty) == ByvalIn.VT && | ||
| "Ins type did not match function type"); | ||
| assert(ByvalIn.VT == PtrVT && "ByVal argument must be a pointer"); | ||
|
|
||
| SDValue P; | ||
| if (isKernelFunction(*F)) { | ||
| P = DAG.getNode(NVPTXISD::Wrapper, dl, ByvalIn.VT, ArgSymbol); | ||
| P.getNode()->setIROrder(Arg.getArgNo() + 1); | ||
| } else { | ||
| P = DAG.getNode(NVPTXISD::MoveParam, dl, ByvalIn.VT, ArgSymbol); | ||
| P.getNode()->setIROrder(Arg.getArgNo() + 1); | ||
| P = DAG.getAddrSpaceCast(dl, ByvalIn.VT, P, ADDRESS_SPACE_LOCAL, | ||
| ADDRESS_SPACE_GENERIC); | ||
| } | ||
| InVals.push_back(P); | ||
| } else { | ||
| bool aggregateIsPacked = false; | ||
| if (StructType *STy = dyn_cast<StructType>(Ty)) | ||
| aggregateIsPacked = STy->isPacked(); | ||
|
|
||
| SmallVector<EVT, 16> VTs; | ||
| SmallVector<uint64_t, 16> Offsets; | ||
| ComputePTXValueVTs(*this, DL, Ty, VTs, &Offsets, 0); | ||
| if (VTs.empty()) | ||
| report_fatal_error("Empty parameter types are not supported"); | ||
| assert(VTs.size() == ArgIns.size() && "Size mismatch"); | ||
| assert(VTs.size() == Offsets.size() && "Size mismatch"); | ||
|
|
||
| Align ArgAlign = getFunctionArgumentAlignment( | ||
| F, Ty, i + AttributeList::FirstArgIndex, DL); | ||
| F, Ty, Arg.getArgNo() + AttributeList::FirstArgIndex, DL); | ||
| auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign); | ||
| assert(VectorInfo.size() == VTs.size() && "Size mismatch"); | ||
|
|
||
| SDValue Arg = getParamSymbol(DAG, i, PtrVT); | ||
| int VecIdx = -1; // Index of the first element of the current vector. | ||
| for (unsigned parti = 0, parte = VTs.size(); parti != parte; ++parti) { | ||
| if (VectorInfo[parti] & PVF_FIRST) { | ||
| for (const unsigned PartI : llvm::seq(VTs.size())) { | ||
| if (VectorInfo[PartI] & PVF_FIRST) { | ||
| assert(VecIdx == -1 && "Orphaned vector."); | ||
| VecIdx = parti; | ||
| VecIdx = PartI; | ||
| } | ||
|
|
||
| // That's the last element of this store op. | ||
| if (VectorInfo[parti] & PVF_LAST) { | ||
| unsigned NumElts = parti - VecIdx + 1; | ||
| EVT EltVT = VTs[parti]; | ||
| if (VectorInfo[PartI] & PVF_LAST) { | ||
| const unsigned NumElts = PartI - VecIdx + 1; | ||
| EVT EltVT = VTs[PartI]; | ||
| // i1 is loaded/stored as i8. | ||
| EVT LoadVT = EltVT; | ||
| if (EltVT == MVT::i1) | ||
|
|
@@ -3469,10 +3470,8 @@ SDValue NVPTXTargetLowering::LowerFormalArguments( | |
|
|
||
| EVT VecVT = EVT::getVectorVT(F->getContext(), LoadVT, NumElts); | ||
| SDValue VecAddr = | ||
| DAG.getNode(ISD::ADD, dl, PtrVT, Arg, | ||
| DAG.getNode(ISD::ADD, dl, PtrVT, ArgSymbol, | ||
| DAG.getConstant(Offsets[VecIdx], dl, PtrVT)); | ||
| Value *srcValue = Constant::getNullValue( | ||
| PointerType::get(F->getContext(), ADDRESS_SPACE_PARAM)); | ||
|
|
||
| const MaybeAlign PartAlign = [&]() -> MaybeAlign { | ||
| if (aggregateIsPacked) | ||
|
|
@@ -3481,23 +3480,23 @@ SDValue NVPTXTargetLowering::LowerFormalArguments( | |
| return std::nullopt; | ||
| Align PartAlign = | ||
| DL.getABITypeAlign(EltVT.getTypeForEVT(F->getContext())); | ||
| return commonAlignment(PartAlign, Offsets[parti]); | ||
| return commonAlignment(PartAlign, Offsets[PartI]); | ||
| }(); | ||
| SDValue P = DAG.getLoad(VecVT, dl, Root, VecAddr, | ||
| MachinePointerInfo(srcValue), PartAlign, | ||
| MachineMemOperand::MODereferenceable | | ||
| MachineMemOperand::MOInvariant); | ||
| SDValue P = | ||
| DAG.getLoad(VecVT, dl, Root, VecAddr, | ||
| MachinePointerInfo(ADDRESS_SPACE_PARAM), PartAlign, | ||
| MachineMemOperand::MODereferenceable | | ||
| MachineMemOperand::MOInvariant); | ||
| if (P.getNode()) | ||
| P.getNode()->setIROrder(i + 1); | ||
| for (unsigned j = 0; j < NumElts; ++j) { | ||
| P.getNode()->setIROrder(Arg.getArgNo() + 1); | ||
| for (const unsigned J : llvm::seq(NumElts)) { | ||
| SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, LoadVT, P, | ||
| DAG.getIntPtrConstant(j, dl)); | ||
| DAG.getIntPtrConstant(J, dl)); | ||
| // We've loaded i1 as an i8 and now must truncate it back to i1 | ||
| if (EltVT == MVT::i1) | ||
| Elt = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, Elt); | ||
| // v2f16 was loaded as an i32. Now we must bitcast it back. | ||
| else if (EltVT != LoadVT) | ||
| Elt = DAG.getNode(ISD::BITCAST, dl, EltVT, Elt); | ||
| Elt = DAG.getBitcast(EltVT, Elt); | ||
|
|
||
| // If a promoted integer type is used, truncate down to the original | ||
| MVT PromotedVT; | ||
|
|
@@ -3507,53 +3506,25 @@ SDValue NVPTXTargetLowering::LowerFormalArguments( | |
|
|
||
| // Extend the element if necessary (e.g. an i8 is loaded | ||
| // into an i16 register) | ||
| if (Ins[InsIdx].VT.isInteger() && | ||
| Ins[InsIdx].VT.getFixedSizeInBits() > | ||
| LoadVT.getFixedSizeInBits()) { | ||
| unsigned Extend = Ins[InsIdx].Flags.isSExt() ? ISD::SIGN_EXTEND | ||
| : ISD::ZERO_EXTEND; | ||
| Elt = DAG.getNode(Extend, dl, Ins[InsIdx].VT, Elt); | ||
| if (ArgIns[PartI].VT.getFixedSizeInBits() != | ||
| LoadVT.getFixedSizeInBits()) { | ||
| assert(ArgIns[PartI].VT.isInteger() && LoadVT.isInteger() && | ||
| "Non-integer argument type size mismatch"); | ||
| Elt = DAG.getExtOrTrunc(ArgIns[PartI].Flags.isSExt(), Elt, dl, | ||
| ArgIns[PartI].VT); | ||
| } | ||
| InVals.push_back(Elt); | ||
| } | ||
|
|
||
| // Reset vector tracking state. | ||
| VecIdx = -1; | ||
| } | ||
| ++InsIdx; | ||
| } | ||
| if (VTs.size() > 0) | ||
| --InsIdx; | ||
| continue; | ||
| } | ||
|
|
||
| // Param has ByVal attribute | ||
| // Return MoveParam(param symbol). | ||
| // Ideally, the param symbol can be returned directly, | ||
| // but when SDNode builder decides to use it in a CopyToReg(), | ||
| // machine instruction fails because TargetExternalSymbol | ||
| // (not lowered) is target dependent, and CopyToReg assumes | ||
| // the source is lowered. | ||
| EVT ObjectVT = getValueType(DL, Ty); | ||
| assert(ObjectVT == Ins[InsIdx].VT && | ||
| "Ins type did not match function type"); | ||
| SDValue Arg = getParamSymbol(DAG, i, PtrVT); | ||
|
|
||
| SDValue P; | ||
| if (isKernelFunction(*F)) { | ||
| P = DAG.getNode(NVPTXISD::Wrapper, dl, ObjectVT, Arg); | ||
| P.getNode()->setIROrder(i + 1); | ||
| } else { | ||
| P = DAG.getNode(NVPTXISD::MoveParam, dl, ObjectVT, Arg); | ||
| P.getNode()->setIROrder(i + 1); | ||
| P = DAG.getAddrSpaceCast(dl, ObjectVT, P, ADDRESS_SPACE_LOCAL, | ||
| ADDRESS_SPACE_GENERIC); | ||
| } | ||
| InVals.push_back(P); | ||
| } | ||
|
|
||
| if (!OutChains.empty()) | ||
| DAG.setRoot(DAG.getNode(ISD::TokenFactor, dl, MVT::Other, OutChains)); | ||
| DAG.setRoot(DAG.getTokenFactor(dl, OutChains)); | ||
|
|
||
| return Chain; | ||
| } | ||
|
|
@@ -5784,7 +5755,7 @@ static void ReplaceBITCAST(SDNode *Node, SelectionDAG &DAG, | |
|
|
||
| // Bitcast to i16 and unpack elements into a vector | ||
| SDLoc DL(Node); | ||
| SDValue AsInt = MaybeBitcast(DAG, DL, MVT::i16, Op->getOperand(0)); | ||
| SDValue AsInt = DAG.getBitcast(MVT::i16, Op->getOperand(0)); | ||
| SDValue Vec0 = DAG.getNode(ISD::TRUNCATE, DL, MVT::i8, AsInt); | ||
| SDValue Const8 = DAG.getConstant(8, DL, MVT::i16); | ||
| SDValue Vec1 = | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if it would be more convenient to use an index into
Insinstead of an iterator.Debugging this code in the past, it was always a bit of a pain figuring out which part advanced how far based on the current pointer. Having an index makes it a bit easier to examine things -- you can easily see where things go off-track.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've refactored things a bit more and I think the latest version shows the benefits of an iterator-based approach. Now, at the start of each loop iteration we update the iterator once and create an ArrayRef to the
Inscorresponding to this arg. Then in the rest of the loop we can just iterate over or index into this ArrayRef. This is much simpler then what we have previously where the index was incremented and decremented in confusing, control flow-dependent ways, so hopefully with this change we won't need to debug any further issues. I've also added several asserts to encode our assumption about how different data-structures will line up.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup. It's a really nice cleanup.