Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 37 additions & 37 deletions llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,11 +249,6 @@ MCOperand NVPTXAsmPrinter::GetSymbolRef(const MCSymbol *Symbol) {
return MCOperand::createExpr(Expr);
}

static bool ShouldPassAsArray(Type *Ty) {
return Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128) ||
Ty->isHalfTy() || Ty->isBFloatTy();
}

void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {
const DataLayout &DL = getDataLayout();
const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F);
Expand All @@ -264,26 +259,21 @@ void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {
return;
O << " (";

if ((Ty->isFloatingPointTy() || Ty->isIntegerTy()) &&
!ShouldPassAsArray(Ty)) {
unsigned size = 0;
if (auto *ITy = dyn_cast<IntegerType>(Ty)) {
size = ITy->getBitWidth();
} else {
assert(Ty->isFloatingPointTy() && "Floating point type expected here");
size = Ty->getPrimitiveSizeInBits();
}
size = promoteScalarArgumentSize(size);
O << ".param .b" << size << " func_retval0";
} else if (isa<PointerType>(Ty)) {
O << ".param .b" << TLI->getPointerTy(DL).getSizeInBits()
<< " func_retval0";
} else if (ShouldPassAsArray(Ty)) {
unsigned totalsz = DL.getTypeAllocSize(Ty);
Align RetAlignment = TLI->getFunctionArgumentAlignment(
auto PrintScalarParam = [&](unsigned Size) {
O << ".param .b" << promoteScalarArgumentSize(Size) << " func_retval0";
};
if (shouldPassAsArray(Ty)) {
const unsigned TotalSize = DL.getTypeAllocSize(Ty);
const Align RetAlignment = TLI->getFunctionArgumentAlignment(
F, Ty, AttributeList::ReturnIndex, DL);
O << ".param .align " << RetAlignment.value() << " .b8 func_retval0["
<< totalsz << "]";
<< TotalSize << "]";
} else if (Ty->isFloatingPointTy()) {
PrintScalarParam(Ty->getPrimitiveSizeInBits());
} else if (auto *ITy = dyn_cast<IntegerType>(Ty)) {
PrintScalarParam(ITy->getBitWidth());
} else if (isa<PointerType>(Ty)) {
PrintScalarParam(TLI->getPointerTy(DL).getSizeInBits());
} else
llvm_unreachable("Unknown return type");
O << ") ";
Expand Down Expand Up @@ -975,8 +965,8 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
O << " .align "
<< GVar->getAlign().value_or(DL.getPrefTypeAlign(ETy)).value();

if (ETy->isFloatingPointTy() || ETy->isPointerTy() ||
(ETy->isIntegerTy() && ETy->getScalarSizeInBits() <= 64)) {
if (ETy->isPointerTy() || ((ETy->isIntegerTy() || ETy->isFloatingPointTy()) &&
ETy->getScalarSizeInBits() <= 64)) {
O << " .";
// Special case: ABI requires that we use .u8 for predicates
if (ETy->isIntegerTy(1))
Expand Down Expand Up @@ -1016,6 +1006,7 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
// and vectors are lowered into arrays of bytes.
switch (ETy->getTypeID()) {
case Type::IntegerTyID: // Integers larger than 64 bits
case Type::FP128TyID:
case Type::StructTyID:
case Type::ArrayTyID:
case Type::FixedVectorTyID: {
Expand Down Expand Up @@ -1266,8 +1257,8 @@ void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar,
O << " .align "
<< GVar->getAlign().value_or(DL.getPrefTypeAlign(ETy)).value();

// Special case for i128
if (ETy->isIntegerTy(128)) {
// Special case for i128/fp128
if (ETy->getScalarSizeInBits() == 128) {
O << " .b8 ";
getSymbol(GVar)->print(O, MAI);
O << "[16]";
Expand Down Expand Up @@ -1383,7 +1374,7 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
continue;
}

if (ShouldPassAsArray(Ty)) {
if (shouldPassAsArray(Ty)) {
// Just print .param .align <a> .b8 .param[size];
// <a> = optimal alignment for the element type; always multiple of
// PAL.getParamAlignment
Expand Down Expand Up @@ -1682,29 +1673,37 @@ void NVPTXAsmPrinter::bufferLEByte(const Constant *CPV, int Bytes,
void NVPTXAsmPrinter::bufferAggregateConstant(const Constant *CPV,
AggBuffer *aggBuffer) {
const DataLayout &DL = getDataLayout();
int Bytes;

// Integers of arbitrary width
if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) {
APInt Val = CI->getValue();
auto BufferConstant = [&](APInt Val) {
for (unsigned I = 0, E = DL.getTypeAllocSize(CPV->getType()); I < E; ++I) {
uint8_t Byte = Val.getLoBits(8).getZExtValue();
aggBuffer->addBytes(&Byte, 1, 1);
Val.lshrInPlace(8);
}
};

// Integers of arbitrary width
if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) {
BufferConstant(CI->getValue());
return;
}

// f128
if (const ConstantFP *CFP = dyn_cast<ConstantFP>(CPV)) {
if (CFP->getType()->isFP128Ty()) {
BufferConstant(CFP->getValueAPF().bitcastToAPInt());
return;
}
}

// Old constants
if (isa<ConstantArray>(CPV) || isa<ConstantVector>(CPV)) {
if (CPV->getNumOperands())
for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i)
bufferLEByte(cast<Constant>(CPV->getOperand(i)), 0, aggBuffer);
for (const auto &Op : CPV->operands())
bufferLEByte(cast<Constant>(Op), 0, aggBuffer);
return;
}

if (const ConstantDataSequential *CDS =
dyn_cast<ConstantDataSequential>(CPV)) {
if (const auto *CDS = dyn_cast<ConstantDataSequential>(CPV)) {
if (CDS->getNumElements())
for (unsigned i = 0; i < CDS->getNumElements(); ++i)
bufferLEByte(cast<Constant>(CDS->getElementAsConstant(i)), 0,
Expand All @@ -1716,6 +1715,7 @@ void NVPTXAsmPrinter::bufferAggregateConstant(const Constant *CPV,
if (CPV->getNumOperands()) {
StructType *ST = cast<StructType>(CPV->getType());
for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i) {
int Bytes;
if (i == (e - 1))
Bytes = DL.getStructLayout(ST)->getElementOffset(0) +
DL.getTypeAllocSize(ST) -
Expand Down
28 changes: 10 additions & 18 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,14 +246,11 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
SmallVector<uint64_t, 16> TempOffsets;

// Special case for i128 - decompose to (i64, i64)
if (Ty->isIntegerTy(128)) {
ValueVTs.push_back(EVT(MVT::i64));
ValueVTs.push_back(EVT(MVT::i64));
if (Ty->isIntegerTy(128) || Ty->isFP128Ty()) {
ValueVTs.append({MVT::i64, MVT::i64});

if (Offsets) {
Offsets->push_back(StartingOffset + 0);
Offsets->push_back(StartingOffset + 8);
}
if (Offsets)
Offsets->append({StartingOffset + 0, StartingOffset + 8});

return;
}
Expand Down Expand Up @@ -1165,11 +1162,6 @@ NVPTXTargetLowering::LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const {
return DAG.getNode(NVPTXISD::Wrapper, dl, PtrVT, Op);
}

static bool IsTypePassedAsArray(const Type *Ty) {
return Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128) ||
Ty->isHalfTy() || Ty->isBFloatTy();
}

std::string NVPTXTargetLowering::getPrototype(
const DataLayout &DL, Type *retTy, const ArgListTy &Args,
const SmallVectorImpl<ISD::OutputArg> &Outs, MaybeAlign retAlignment,
Expand All @@ -1186,7 +1178,7 @@ std::string NVPTXTargetLowering::getPrototype(
} else {
O << "(";
if ((retTy->isFloatingPointTy() || retTy->isIntegerTy()) &&
!IsTypePassedAsArray(retTy)) {
!shouldPassAsArray(retTy)) {
unsigned size = 0;
if (auto *ITy = dyn_cast<IntegerType>(retTy)) {
size = ITy->getBitWidth();
Expand All @@ -1203,7 +1195,7 @@ std::string NVPTXTargetLowering::getPrototype(
O << ".param .b" << size << " _";
} else if (isa<PointerType>(retTy)) {
O << ".param .b" << PtrVT.getSizeInBits() << " _";
} else if (IsTypePassedAsArray(retTy)) {
} else if (shouldPassAsArray(retTy)) {
O << ".param .align " << (retAlignment ? retAlignment->value() : 0)
<< " .b8 _[" << DL.getTypeAllocSize(retTy) << "]";
} else {
Expand All @@ -1224,7 +1216,7 @@ std::string NVPTXTargetLowering::getPrototype(
first = false;

if (!Outs[OIdx].Flags.isByVal()) {
if (IsTypePassedAsArray(Ty)) {
if (shouldPassAsArray(Ty)) {
Align ParamAlign =
getArgumentAlignment(&CB, Ty, i + AttributeList::FirstArgIndex, DL);
O << ".param .align " << ParamAlign.value() << " .b8 ";
Expand Down Expand Up @@ -1529,7 +1521,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);

bool NeedAlign; // Does argument declaration specify alignment?
bool PassAsArray = IsByVal || IsTypePassedAsArray(Ty);
const bool PassAsArray = IsByVal || shouldPassAsArray(Ty);
if (IsVAArg) {
if (ParamCount == FirstVAArg) {
SDValue DeclareParamOps[] = {
Expand Down Expand Up @@ -1718,7 +1710,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
// .param .align N .b8 retval0[<size-in-bytes>], or
// .param .b<size-in-bits> retval0
unsigned resultsz = DL.getTypeAllocSizeInBits(RetTy);
if (!IsTypePassedAsArray(RetTy)) {
if (!shouldPassAsArray(RetTy)) {
resultsz = promoteScalarArgumentSize(resultsz);
SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
SDValue DeclareRetOps[] = { Chain, DAG.getConstant(1, dl, MVT::i32),
Expand Down Expand Up @@ -3344,7 +3336,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(

if (theArgs[i]->use_empty()) {
// argument is dead
if (IsTypePassedAsArray(Ty) && !Ty->isVectorTy()) {
if (shouldPassAsArray(Ty) && !Ty->isVectorTy()) {
SmallVector<EVT, 16> vtparts;

ComputePTXValueVTs(*this, DAG.getDataLayout(), Ty, vtparts);
Expand Down
4 changes: 0 additions & 4 deletions llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -372,8 +372,4 @@ bool shouldEmitPTXNoReturn(const Value *V, const TargetMachine &TM) {
!isKernelFunction(*F);
}

bool Isv2x16VT(EVT VT) {
return (VT == MVT::v2f16 || VT == MVT::v2bf16 || VT == MVT::v2i16);
}

} // namespace llvm
9 changes: 8 additions & 1 deletion llvm/lib/Target/NVPTX/NVPTXUtilities.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,14 @@ inline unsigned promoteScalarArgumentSize(unsigned size) {

bool shouldEmitPTXNoReturn(const Value *V, const TargetMachine &TM);

bool Isv2x16VT(EVT VT);
inline bool Isv2x16VT(EVT VT) {
return (VT == MVT::v2f16 || VT == MVT::v2bf16 || VT == MVT::v2i16);
}

inline bool shouldPassAsArray(Type *Ty) {
return Ty->isAggregateType() || Ty->isVectorTy() ||
Ty->getScalarSizeInBits() == 128 || Ty->isHalfTy() || Ty->isBFloatTy();
}

namespace NVPTX {
inline std::string getValidPTXIdentifier(StringRef Name) {
Expand Down
56 changes: 56 additions & 0 deletions llvm/test/CodeGen/NVPTX/fp128-storage-type.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc < %s -mcpu=sm_20 | FileCheck %s
; RUN: %if ptxas %{ llc < %s-mcpu=sm_20 | %ptxas-verify %}

target triple = "nvptx64-unknown-cuda"

define fp128 @identity(fp128 %x) {
; CHECK-LABEL: identity(
; CHECK: {
; CHECK-NEXT: .reg .b64 %rd<3>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.v2.u64 {%rd1, %rd2}, [identity_param_0];
; CHECK-NEXT: st.param.v2.b64 [func_retval0], {%rd1, %rd2};
; CHECK-NEXT: ret;
ret fp128 %x
}

define void @load_store(ptr %in, ptr %out) {
; CHECK-LABEL: load_store(
; CHECK: {
; CHECK-NEXT: .reg .b64 %rd<5>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.u64 %rd1, [load_store_param_0];
; CHECK-NEXT: ld.u64 %rd2, [%rd1+8];
; CHECK-NEXT: ld.u64 %rd3, [%rd1];
; CHECK-NEXT: ld.param.u64 %rd4, [load_store_param_1];
; CHECK-NEXT: st.u64 [%rd4], %rd3;
; CHECK-NEXT: st.u64 [%rd4+8], %rd2;
; CHECK-NEXT: ret;
%val = load fp128, ptr %in
store fp128 %val, ptr %out
ret void
}

define void @call(fp128 %x) {
; CHECK-LABEL: call(
; CHECK: {
; CHECK-NEXT: .reg .b64 %rd<3>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.v2.u64 {%rd1, %rd2}, [call_param_0];
; CHECK-NEXT: { // callseq 0, 0
; CHECK-NEXT: .param .align 16 .b8 param0[16];
; CHECK-NEXT: st.param.v2.b64 [param0], {%rd1, %rd2};
; CHECK-NEXT: call.uni
; CHECK-NEXT: call,
; CHECK-NEXT: (
; CHECK-NEXT: param0
; CHECK-NEXT: );
; CHECK-NEXT: } // callseq 0
; CHECK-NEXT: ret;
call void @call(fp128 %x)
ret void
}
5 changes: 4 additions & 1 deletion llvm/test/CodeGen/NVPTX/global-variable-big.ll
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@
target datalayout = "e-i64:64-v16:16-v32:32-n16:32:64"
target triple = "nvptx64-nvidia-cuda"

; Check that we can handle global variables of large integer type.
; Check that we can handle global variables of large integer and fp128 type.

; (lsb) 0x0102'0304'0506...0F10 (msb)
@gv = addrspace(1) externally_initialized global i128 21345817372864405881847059188222722561, align 16
; CHECK: .visible .global .align 16 .b8 gv[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};

@gv_fp128 = addrspace(1) externally_initialized global fp128 0xL33333333333333334004033333333333, align 16
; CHECK: .visible .global .align 16 .b8 gv_fp128[16] = {51, 51, 51, 51, 51, 51, 51, 51, 51, 51, 51, 51, 51, 3, 4, 64};

; Make sure that we do not overflow on large number of elements.
; CHECK: .visible .global .align 1 .b8 large_data[4831838208];
@large_data = global [4831838208 x i8] zeroinitializer
Loading