Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions llvm/include/llvm/IR/Intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ class LLVMContext;
class Module;
class AttributeList;
class AttributeSet;
class raw_ostream;
class Constant;

/// This namespace contains an enum with a value for every intrinsic/builtin
/// function known by LLVM. The enum values are returned by
Expand Down Expand Up @@ -81,6 +83,9 @@ namespace Intrinsic {
/// Returns true if the intrinsic can be overloaded.
LLVM_ABI bool isOverloaded(ID id);

/// Returns true if the intrinsic has pretty printed immediate arguments.
LLVM_ABI bool hasPrettyPrintedArgs(ID id);

/// isTargetIntrinsic - Returns true if IID is an intrinsic specific to a
/// certain target. If it is a generic intrinsic false is returned.
LLVM_ABI bool isTargetIntrinsic(ID IID);
Expand Down Expand Up @@ -284,6 +289,10 @@ namespace Intrinsic {
/// N.
LLVM_ABI Intrinsic::ID getDeinterleaveIntrinsicID(unsigned Factor);

/// Print the argument info for the arguments with ArgInfo.
LLVM_ABI void printImmArg(ID IID, unsigned ArgIdx, raw_ostream &OS,
const Constant *ImmArgVal);

} // namespace Intrinsic

} // namespace llvm
Expand Down
19 changes: 19 additions & 0 deletions llvm/include/llvm/IR/Intrinsics.td
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,25 @@ class Range<AttrIndex idx, int lower, int upper> : IntrinsicProperty {
int Upper = upper;
}

// ArgProperty - Base class for argument properties that can be specified in ArgInfo.
class ArgProperty;

// ArgName - Specifies the name of an argument for pretty-printing.
class ArgName<string name> : ArgProperty {
string Name = name;
}

// ImmArgPrinter - Specifies a custom printer function for immediate arguments.
class ImmArgPrinter<string funcname> : ArgProperty {
string FuncName = funcname;
}

// ArgInfo - The specified argument has properties defined by a list of ArgProperty objects.
class ArgInfo<ArgIndex idx, list<ArgProperty> arg_properties> : IntrinsicProperty {
int ArgNo = idx.Value;
list<ArgProperty> Properties = arg_properties;
}

def IntrNoReturn : IntrinsicProperty;

