@@ -1389,26 +1389,6 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
13891389 return DAG.getConstant (I, dl, MVT::i32 );
13901390 };
13911391
1392- // Variadic arguments.
1393- //
1394- // Normally, for each argument, we declare a param scalar or a param
1395- // byte array in the .param space, and store the argument value to that
1396- // param scalar or array starting at offset 0.
1397- //
1398- // In the case of the first variadic argument, we declare a vararg byte array
1399- // with size 0. The exact size of this array isn't known at this point, so
1400- // it'll be patched later. All the variadic arguments will be stored to this
1401- // array at a certain offset (which gets tracked by 'VAOffset'). The offset is
1402- // initially set to 0, so it can be used for non-variadic arguments (which use
1403- // 0 offset) to simplify the code.
1404- //
1405- // After all vararg is processed, 'VAOffset' holds the size of the
1406- // vararg byte array.
1407-
1408- SDValue VADeclareParam = SDValue (); // vararg byte array
1409- const unsigned FirstVAArg = CLI.NumFixedArgs ; // position of first variadic
1410- unsigned VAOffset = 0 ; // current offset in the param array
1411-
14121392 const unsigned UniqueCallSite = GlobalUniqueCallSite++;
14131393 const SDValue CallChain = CLI.Chain ;
14141394 const SDValue StartChain =
@@ -1417,7 +1397,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
14171397
14181398 SmallVector<SDValue, 16 > CallPrereqs{StartChain};
14191399
1420- const auto DeclareScalarParam = [&](SDValue Symbol, unsigned Size) {
1400+ const auto MakeDeclareScalarParam = [&](SDValue Symbol, unsigned Size) {
14211401 // PTX ABI requires integral types to be at least 32 bits in size. FP16 is
14221402 // loaded/stored using i16, so it's handled here as well.
14231403 const unsigned SizeBits = promoteScalarArgumentSize (Size * 8 );
@@ -1429,8 +1409,8 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
14291409 return Declare;
14301410 };
14311411
1432- const auto DeclareArrayParam = [&](SDValue Symbol, Align Align,
1433- unsigned Size) {
1412+ const auto MakeDeclareArrayParam = [&](SDValue Symbol, Align Align,
1413+ unsigned Size) {
14341414 SDValue Declare = DAG.getNode (
14351415 NVPTXISD::DeclareArrayParam, dl, {MVT::Other, MVT::Glue},
14361416 {StartChain, Symbol, GetI32 (Align.value ()), GetI32 (Size), DeclareGlue});
@@ -1439,6 +1419,33 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
14391419 return Declare;
14401420 };
14411421
1422+ // Variadic arguments.
1423+ //
1424+ // Normally, for each argument, we declare a param scalar or a param
1425+ // byte array in the .param space, and store the argument value to that
1426+ // param scalar or array starting at offset 0.
1427+ //
1428+ // In the case of the first variadic argument, we declare a vararg byte array
1429+ // with size 0. The exact size of this array isn't known at this point, so
1430+ // it'll be patched later. All the variadic arguments will be stored to this
1431+ // array at a certain offset (which gets tracked by 'VAOffset'). The offset is
1432+ // initially set to 0, so it can be used for non-variadic arguments (which use
1433+ // 0 offset) to simplify the code.
1434+ //
1435+ // After all vararg is processed, 'VAOffset' holds the size of the
1436+ // vararg byte array.
1437+ assert ((CLI.IsVarArg || CLI.Args .size () == CLI.NumFixedArgs ) &&
1438+ " Non-VarArg function with extra arguments" );
1439+
1440+ const unsigned FirstVAArg = CLI.NumFixedArgs ; // position of first variadic
1441+ unsigned VAOffset = 0 ; // current offset in the param array
1442+
1443+ const SDValue VADeclareParam =
1444+ CLI.Args .size () > FirstVAArg
1445+ ? MakeDeclareArrayParam (getCallParamSymbol (DAG, FirstVAArg, MVT::i32 ),
1446+ Align (STI.getMaxRequiredAlignment ()), 0 )
1447+ : SDValue ();
1448+
14421449 // Args.size() and Outs.size() need not match.
14431450 // Outs.size() will be larger
14441451 // * if there is an aggregate argument with multiple fields (each field
@@ -1499,21 +1506,17 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
14991506 " type size mismatch" );
15001507
15011508 const SDValue ArgDeclare = [&]() {
1502- if (IsVAArg) {
1503- if (ArgI == FirstVAArg)
1504- VADeclareParam = DeclareArrayParam (
1505- ParamSymbol, Align (STI.getMaxRequiredAlignment ()), 0 );
1509+ if (IsVAArg)
15061510 return VADeclareParam;
1507- }
15081511
15091512 if (IsByVal || shouldPassAsArray (Arg.Ty ))
1510- return DeclareArrayParam (ParamSymbol, ArgAlign, TypeSize);
1513+ return MakeDeclareArrayParam (ParamSymbol, ArgAlign, TypeSize);
15111514
15121515 assert (ArgOuts.size () == 1 && " We must pass only one value as non-array" );
15131516 assert ((ArgOuts[0 ].VT .isInteger () || ArgOuts[0 ].VT .isFloatingPoint ()) &&
15141517 " Only int and float types are supported as non-array arguments" );
15151518
1516- return DeclareScalarParam (ParamSymbol, TypeSize);
1519+ return MakeDeclareScalarParam (ParamSymbol, TypeSize);
15171520 }();
15181521
15191522 // PTX Interoperability Guide 3.3(A): [Integer] Values shorter
@@ -1573,7 +1576,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
15731576 if (NumElts == 1 ) {
15741577 Val = GetStoredValue (J, EltVT, CurrentAlign);
15751578 } else {
1576- SmallVector<SDValue, 6 > StoreVals;
1579+ SmallVector<SDValue, 8 > StoreVals;
15771580 for (const unsigned K : llvm::seq (NumElts)) {
15781581 SDValue ValJ = GetStoredValue (J + K, EltVT, CurrentAlign);
15791582 if (ValJ.getValueType ().isVector ())
@@ -1614,9 +1617,9 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
16141617 const unsigned ResultSize = DL.getTypeAllocSize (RetTy);
16151618 if (shouldPassAsArray (RetTy)) {
16161619 const Align RetAlign = getArgumentAlignment (CB, RetTy, 0 , DL);
1617- DeclareArrayParam (RetSymbol, RetAlign, ResultSize);
1620+ MakeDeclareArrayParam (RetSymbol, RetAlign, ResultSize);
16181621 } else {
1619- DeclareScalarParam (RetSymbol, ResultSize);
1622+ MakeDeclareScalarParam (RetSymbol, ResultSize);
16201623 }
16211624 }
16221625
@@ -1740,17 +1743,16 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
17401743
17411744 LoadChains.push_back (R.getValue (1 ));
17421745
1743- if (NumElts == 1 ) {
1746+ if (NumElts == 1 )
17441747 ProxyRegOps.push_back (R);
1745- } else {
1748+ else
17461749 for (const unsigned J : llvm::seq (NumElts)) {
17471750 SDValue Elt = DAG.getNode (
17481751 LoadVT.isVector () ? ISD::EXTRACT_SUBVECTOR
17491752 : ISD::EXTRACT_VECTOR_ELT,
17501753 dl, LoadVT, R, DAG.getVectorIdxConstant (J * PackingAmt, dl));
17511754 ProxyRegOps.push_back (Elt);
17521755 }
1753- }
17541756 I += NumElts;
17551757 }
17561758 }
@@ -5770,7 +5772,7 @@ static SDValue sinkProxyReg(SDValue R, SDValue Chain,
57705772 {Chain, R});
57715773 }
57725774 case ISD::BUILD_VECTOR: {
5773- if (DCI.isAfterLegalizeDAG ())
5775+ if (DCI.isBeforeLegalize ())
57745776 return SDValue ();
57755777
57765778 SmallVector<SDValue, 16 > Ops;
@@ -5782,6 +5784,15 @@ static SDValue sinkProxyReg(SDValue R, SDValue Chain,
57825784 }
57835785 return DCI.DAG .getNode (ISD::BUILD_VECTOR, SDLoc (R), R.getValueType (), Ops);
57845786 }
5787+ case ISD::EXTRACT_VECTOR_ELT: {
5788+ if (DCI.isBeforeLegalize ())
5789+ return SDValue ();
5790+
5791+ if (SDValue V = sinkProxyReg (R.getOperand (0 ), Chain, DCI))
5792+ return DCI.DAG .getNode (ISD::EXTRACT_VECTOR_ELT, SDLoc (R), R.getValueType (),
5793+ V, R.getOperand (1 ));
5794+ return SDValue ();
5795+ }
57855796 default :
57865797 return SDValue ();
57875798 }
0 commit comments