@@ -249,11 +249,6 @@ MCOperand NVPTXAsmPrinter::GetSymbolRef(const MCSymbol *Symbol) {
249249 return MCOperand::createExpr (Expr);
250250}
251251
252- static bool ShouldPassAsArray (Type *Ty) {
253- return Ty->isAggregateType () || Ty->isVectorTy () || Ty->isIntegerTy (128 ) ||
254- Ty->isHalfTy () || Ty->isBFloatTy ();
255- }
256-
257252void NVPTXAsmPrinter::printReturnValStr (const Function *F, raw_ostream &O) {
258253 const DataLayout &DL = getDataLayout ();
259254 const NVPTXSubtarget &STI = TM.getSubtarget <NVPTXSubtarget>(*F);
@@ -264,26 +259,21 @@ void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {
264259 return ;
265260 O << " (" ;
266261
267- if ((Ty->isFloatingPointTy () || Ty->isIntegerTy ()) &&
268- !ShouldPassAsArray (Ty)) {
269- unsigned size = 0 ;
270- if (auto *ITy = dyn_cast<IntegerType>(Ty)) {
271- size = ITy->getBitWidth ();
272- } else {
273- assert (Ty->isFloatingPointTy () && " Floating point type expected here" );
274- size = Ty->getPrimitiveSizeInBits ();
275- }
276- size = promoteScalarArgumentSize (size);
277- O << " .param .b" << size << " func_retval0" ;
278- } else if (isa<PointerType>(Ty)) {
279- O << " .param .b" << TLI->getPointerTy (DL).getSizeInBits ()
280- << " func_retval0" ;
281- } else if (ShouldPassAsArray (Ty)) {
282- unsigned totalsz = DL.getTypeAllocSize (Ty);
283- Align RetAlignment = TLI->getFunctionArgumentAlignment (
262+ auto PrintScalarRetVal = [&](unsigned Size) {
263+ O << " .param .b" << promoteScalarArgumentSize (Size) << " func_retval0" ;
264+ };
265+ if (shouldPassAsArray (Ty)) {
266+ const unsigned TotalSize = DL.getTypeAllocSize (Ty);
267+ const Align RetAlignment = TLI->getFunctionArgumentAlignment (
284268 F, Ty, AttributeList::ReturnIndex, DL);
285269 O << " .param .align " << RetAlignment.value () << " .b8 func_retval0["
286- << totalsz << " ]" ;
270+ << TotalSize << " ]" ;
271+ } else if (Ty->isFloatingPointTy ()) {
272+ PrintScalarRetVal (Ty->getPrimitiveSizeInBits ());
273+ } else if (auto *ITy = dyn_cast<IntegerType>(Ty)) {
274+ PrintScalarRetVal (ITy->getBitWidth ());
275+ } else if (isa<PointerType>(Ty)) {
276+ PrintScalarRetVal (TLI->getPointerTy (DL).getSizeInBits ());
287277 } else
288278 llvm_unreachable (" Unknown return type" );
289279 O << " ) " ;
@@ -975,8 +965,8 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
975965 O << " .align "
976966 << GVar->getAlign ().value_or (DL.getPrefTypeAlign (ETy)).value ();
977967
978- if (ETy->isFloatingPointTy () || ETy->isPointerTy () ||
979- (ETy-> isIntegerTy () && ETy->getScalarSizeInBits () <= 64 )) {
968+ if (ETy->isPointerTy () || (( ETy->isIntegerTy () || ETy-> isFloatingPointTy ()) &&
969+ ETy->getScalarSizeInBits () <= 64 )) {
980970 O << " ." ;
981971 // Special case: ABI requires that we use .u8 for predicates
982972 if (ETy->isIntegerTy (1 ))
@@ -1016,6 +1006,7 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
10161006 // and vectors are lowered into arrays of bytes.
10171007 switch (ETy->getTypeID ()) {
10181008 case Type::IntegerTyID: // Integers larger than 64 bits
1009+ case Type::FP128TyID:
10191010 case Type::StructTyID:
10201011 case Type::ArrayTyID:
10211012 case Type::FixedVectorTyID: {
@@ -1266,8 +1257,8 @@ void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar,
12661257 O << " .align "
12671258 << GVar->getAlign ().value_or (DL.getPrefTypeAlign (ETy)).value ();
12681259
1269- // Special case for i128
1270- if (ETy->isIntegerTy ( 128 ) ) {
1260+ // Special case for i128/fp128
1261+ if (ETy->getScalarSizeInBits () == 128 ) {
12711262 O << " .b8 " ;
12721263 getSymbol (GVar)->print (O, MAI);
12731264 O << " [16]" ;
@@ -1383,7 +1374,7 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
13831374 continue ;
13841375 }
13851376
1386- if (ShouldPassAsArray (Ty)) {
1377+ if (shouldPassAsArray (Ty)) {
13871378 // Just print .param .align <a> .b8 .param[size];
13881379 // <a> = optimal alignment for the element type; always multiple of
13891380 // PAL.getParamAlignment
@@ -1682,48 +1673,49 @@ void NVPTXAsmPrinter::bufferLEByte(const Constant *CPV, int Bytes,
16821673void NVPTXAsmPrinter::bufferAggregateConstant (const Constant *CPV,
16831674 AggBuffer *aggBuffer) {
16841675 const DataLayout &DL = getDataLayout ();
1685- int Bytes;
1676+
1677+ auto ExtendBuffer = [](APInt Val, AggBuffer *Buffer) {
1678+ for (unsigned I : llvm::seq (Val.getBitWidth () / 8 ))
1679+ Buffer->addByte (Val.extractBitsAsZExtValue (8 , I * 8 ));
1680+ };
16861681
16871682 // Integers of arbitrary width
16881683 if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) {
1689- APInt Val = CI->getValue ();
1690- for (unsigned I = 0 , E = DL.getTypeAllocSize (CPV->getType ()); I < E; ++I) {
1691- uint8_t Byte = Val.getLoBits (8 ).getZExtValue ();
1692- aggBuffer->addBytes (&Byte, 1 , 1 );
1693- Val.lshrInPlace (8 );
1694- }
1684+ ExtendBuffer (CI->getValue (), aggBuffer);
16951685 return ;
16961686 }
16971687
1688+ // f128
1689+ if (const ConstantFP *CFP = dyn_cast<ConstantFP>(CPV)) {
1690+ if (CFP->getType ()->isFP128Ty ()) {
1691+ ExtendBuffer (CFP->getValueAPF ().bitcastToAPInt (), aggBuffer);
1692+ return ;
1693+ }
1694+ }
1695+
16981696 // Old constants
16991697 if (isa<ConstantArray>(CPV) || isa<ConstantVector>(CPV)) {
1700- if (CPV->getNumOperands ())
1701- for (unsigned i = 0 , e = CPV->getNumOperands (); i != e; ++i)
1702- bufferLEByte (cast<Constant>(CPV->getOperand (i)), 0 , aggBuffer);
1698+ for (const auto &Op : CPV->operands ())
1699+ bufferLEByte (cast<Constant>(Op), 0 , aggBuffer);
17031700 return ;
17041701 }
17051702
1706- if (const ConstantDataSequential *CDS =
1707- dyn_cast<ConstantDataSequential>(CPV)) {
1708- if (CDS->getNumElements ())
1709- for (unsigned i = 0 ; i < CDS->getNumElements (); ++i)
1710- bufferLEByte (cast<Constant>(CDS->getElementAsConstant (i)), 0 ,
1711- aggBuffer);
1703+ if (const auto *CDS = dyn_cast<ConstantDataSequential>(CPV)) {
1704+ for (unsigned I : llvm::seq (CDS->getNumElements ()))
1705+ bufferLEByte (cast<Constant>(CDS->getElementAsConstant (I)), 0 , aggBuffer);
17121706 return ;
17131707 }
17141708
17151709 if (isa<ConstantStruct>(CPV)) {
17161710 if (CPV->getNumOperands ()) {
17171711 StructType *ST = cast<StructType>(CPV->getType ());
1718- for (unsigned i = 0 , e = CPV->getNumOperands (); i != e; ++i) {
1719- if (i == (e - 1 ))
1720- Bytes = DL.getStructLayout (ST)->getElementOffset (0 ) +
1721- DL.getTypeAllocSize (ST) -
1722- DL.getStructLayout (ST)->getElementOffset (i);
1723- else
1724- Bytes = DL.getStructLayout (ST)->getElementOffset (i + 1 ) -
1725- DL.getStructLayout (ST)->getElementOffset (i);
1726- bufferLEByte (cast<Constant>(CPV->getOperand (i)), Bytes, aggBuffer);
1712+ for (unsigned I : llvm::seq (CPV->getNumOperands ())) {
1713+ int EndOffset = (I + 1 == CPV->getNumOperands ())
1714+ ? DL.getStructLayout (ST)->getElementOffset (0 ) +
1715+ DL.getTypeAllocSize (ST)
1716+ : DL.getStructLayout (ST)->getElementOffset (I + 1 );
1717+ int Bytes = EndOffset - DL.getStructLayout (ST)->getElementOffset (I);
1718+ bufferLEByte (cast<Constant>(CPV->getOperand (I)), Bytes, aggBuffer);
17271719 }
17281720 }
17291721 return ;
0 commit comments