@@ -370,7 +370,7 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
370370 } else if (EltVT.getSimpleVT () == MVT::i8 && NumElts == 2 ) {
371371 // v2i8 is promoted to v2i16
372372 NumElts = 1 ;
373- EltVT = MVT::v2i16 ;
373+ EltVT = MVT::v2i8 ;
374374 }
375375 for (unsigned j = 0 ; j != NumElts; ++j) {
376376 ValueVTs.push_back (EltVT);
@@ -1065,9 +1065,6 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
10651065 MAKE_CASE (NVPTXISD::StoreParamV2)
10661066 MAKE_CASE (NVPTXISD::StoreParamV4)
10671067 MAKE_CASE (NVPTXISD::MoveParam)
1068- MAKE_CASE (NVPTXISD::StoreRetval)
1069- MAKE_CASE (NVPTXISD::StoreRetvalV2)
1070- MAKE_CASE (NVPTXISD::StoreRetvalV4)
10711068 MAKE_CASE (NVPTXISD::UNPACK_VECTOR)
10721069 MAKE_CASE (NVPTXISD::BUILD_VECTOR)
10731070 MAKE_CASE (NVPTXISD::CallPrototype)
@@ -1438,7 +1435,11 @@ static MachinePointerInfo refinePtrAS(SDValue &Ptr, SelectionDAG &DAG,
14381435}
14391436
14401437static ISD::NodeType getExtOpcode (const ISD::ArgFlagsTy &Flags) {
1441- return Flags.isSExt () ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
1438+ if (Flags.isSExt ())
1439+ return ISD::SIGN_EXTEND;
1440+ if (Flags.isZExt ())
1441+ return ISD::ZERO_EXTEND;
1442+ return ISD::ANY_EXTEND;
14421443}
14431444
14441445SDValue NVPTXTargetLowering::LowerCall (TargetLowering::CallLoweringInfo &CLI,
@@ -3373,10 +3374,6 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
33733374 }
33743375 InVals.push_back (P);
33753376 } else {
3376- bool aggregateIsPacked = false ;
3377- if (StructType *STy = dyn_cast<StructType>(Ty))
3378- aggregateIsPacked = STy->isPacked ();
3379-
33803377 SmallVector<EVT, 16 > VTs;
33813378 SmallVector<uint64_t , 16 > Offsets;
33823379 ComputePTXValueVTs (*this , DL, Ty, VTs, &Offsets, 0 );
@@ -3389,9 +3386,8 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
33893386 const auto VectorInfo = VectorizePTXValueVTs (VTs, Offsets, ArgAlign);
33903387 unsigned I = 0 ;
33913388 for (const unsigned NumElts : VectorInfo) {
3392- const EVT EltVT = VTs[I];
33933389 // i1 is loaded/stored as i8
3394- const EVT LoadVT = EltVT == MVT::i1 ? MVT::i8 : EltVT ;
3390+ const EVT LoadVT = VTs[I] == MVT::i1 ? MVT::i8 : VTs[I] ;
33953391 // If the element is a packed type (ex. v2f16, v4i8, etc) holding
33963392 // multiple elements.
33973393 const unsigned PackingAmt =
@@ -3403,14 +3399,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
34033399 SDValue VecAddr = DAG.getObjectPtrOffset (
34043400 dl, ArgSymbol, TypeSize::getFixed (Offsets[I]));
34053401
3406- const MaybeAlign PartAlign = [&]() -> MaybeAlign {
3407- if (aggregateIsPacked)
3408- return Align (1 );
3409- if (NumElts != 1 )
3410- return std::nullopt ;
3411- Align PartAlign = DAG.getEVTAlign (EltVT);
3412- return commonAlignment (PartAlign, Offsets[I]);
3413- }();
3402+ const MaybeAlign PartAlign = commonAlignment (ArgAlign, Offsets[I]);
34143403 SDValue P =
34153404 DAG.getLoad (VecVT, dl, Root, VecAddr,
34163405 MachinePointerInfo (ADDRESS_SPACE_PARAM), PartAlign,
@@ -3419,23 +3408,22 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
34193408 if (P.getNode ())
34203409 P.getNode ()->setIROrder (Arg.getArgNo () + 1 );
34213410 for (const unsigned J : llvm::seq (NumElts)) {
3422- SDValue Elt = DAG.getNode (LoadVT. isVector () ? ISD::EXTRACT_SUBVECTOR
3423- : ISD::EXTRACT_VECTOR_ELT,
3424- dl, LoadVT, P ,
3425- DAG.getIntPtrConstant (J * PackingAmt, dl));
3411+ SDValue Elt = DAG.getNode (
3412+ LoadVT. isVector () ? ISD::EXTRACT_SUBVECTOR
3413+ : ISD::EXTRACT_VECTOR_ELT ,
3414+ dl, LoadVT, P, DAG.getVectorIdxConstant (J * PackingAmt, dl));
34263415
34273416 // Extend or truncate the element if necessary (e.g. an i8 is loaded
34283417 // into an i16 register)
3429- const EVT ExpactedVT = ArgIns[I + J].VT ;
3430- assert ((Elt.getValueType ().bitsEq (ExpactedVT) ||
3431- (ExpactedVT.isScalarInteger () &&
3432- Elt.getValueType ().isScalarInteger ())) &&
3418+ const EVT ExpectedVT = ArgIns[I + J].VT ;
3419+ assert ((Elt.getValueType () == ExpectedVT ||
3420+ (ExpectedVT.isInteger () && Elt.getValueType ().isInteger ())) &&
34333421 " Non-integer argument type size mismatch" );
3434- if (ExpactedVT .bitsGT (Elt.getValueType ()))
3435- Elt = DAG.getNode (getExtOpcode (ArgIns[I + J].Flags ), dl, ExpactedVT ,
3422+ if (ExpectedVT .bitsGT (Elt.getValueType ()))
3423+ Elt = DAG.getNode (getExtOpcode (ArgIns[I + J].Flags ), dl, ExpectedVT ,
34363424 Elt);
3437- else if (ExpactedVT .bitsLT (Elt.getValueType ()))
3438- Elt = DAG.getNode (ISD::TRUNCATE, dl, ExpactedVT , Elt);
3425+ else if (ExpectedVT .bitsLT (Elt.getValueType ()))
3426+ Elt = DAG.getNode (ISD::TRUNCATE, dl, ExpectedVT , Elt);
34393427 InVals.push_back (Elt);
34403428 }
34413429 I += NumElts;
@@ -3449,33 +3437,6 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
34493437 return Chain;
34503438}
34513439
3452- // Use byte-store when the param adress of the return value is unaligned.
3453- // This may happen when the return value is a field of a packed structure.
3454- static SDValue LowerUnalignedStoreRet (SelectionDAG &DAG, SDValue Chain,
3455- uint64_t Offset, EVT ElementType,
3456- SDValue RetVal, const SDLoc &dl) {
3457- // Bit logic only works on integer types
3458- if (adjustElementType (ElementType))
3459- RetVal = DAG.getNode (ISD::BITCAST, dl, ElementType, RetVal);
3460-
3461- // Store each byte
3462- for (unsigned i = 0 , n = ElementType.getSizeInBits () / 8 ; i < n; i++) {
3463- // Shift the byte to the last byte position
3464- SDValue ShiftVal = DAG.getNode (ISD::SRL, dl, ElementType, RetVal,
3465- DAG.getConstant (i * 8 , dl, MVT::i32 ));
3466- SDValue StoreOperands[] = {Chain, DAG.getConstant (Offset + i, dl, MVT::i32 ),
3467- ShiftVal};
3468- // Trunc store only the last byte by using
3469- // st.param.b8
3470- // The register type can be larger than b8.
3471- Chain = DAG.getMemIntrinsicNode (NVPTXISD::StoreRetval, dl,
3472- DAG.getVTList (MVT::Other), StoreOperands,
3473- MVT::i8 , MachinePointerInfo (), std::nullopt ,
3474- MachineMemOperand::MOStore);
3475- }
3476- return Chain;
3477- }
3478-
34793440SDValue
34803441NVPTXTargetLowering::LowerReturn (SDValue Chain, CallingConv::ID CallConv,
34813442 bool isVarArg,
@@ -3497,10 +3458,6 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
34973458 ComputePTXValueVTs (*this , DL, RetTy, VTs, &Offsets);
34983459 assert (VTs.size () == OutVals.size () && " Bad return value decomposition" );
34993460
3500- for (const unsigned I : llvm::seq (VTs.size ()))
3501- if (const auto PromotedVT = PromoteScalarIntegerPTX (VTs[I]))
3502- VTs[I] = *PromotedVT;
3503-
35043461 // PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
35053462 // 32-bits are sign extended or zero extended, depending on whether
35063463 // they are signed or unsigned types.
@@ -3512,12 +3469,20 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
35123469 assert (!PromoteScalarIntegerPTX (RetVal.getValueType ()) &&
35133470 " OutVal type should always be legal" );
35143471
3515- if (ExtendIntegerRetVal) {
3516- RetVal = DAG.getNode (getExtOpcode (Outs[I].Flags ), dl, MVT::i32 , RetVal);
3517- } else if (RetVal.getValueSizeInBits () < 16 ) {
3518- // Use 16-bit registers for small load-stores as it's the
3519- // smallest general purpose register size supported by NVPTX.
3520- RetVal = DAG.getNode (ISD::ANY_EXTEND, dl, MVT::i16 , RetVal);
3472+ EVT VTI = VTs[I];
3473+ if (const auto PromotedVT = PromoteScalarIntegerPTX (VTI))
3474+ VTI = *PromotedVT;
3475+
3476+ const EVT StoreVT =
3477+ ExtendIntegerRetVal ? MVT::i32 : (VTI == MVT::i1 ? MVT::i8 : VTI);
3478+
3479+ assert ((RetVal.getValueType () == StoreVT ||
3480+ (StoreVT.isInteger () && RetVal.getValueType ().isInteger ())) &&
3481+ " Non-integer argument type size mismatch" );
3482+ if (StoreVT.bitsGT (RetVal.getValueType ())) {
3483+ RetVal = DAG.getNode (getExtOpcode (Outs[I].Flags ), dl, StoreVT, RetVal);
3484+ } else if (StoreVT.bitsLT (RetVal.getValueType ())) {
3485+ RetVal = DAG.getNode (ISD::TRUNCATE, dl, StoreVT, RetVal);
35213486 }
35223487 return RetVal;
35233488 };
@@ -3526,45 +3491,34 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
35263491 const auto VectorInfo = VectorizePTXValueVTs (VTs, Offsets, RetAlign);
35273492 unsigned I = 0 ;
35283493 for (const unsigned NumElts : VectorInfo) {
3529- const Align CurrentAlign = commonAlignment (RetAlign, Offsets[I]);
3530- if (NumElts == 1 && RetTy->isAggregateType () &&
3531- CurrentAlign < DAG.getEVTAlign (VTs[I])) {
3532- Chain = LowerUnalignedStoreRet (DAG, Chain, Offsets[I], VTs[I],
3533- GetRetVal (I), dl);
3534-
3535- // The call to LowerUnalignedStoreRet inserted the necessary SDAG nodes
3536- // into the graph, so just move on to the next element.
3537- I++;
3538- continue ;
3539- }
3494+ const MaybeAlign CurrentAlign = ExtendIntegerRetVal
3495+ ? MaybeAlign (std::nullopt )
3496+ : commonAlignment (RetAlign, Offsets[I]);
35403497
3541- SmallVector<SDValue, 6 > StoreOperands{
3542- Chain, DAG.getConstant (Offsets[I], dl, MVT::i32 )};
3543-
3544- for (const unsigned J : llvm::seq (NumElts))
3545- StoreOperands.push_back (GetRetVal (I + J));
3498+ SDValue Val;
3499+ if (NumElts == 1 ) {
3500+ Val = GetRetVal (I);
3501+ } else {
3502+ SmallVector<SDValue, 6 > StoreVals;
3503+ for (const unsigned J : llvm::seq (NumElts)) {
3504+ SDValue ValJ = GetRetVal (I + J);
3505+ if (ValJ.getValueType ().isVector ())
3506+ DAG.ExtractVectorElements (ValJ, StoreVals);
3507+ else
3508+ StoreVals.push_back (ValJ);
3509+ }
35463510
3547- NVPTXISD::NodeType Op;
3548- switch (NumElts) {
3549- case 1 :
3550- Op = NVPTXISD::StoreRetval;
3551- break ;
3552- case 2 :
3553- Op = NVPTXISD::StoreRetvalV2;
3554- break ;
3555- case 4 :
3556- Op = NVPTXISD::StoreRetvalV4;
3557- break ;
3558- default :
3559- llvm_unreachable (" Invalid vector info." );
3511+ EVT VT = EVT::getVectorVT (F.getContext (), StoreVals[0 ].getValueType (),
3512+ StoreVals.size ());
3513+ Val = DAG.getBuildVector (VT, dl, StoreVals);
35603514 }
35613515
3562- // Adjust type of load/store op if we've extended the scalar
3563- // return value.
3564- EVT TheStoreType = ExtendIntegerRetVal ? MVT:: i32 : VTs [I];
3565- Chain = DAG. getMemIntrinsicNode (
3566- Op, dl, DAG.getVTList (MVT::Other), StoreOperands, TheStoreType ,
3567- MachinePointerInfo (), CurrentAlign, MachineMemOperand::MOStore );
3516+ SDValue RetSymbol = DAG. getExternalSymbol ( " func_retval0 " , MVT:: i32 );
3517+ SDValue Ptr =
3518+ DAG. getObjectPtrOffset (dl, RetSymbol, TypeSize::getFixed (Offsets [I])) ;
3519+
3520+ Chain = DAG.getStore (Chain, dl, Val, Ptr ,
3521+ MachinePointerInfo (ADDRESS_SPACE_PARAM ), CurrentAlign);
35683522
35693523 I += NumElts;
35703524 }
@@ -5120,19 +5074,12 @@ static SDValue combinePackingMovIntoStore(SDNode *N,
51205074 case NVPTXISD::StoreParamV2:
51215075 Opcode = NVPTXISD::StoreParamV4;
51225076 break ;
5123- case NVPTXISD::StoreRetval:
5124- Opcode = NVPTXISD::StoreRetvalV2;
5125- break ;
5126- case NVPTXISD::StoreRetvalV2:
5127- Opcode = NVPTXISD::StoreRetvalV4;
5128- break ;
51295077 case NVPTXISD::StoreV2:
51305078 MemVT = ST->getMemoryVT ();
51315079 Opcode = NVPTXISD::StoreV4;
51325080 break ;
51335081 case NVPTXISD::StoreV4:
51345082 case NVPTXISD::StoreParamV4:
5135- case NVPTXISD::StoreRetvalV4:
51365083 case NVPTXISD::StoreV8:
51375084 // PTX doesn't support the next doubling of operands
51385085 return SDValue ();
@@ -5201,12 +5148,6 @@ static SDValue PerformStoreParamCombine(SDNode *N,
52015148 return PerformStoreCombineHelper (N, DCI, 3 , 1 );
52025149}
52035150
5204- static SDValue PerformStoreRetvalCombine (SDNode *N,
5205- TargetLowering::DAGCombinerInfo &DCI) {
5206- // Operands from the 2nd to the last one are the values to be stored
5207- return PerformStoreCombineHelper (N, DCI, 2 , 0 );
5208- }
5209-
52105151// / PerformADDCombine - Target-specific dag combine xforms for ISD::ADD.
52115152// /
52125153static SDValue PerformADDCombine (SDNode *N,
@@ -5840,10 +5781,6 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
58405781 case NVPTXISD::LoadV2:
58415782 case NVPTXISD::LoadV4:
58425783 return combineUnpackingMovIntoLoad (N, DCI);
5843- case NVPTXISD::StoreRetval:
5844- case NVPTXISD::StoreRetvalV2:
5845- case NVPTXISD::StoreRetvalV4:
5846- return PerformStoreRetvalCombine (N, DCI);
58475784 case NVPTXISD::StoreParam:
58485785 case NVPTXISD::StoreParamV2:
58495786 case NVPTXISD::StoreParamV4:
0 commit comments