// Applied by default.
Expand Down
16 changes: 15 additions & 1 deletion llvm/include/llvm/IR/IntrinsicsNVVM.td
Original file line number Diff line number Diff line change
Expand Up @@ -2947,7 +2947,14 @@ foreach sp = [0, 1] in {
defvar nargs = !size(args);
defvar scale_d_imm = ArgIndex<!sub(nargs, 1)>;
defvar scale_d_imm_range = [ImmArg<scale_d_imm>, Range<scale_d_imm, 0, 16>];
defvar intrinsic_properties = !listconcat(

// Check if this is the specific llvm.nvvm.tcgen05.mma.tensor intrinsic.
defvar is_target_intrinsic = !and(!eq(sp, 0),
!eq(space, "tensor"),
!eq(scale_d, 0),
!eq(ashift, 0));

defvar base_properties = !listconcat(
mma.common_intr_props,
!if(!eq(scale_d, 1), scale_d_imm_range, []),
[Range<ArgIndex<nargs>, 0, !if(!eq(scale_d, 1), 2, 4)>, // kind
Expand All @@ -2957,6 +2964,13 @@ foreach sp = [0, 1] in {
]
);

defvar intrinsic_properties = !if(is_target_intrinsic,
!listconcat(base_properties,
[ArgInfo<ArgIndex<nargs>, [ArgName<"kind">, ImmArgPrinter<"printTcgen05MMAKind">]>,
ArgInfo<ArgIndex<!add(nargs, 1)>, [ArgName<"cta_group">]>,
ArgInfo<ArgIndex<!add(nargs, 2)>, [ArgName<"collector">, ImmArgPrinter<"printTcgen05CollectorUsageOp">]>]),
base_properties);

def mma.record_name:
DefaultAttrsIntrinsicFlags<[], args, flags, intrinsic_properties,
mma.intr_name>;
Expand Down
48 changes: 48 additions & 0 deletions llvm/include/llvm/IR/NVVMIntrinsicUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@
#include <stdint.h>

#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/IntrinsicsNVPTX.h"
#include "llvm/Support/raw_ostream.h"

namespace llvm {
namespace nvvm {
Expand Down Expand Up @@ -659,6 +662,51 @@ inline APFloat::roundingMode GetFMARoundingMode(Intrinsic::ID IntrinsicID) {
llvm_unreachable("Invalid FP instrinsic rounding mode for NVVM fma");
}

inline void printTcgen05MMAKind(raw_ostream &OS, const Constant *ImmArgVal) {
if (const ConstantInt *CI = dyn_cast<ConstantInt>(ImmArgVal)) {
uint64_t Val = CI->getZExtValue();
switch (static_cast<Tcgen05MMAKind>(Val)) {
case Tcgen05MMAKind::F16:
OS << "f16";
return;
case Tcgen05MMAKind::TF32:
OS << "tf32";
return;
case Tcgen05MMAKind::F8F6F4:
OS << "f8f6f4";
return;
case Tcgen05MMAKind::I8:
OS << "i8";
return;
}
}
llvm_unreachable(
"printTcgen05MMAKind called with invalid value for immediate argument");
}

inline void printTcgen05CollectorUsageOp(raw_ostream &OS,
const Constant *ImmArgVal) {
if (const ConstantInt *CI = dyn_cast<ConstantInt>(ImmArgVal)) {
uint64_t Val = CI->getZExtValue();
switch (static_cast<Tcgen05CollectorUsageOp>(Val)) {
case Tcgen05CollectorUsageOp::DISCARD:
OS << "discard";
return;
case Tcgen05CollectorUsageOp::LASTUSE:
OS << "lastuse";
return;
case Tcgen05CollectorUsageOp::FILL:
OS << "fill";
return;
case Tcgen05CollectorUsageOp::USE:
OS << "use";
return;
}
}
llvm_unreachable("printTcgen05CollectorUsageOp called with invalid value for "
"immediate argument");
}

} // namespace nvvm
} // namespace llvm
#endif // LLVM_IR_NVVMINTRINSICUTILS_H
41 changes: 33 additions & 8 deletions llvm/lib/IR/AsmWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Metadata.h"
#include "llvm/IR/Module.h"
Expand Down Expand Up @@ -4581,12 +4582,38 @@ void AssemblyWriter::printInstruction(const Instruction &I) {
Out << ' ';
writeOperand(Operand, false);
Out << '(';
bool HasPrettyPrintedArgs =
isa<IntrinsicInst>(CI) &&
Intrinsic::hasPrettyPrintedArgs(CI->getIntrinsicID());

ListSeparator LS;
for (unsigned op = 0, Eop = CI->arg_size(); op < Eop; ++op) {
Out << LS;
writeParamOperand(CI->getArgOperand(op), PAL.getParamAttrs(op));
Function *CalledFunc = CI->getCalledFunction();
auto PrintArgComment = [&](unsigned ArgNo) {
const Constant *ConstArg = dyn_cast<Constant>(CI->getArgOperand(ArgNo));
if (!ConstArg)
return;
std::string ArgComment;
raw_string_ostream ArgCommentStream(ArgComment);
Intrinsic::ID IID = CalledFunc->getIntrinsicID();
Intrinsic::printImmArg(IID, ArgNo, ArgCommentStream, ConstArg);
if (ArgComment.empty())
return;
Out << "/* " << ArgComment << " */ ";
};
if (HasPrettyPrintedArgs) {
for (unsigned ArgNo = 0, NumArgs = CI->arg_size(); ArgNo < NumArgs;
++ArgNo) {
Out << LS;
PrintArgComment(ArgNo);
writeParamOperand(CI->getArgOperand(ArgNo), PAL.getParamAttrs(ArgNo));
}
} else {
for (unsigned ArgNo = 0, NumArgs = CI->arg_size(); ArgNo < NumArgs;
++ArgNo) {
Out << LS;
writeParamOperand(CI->getArgOperand(ArgNo), PAL.getParamAttrs(ArgNo));
}
}

// Emit an ellipsis if this is a musttail call in a vararg function. This
// is only to aid readability, musttail calls forward varargs by default.
if (CI->isMustTailCall() && CI->getParent() &&
Expand Down Expand Up @@ -5010,12 +5037,10 @@ void AssemblyWriter::printUseLists(const Function *F) {
//===----------------------------------------------------------------------===//

void Function::print(raw_ostream &ROS, AssemblyAnnotationWriter *AAW,
bool ShouldPreserveUseListOrder,
bool IsForDebug) const {
bool ShouldPreserveUseListOrder, bool IsForDebug) const {
SlotTracker SlotTable(this->getParent());
formatted_raw_ostream OS(ROS);
AssemblyWriter W(OS, SlotTable, this->getParent(), AAW,
IsForDebug,
AssemblyWriter W(OS, SlotTable, this->getParent(), AAW, IsForDebug,
ShouldPreserveUseListOrder);
W.printFunction(this);
}
Expand Down
11 changes: 11 additions & 0 deletions llvm/lib/IR/Intrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include "llvm/IR/IntrinsicsX86.h"
#include "llvm/IR/IntrinsicsXCore.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/NVVMIntrinsicUtils.h"
#include "llvm/IR/Type.h"

using namespace llvm;
Expand Down Expand Up @@ -601,6 +602,12 @@ bool Intrinsic::isOverloaded(ID id) {
#undef GET_INTRINSIC_OVERLOAD_TABLE
}

bool Intrinsic::hasPrettyPrintedArgs(ID id){
#define GET_INTRINSIC_PRETTY_PRINT_TABLE
#include "llvm/IR/IntrinsicImpl.inc"
#undef GET_INTRINSIC_PRETTY_PRINT_TABLE
}

/// Table of per-target intrinsic name tables.
#define GET_INTRINSIC_TARGET_DATA
#include "llvm/IR/IntrinsicImpl.inc"
Expand Down Expand Up @@ -1142,3 +1149,7 @@ Intrinsic::ID Intrinsic::getDeinterleaveIntrinsicID(unsigned Factor) {
assert(Factor >= 2 && Factor <= 8 && "Unexpected factor");
return InterleaveIntrinsics[Factor - 2].Deinterleave;
}

#define GET_INTRINSIC_PRETTY_PRINT_ARGUMENTS
#include "llvm/IR/IntrinsicImpl.inc"
#undef GET_INTRINSIC_PRETTY_PRINT_ARGUMENTS
50 changes: 50 additions & 0 deletions llvm/test/CodeGen/NVPTX/tcgen05-mma-tensor-formatted.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
; NOTE: This sample test demonstrates the pretty print feature for NVPTX intrinsics
; RUN: llvm-as < %s | llvm-dis | FileCheck %s

target triple = "nvptx64-nvidia-cuda"

define void @tcgen05_mma_fp16_cta1(ptr addrspace(6) %dtmem, ptr addrspace(6) %atensor, i64 %b, i32 %idesc, i1 %enable_inp_d) {
; CHECK-LABEL: define void @tcgen05_mma_fp16_cta1(
; CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) %dtmem, ptr addrspace(6) %atensor, i64 %b, i32 %idesc, i1 %enable_inp_d, /* kind=f16 */ i32 0, /* cta_group= */ i32 1, /* collector=discard */ i32 0)
call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) %dtmem, ptr addrspace(6) %atensor, i64 %b, i32 %idesc, i1 %enable_inp_d, i32 0, i32 1, i32 0)

; CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) %dtmem, ptr addrspace(6) %atensor, i64 %b, i32 %idesc, i1 %enable_inp_d, /* kind=f16 */ i32 0, /* cta_group= */ i32 1, /* collector=lastuse */ i32 1)
call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) %dtmem, ptr addrspace(6) %atensor, i64 %b, i32 %idesc, i1 %enable_inp_d, i32 0, i32 1, i32 1)

; CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) %dtmem, ptr addrspace(6) %atensor, i64 %b, i32 %idesc, i1 %enable_inp_d, /* kind=f16 */ i32 0, /* cta_group= */ i32 1, /* collector=fill */ i32 2)
call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) %dtmem, ptr addrspace(6) %atensor, i64 %b, i32 %idesc, i1 %enable_inp_d, i32 0, i32 1, i32 2)

; CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) %dtmem, ptr addrspace(6) %atensor, i64 %b, i32 %idesc, i1 %enable_inp_d, /* kind=f16 */ i32 0, /* cta_group= */ i32 1, /* collector=use */ i32 3)
call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) %dtmem, ptr addrspace(6) %atensor, i64 %b, i32 %idesc, i1 %enable_inp_d, i32 0, i32 1, i32 3)

ret void
}

define void @tcgen05_mma_f8f6f4_cta2(ptr addrspace(6) %dtmem, ptr addrspace(6) %atensor, i64 %b, i32 %idesc, i1 %enable_inp_d) {
; CHECK-LABEL: define void @tcgen05_mma_f8f6f4_cta2(
; CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) %dtmem, ptr addrspace(6) %atensor, i64 %b, i32 %idesc, i1 %enable_inp_d, /* kind=f8f6f4 */ i32 2, /* cta_group= */ i32 2, /* collector=discard */ i32 0)
call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) %dtmem, ptr addrspace(6) %atensor, i64 %b, i32 %idesc, i1 %enable_inp_d, i32 2, i32 2, i32 0)

; CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) %dtmem, ptr addrspace(6) %atensor, i64 %b, i32 %idesc, i1 %enable_inp_d, /* kind=f8f6f4 */ i32 2, /* cta_group= */ i32 2, /* collector=lastuse */ i32 1)
call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) %dtmem, ptr addrspace(6) %atensor, i64 %b, i32 %idesc, i1 %enable_inp_d, i32 2, i32 2, i32 1)

; CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) %dtmem, ptr addrspace(6) %atensor, i64 %b, i32 %idesc, i1 %enable_inp_d, /* kind=f8f6f4 */ i32 2, /* cta_group= */ i32 2, /* collector=fill */ i32 2)
call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) %dtmem, ptr addrspace(6) %atensor, i64 %b, i32 %idesc, i1 %enable_inp_d, i32 2, i32 2, i32 2)

; CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) %dtmem, ptr addrspace(6) %atensor, i64 %b, i32 %idesc, i1 %enable_inp_d, /* kind=f8f6f4 */ i32 2, /* cta_group= */ i32 2, /* collector=use */ i32 3)
call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) %dtmem, ptr addrspace(6) %atensor, i64 %b, i32 %idesc, i1 %enable_inp_d, i32 2, i32 2, i32 3)

ret void
}

; This test verifies that printImmArg is safe to call on all constant arguments, but only prints comments for arguments that have pretty printing configured.
define void @test_mixed_constants_edge_case(ptr addrspace(6) %dtmem, ptr addrspace(6) %atensor) {
; CHECK-LABEL: define void @test_mixed_constants_edge_case(
; CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) %dtmem, ptr addrspace(6) %atensor, i64 42, i32 100, i1 true, /* kind=i8 */ i32 3, /* cta_group= */ i32 1, /* collector=discard */ i32 0)
call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) %dtmem, ptr addrspace(6) %atensor, i64 42, i32 100, i1 true, i32 3, i32 1, i32 0)

ret void
}

declare void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6), ptr addrspace(6), i64, i32, i1, i32, i32, i32)
71 changes: 71 additions & 0 deletions llvm/test/TableGen/intrinsic-arginfo.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// RUN: llvm-tblgen -gen-intrinsic-impl -I %p/../../include %s | FileCheck %s

// Test ArgInfo property for pretty-printing intrinsic arguments.
// This test verifies that TableGen generates the correct pretty-printing code
// for intrinsics that use the ArgInfo property.

include "llvm/IR/Intrinsics.td"

// Simple intrinsic with two arguments that have ArgInfo.
def int_dummy_foo_bar : DefaultAttrsIntrinsic<
[llvm_i32_ty],
[llvm_i32_ty, // data
llvm_i32_ty, // mode
llvm_i32_ty], // stride
[IntrNoMem,
ImmArg<ArgIndex<1>>,
ArgInfo<ArgIndex<1>, [ArgName<"mode">, ImmArgPrinter<"printDummyMode">]>,
ArgInfo<ArgIndex<2>, [ArgName<"stride">]>]>;

// A custom floating point add with rounding and sat mode.
def int_my_fadd_f32 : DefaultAttrsIntrinsic<
[llvm_float_ty],
[llvm_float_ty, // a
llvm_float_ty, // b
llvm_i32_ty, // rounding_mode
llvm_i1_ty], // saturation_mode
[IntrNoMem,
ImmArg<ArgIndex<2>>,
ImmArg<ArgIndex<3>>,
ArgInfo<ArgIndex<2>, [ArgName<"rounding_mode">, ImmArgPrinter<"printRoundingMode">]>,
ArgInfo<ArgIndex<3>, [ArgName<"saturation_mode">]>]>;

// CHECK: #ifdef GET_INTRINSIC_PRETTY_PRINT_TABLE
// CHECK-NEXT: static constexpr uint8_t PPTable[] = {

// CHECK: #endif // GET_INTRINSIC_PRETTY_PRINT_TABLE

// CHECK: #ifdef GET_INTRINSIC_PRETTY_PRINT_ARGUMENTS
// CHECK: void Intrinsic::printImmArg(ID IID, unsigned ArgIdx, raw_ostream &OS, const Constant *ImmArgVal) {

// CHECK: case dummy_foo_bar:
// CHECK-NEXT: switch (ArgIdx) {

// CHECK-NEXT: case 1:
// CHECK-NEXT: OS << "mode=";
// CHECK-NEXT: printDummyMode(OS, ImmArgVal);
// CHECK-NEXT: return;

// CHECK-NEXT: case 2:
// CHECK-NEXT: OS << "stride=";
// CHECK-NEXT: return;

// CHECK-NEXT: }
// CHECK-NEXT: break;

// CHECK: case my_fadd_f32:
// CHECK-NEXT: switch (ArgIdx) {

// CHECK-NEXT: case 2:
// CHECK-NEXT: OS << "rounding_mode=";
// CHECK-NEXT: printRoundingMode(OS, ImmArgVal);
// CHECK-NEXT: return;

// CHECK-NEXT: case 3:
// CHECK-NEXT: OS << "saturation_mode=";
// CHECK-NEXT: return;

// CHECK-NEXT: }
// CHECK-NEXT: break;

// CHECK: #endif // GET_INTRINSIC_PRETTY_PRINT_ARGUMENTS
Loading
Loading