Skip to content

Commit 3af8b3f

Browse files
AdUhTkJmlanza
authored andcommitted
[CIR][CUDA] Handle clang builtin functions (#1496)
Clang relies on `llvm::Intrinsic::getOrInsertDeclaration` to handle functions marked as `ClangBuiltin` in TableGen. That function receives a `CodeGenModule*` so CIR can't use that. We need to re-implement parts of it.
1 parent 06cee96 commit 3af8b3f

File tree

3 files changed

+147
-43
lines changed

3 files changed

+147
-43
lines changed

clang/lib/CIR/CodeGen/CIRGenBuiltin.cpp

Lines changed: 79 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,46 @@ static bool isMemBuiltinOutOfBoundPossible(const clang::Expr *sizeArg,
467467
return size.ugt(dstSize);
468468
}
469469

470+
static mlir::Type
471+
decodeFixedType(ArrayRef<llvm::Intrinsic::IITDescriptor> &infos,
472+
mlir::MLIRContext *context) {
473+
using namespace llvm::Intrinsic;
474+
475+
IITDescriptor descriptor = infos.front();
476+
infos = infos.slice(1);
477+
478+
switch (descriptor.Kind) {
479+
case IITDescriptor::Void:
480+
return VoidType::get(context);
481+
case IITDescriptor::Integer:
482+
return IntType::get(context, descriptor.Integer_Width, /*signed=*/true);
483+
case IITDescriptor::Float:
484+
return SingleType::get(context);
485+
case IITDescriptor::Double:
486+
return DoubleType::get(context);
487+
default:
488+
llvm_unreachable("NYI");
489+
}
490+
}
491+
492+
// llvm::Intrinsics accepts only LLVMContext. We need to reimplement it here.
493+
static cir::FuncType getIntrinsicType(mlir::MLIRContext *context,
494+
llvm::Intrinsic::ID id) {
495+
using namespace llvm::Intrinsic;
496+
497+
SmallVector<IITDescriptor, 8> table;
498+
getIntrinsicInfoTableEntries(id, table);
499+
500+
ArrayRef<IITDescriptor> tableRef = table;
501+
mlir::Type resultTy = decodeFixedType(tableRef, context);
502+
503+
SmallVector<mlir::Type, 8> argTypes;
504+
while (!tableRef.empty())
505+
argTypes.push_back(decodeFixedType(tableRef, context));
506+
507+
return FuncType::get(argTypes, resultTy);
508+
}
509+
470510
RValue CIRGenFunction::emitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
471511
const CallExpr *E,
472512
ReturnValueSlot ReturnValue) {
@@ -2526,25 +2566,58 @@ RValue CIRGenFunction::emitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
25262566

25272567
// See if we have a target specific intrinsic.
25282568
std::string Name = getContext().BuiltinInfo.getName(BuiltinID);
2529-
Intrinsic::ID IntrinsicID = Intrinsic::not_intrinsic;
2569+
Intrinsic::ID intrinsicID = Intrinsic::not_intrinsic;
25302570
StringRef Prefix =
25312571
llvm::Triple::getArchTypePrefix(getTarget().getTriple().getArch());
25322572
if (!Prefix.empty()) {
2533-
IntrinsicID = Intrinsic::getIntrinsicForClangBuiltin(Prefix.data(), Name);
2573+
intrinsicID = Intrinsic::getIntrinsicForClangBuiltin(Prefix.data(), Name);
25342574
// NOTE we don't need to perform a compatibility flag check here since the
25352575
// intrinsics are declared in Builtins*.def via LANGBUILTIN which filter the
25362576
// MS builtins via ALL_MS_LANGUAGES and are filtered earlier.
2537-
if (IntrinsicID == Intrinsic::not_intrinsic)
2538-
IntrinsicID = Intrinsic::getIntrinsicForMSBuiltin(Prefix.data(), Name);
2577+
if (intrinsicID == Intrinsic::not_intrinsic)
2578+
intrinsicID = Intrinsic::getIntrinsicForMSBuiltin(Prefix.data(), Name);
25392579
}
25402580

2541-
if (IntrinsicID != Intrinsic::not_intrinsic) {
2581+
if (intrinsicID != Intrinsic::not_intrinsic) {
25422582
unsigned iceArguments = 0;
25432583
ASTContext::GetBuiltinTypeError error;
25442584
getContext().GetBuiltinType(BuiltinID, error, &iceArguments);
25452585
assert(error == ASTContext::GE_None && "Should not codegen an error");
2546-
if (iceArguments > 0)
2586+
2587+
llvm::StringRef name = llvm::Intrinsic::getName(intrinsicID);
2588+
// cir::LLVMIntrinsicCallOp expects intrinsic name to not have prefix
2589+
// "llvm." For example, `llvm.nvvm.barrier0` should be passed as
2590+
// `nvvm.barrier0`.
2591+
if (!name.consume_front("llvm."))
2592+
assert(false && "bad intrinsic name!");
2593+
2594+
cir::FuncType intrinsicType =
2595+
getIntrinsicType(&getMLIRContext(), intrinsicID);
2596+
2597+
SmallVector<mlir::Value> args;
2598+
for (unsigned i = 0; i < E->getNumArgs(); i++) {
2599+
mlir::Value arg = emitScalarOrConstFoldImmArg(iceArguments, i, E);
2600+
mlir::Type argType = arg.getType();
2601+
if (argType != intrinsicType.getInput(i))
2602+
llvm_unreachable("NYI");
2603+
2604+
args.push_back(arg);
2605+
}
2606+
2607+
auto intrinsicCall = builder.create<cir::LLVMIntrinsicCallOp>(
2608+
getLoc(E->getExprLoc()), builder.getStringAttr(name),
2609+
intrinsicType.getReturnType(), args);
2610+
2611+
mlir::Type builtinReturnType = intrinsicCall.getResult().getType();
2612+
mlir::Type retTy = intrinsicType.getReturnType();
2613+
2614+
if (builtinReturnType != retTy)
25472615
llvm_unreachable("NYI");
2616+
2617+
if (isa<cir::VoidType>(retTy))
2618+
return RValue::get(nullptr);
2619+
2620+
return RValue::get(intrinsicCall.getResult());
25482621
}
25492622

25502623
// Some target-specific builtins can have aggregate return values, e.g.

clang/lib/CIR/CodeGen/CIRGenBuiltinNVPTX.cpp

Lines changed: 4 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -40,44 +40,11 @@ mlir::Value CIRGenFunction::emitNVPTXBuiltinExpr(unsigned builtinId,
4040
.getResult();
4141
};
4242
switch (builtinId) {
43-
case NVPTX::BI__nvvm_read_ptx_sreg_tid_x:
44-
return getIntrinsic("nvvm.read.ptx.sreg.tid.x");
45-
case NVPTX::BI__nvvm_read_ptx_sreg_tid_y:
46-
return getIntrinsic("nvvm.read.ptx.sreg.tid.y");
47-
case NVPTX::BI__nvvm_read_ptx_sreg_tid_z:
48-
return getIntrinsic("nvvm.read.ptx.sreg.tid.z");
49-
case NVPTX::BI__nvvm_read_ptx_sreg_tid_w:
50-
return getIntrinsic("nvvm.read.ptx.sreg.tid.w");
51-
52-
case NVPTX::BI__nvvm_read_ptx_sreg_ntid_x:
53-
return getIntrinsic("nvvm.read.ptx.sreg.ntid.x");
54-
case NVPTX::BI__nvvm_read_ptx_sreg_ntid_y:
55-
return getIntrinsic("nvvm.read.ptx.sreg.ntid.y");
56-
case NVPTX::BI__nvvm_read_ptx_sreg_ntid_z:
57-
return getIntrinsic("nvvm.read.ptx.sreg.ntid.z");
58-
case NVPTX::BI__nvvm_read_ptx_sreg_ntid_w:
59-
return getIntrinsic("nvvm.read.ptx.sreg.ntid.w");
60-
61-
case NVPTX::BI__nvvm_read_ptx_sreg_ctaid_x:
62-
return getIntrinsic("nvvm.read.ptx.sreg.ctaid.x");
63-
case NVPTX::BI__nvvm_read_ptx_sreg_ctaid_y:
64-
return getIntrinsic("nvvm.read.ptx.sreg.ctaid.y");
65-
case NVPTX::BI__nvvm_read_ptx_sreg_ctaid_z:
66-
return getIntrinsic("nvvm.read.ptx.sreg.ctaid.z");
67-
case NVPTX::BI__nvvm_read_ptx_sreg_ctaid_w:
68-
return getIntrinsic("nvvm.read.ptx.sreg.ctaid.w");
69-
70-
case NVPTX::BI__nvvm_read_ptx_sreg_nctaid_x:
71-
return getIntrinsic("nvvm.read.ptx.sreg.nctaid.x");
72-
case NVPTX::BI__nvvm_read_ptx_sreg_nctaid_y:
73-
return getIntrinsic("nvvm.read.ptx.sreg.nctaid.y");
74-
case NVPTX::BI__nvvm_read_ptx_sreg_nctaid_z:
75-
return getIntrinsic("nvvm.read.ptx.sreg.nctaid.z");
76-
case NVPTX::BI__nvvm_read_ptx_sreg_nctaid_w:
77-
return getIntrinsic("nvvm.read.ptx.sreg.nctaid.w");
78-
7943
default:
80-
llvm_unreachable("NYI");
44+
// Returning nullptr means the intrinsic is not implemented.
45+
// This will be checked in `emitBuiltinExpr`, and will cause clang to output
46+
// "unsupported builtin" diagnostics.
47+
return nullptr;
8148
}
8249
}
8350

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
#include "../Inputs/cuda.h"
2+
3+
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -fclangir \
4+
// RUN: -fcuda-is-device -emit-cir -target-sdk-version=12.3 \
5+
// RUN: %s -o %t.cir
6+
// RUN: FileCheck --check-prefix=CIR --input-file=%t.cir %s
7+
8+
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -fclangir \
9+
// RUN: -fcuda-is-device -emit-llvm -target-sdk-version=12.3 \
10+
// RUN: %s -o %t.ll
11+
// RUN: FileCheck --check-prefix=LLVM --input-file=%t.ll %s
12+
13+
__device__ void builtins() {
14+
float f1, f2;
15+
double d1, d2;
16+
17+
// CIR: cir.llvm.intrinsic "nvvm.fmax.f" {{.*}} : (!cir.float, !cir.float) -> !cir.float
18+
// LLVM: call float @llvm.nvvm.fmax.f(float {{.*}}, float {{.*}})
19+
float t1 = __nvvm_fmax_f(f1, f2);
20+
// CIR: cir.llvm.intrinsic "nvvm.fmin.f" {{.*}} : (!cir.float, !cir.float) -> !cir.float
21+
// LLVM: call float @llvm.nvvm.fmin.f(float {{.*}}, float {{.*}})
22+
float t2 = __nvvm_fmin_f(f1, f2);
23+
// CIR: cir.llvm.intrinsic "nvvm.sqrt.rn.f" {{.*}} : (!cir.float) -> !cir.float
24+
// LLVM: call float @llvm.nvvm.sqrt.rn.f(float {{.*}})
25+
float t3 = __nvvm_sqrt_rn_f(f1);
26+
// CIR: cir.llvm.intrinsic "nvvm.rcp.rn.f" {{.*}} : (!cir.float) -> !cir.float
27+
// LLVM: call float @llvm.nvvm.rcp.rn.f(float {{.*}})
28+
float t4 = __nvvm_rcp_rn_f(f2);
29+
// CIR: cir.llvm.intrinsic "nvvm.add.rn.f" {{.*}} : (!cir.float, !cir.float) -> !cir.float
30+
// LLVM: call float @llvm.nvvm.add.rn.f(float {{.*}}, float {{.*}})
31+
float t5 = __nvvm_add_rn_f(f1, f2);
32+
33+
// CIR: cir.llvm.intrinsic "nvvm.fmax.d" {{.*}} : (!cir.double, !cir.double) -> !cir.double
34+
// LLVM: call double @llvm.nvvm.fmax.d(double {{.*}}, double {{.*}})
35+
double td1 = __nvvm_fmax_d(d1, d2);
36+
// CIR: cir.llvm.intrinsic "nvvm.fmin.d" {{.*}} : (!cir.double, !cir.double) -> !cir.double
37+
// LLVM: call double @llvm.nvvm.fmin.d(double {{.*}}, double {{.*}})
38+
double td2 = __nvvm_fmin_d(d1, d2);
39+
// CIR: cir.llvm.intrinsic "nvvm.sqrt.rn.d" {{.*}} : (!cir.double) -> !cir.double
40+
// LLVM: call double @llvm.nvvm.sqrt.rn.d(double {{.*}})
41+
double td3 = __nvvm_sqrt_rn_d(d1);
42+
// CIR: cir.llvm.intrinsic "nvvm.rcp.rn.d" {{.*}} : (!cir.double) -> !cir.double
43+
// LLVM: call double @llvm.nvvm.rcp.rn.d(double {{.*}})
44+
double td4 = __nvvm_rcp_rn_d(d2);
45+
46+
int i1, i2;
47+
48+
// CIR: cir.llvm.intrinsic "nvvm.mulhi.i" {{.*}} : (!s32i, !s32i) -> !s32i
49+
// LLVM: call i32 @llvm.nvvm.mulhi.i(i32 {{.*}}, i32 {{.*}})
50+
int ti1 = __nvvm_mulhi_i(i1, i2);
51+
52+
// CIR: cir.llvm.intrinsic "nvvm.membar.cta"
53+
// LLVM: call void @llvm.nvvm.membar.cta()
54+
__nvvm_membar_cta();
55+
// CIR: cir.llvm.intrinsic "nvvm.membar.gl"
56+
// LLVM: call void @llvm.nvvm.membar.gl()
57+
__nvvm_membar_gl();
58+
// CIR: cir.llvm.intrinsic "nvvm.membar.sys"
59+
// LLVM: call void @llvm.nvvm.membar.sys()
60+
__nvvm_membar_sys();
61+
// CIR: cir.llvm.intrinsic "nvvm.barrier0"
62+
// LLVM: call void @llvm.nvvm.barrier0()
63+
__syncthreads();
64+
}

0 commit comments

Comments
 (0)