Skip to content

Commit be7bab5

Browse files
advay168anominos
authored andcommitted
[CIR][CUDA] Support builtin CUDA variables (#1458)
This PR adds support for compiling builtin variables like `threadIdx` down to the appropriate intrinsic. --------- Co-authored-by: Aidan Wong <[email protected]> Co-authored-by: anominos <[email protected]>
1 parent 444da4f commit be7bab5

File tree

7 files changed

+274
-3
lines changed

7 files changed

+274
-3
lines changed

clang/lib/CIR/CodeGen/CIRGenBuiltin.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2516,7 +2516,12 @@ RValue CIRGenFunction::emitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
25162516
}
25172517

25182518
if (IntrinsicID != Intrinsic::not_intrinsic) {
2519-
llvm_unreachable("NYI");
2519+
unsigned iceArguments = 0;
2520+
ASTContext::GetBuiltinTypeError error;
2521+
getContext().GetBuiltinType(BuiltinID, error, &iceArguments);
2522+
assert(error == ASTContext::GE_None && "Should not codegen an error");
2523+
if (iceArguments > 0)
2524+
llvm_unreachable("NYI");
25202525
}
25212526

25222527
// Some target-specific builtins can have aggregate return values, e.g.
@@ -2614,7 +2619,7 @@ static mlir::Value emitTargetArchBuiltinExpr(CIRGenFunction *CGF,
26142619
llvm_unreachable("NYI");
26152620
case llvm::Triple::nvptx:
26162621
case llvm::Triple::nvptx64:
2617-
llvm_unreachable("NYI");
2622+
return CGF->emitNVPTXBuiltinExpr(BuiltinID, E);
26182623
case llvm::Triple::wasm32:
26192624
case llvm::Triple::wasm64:
26202625
llvm_unreachable("NYI");
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
//===---- CIRGenBuiltinX86.cpp - Emit CIR for X86 builtins ----------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This contains code to emit NVPTX Builtin calls.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "CIRGenCXXABI.h"
14+
#include "CIRGenCall.h"
15+
#include "CIRGenFunction.h"
16+
#include "CIRGenModule.h"
17+
#include "TargetInfo.h"
18+
#include "clang/CIR/MissingFeatures.h"
19+
20+
#include "mlir/Dialect/Func/IR/FuncOps.h"
21+
#include "mlir/IR/Value.h"
22+
#include "clang/AST/GlobalDecl.h"
23+
#include "clang/Basic/Builtins.h"
24+
#include "clang/Basic/TargetBuiltins.h"
25+
#include "clang/CIR/Dialect/IR/CIRDialect.h"
26+
#include "clang/CIR/Dialect/IR/CIRTypes.h"
27+
#include "llvm/Support/ErrorHandling.h"
28+
29+
using namespace clang;
30+
using namespace clang::CIRGen;
31+
using namespace cir;
32+
33+
mlir::Value CIRGenFunction::emitNVPTXBuiltinExpr(unsigned builtinId,
34+
const CallExpr *expr) {
35+
auto getIntrinsic = [&](const char *name) {
36+
mlir::Type intTy = cir::IntType::get(&getMLIRContext(), 32, false);
37+
return builder
38+
.create<cir::LLVMIntrinsicCallOp>(getLoc(expr->getExprLoc()),
39+
builder.getStringAttr(name), intTy)
40+
.getResult();
41+
};
42+
switch (builtinId) {
43+
case NVPTX::BI__nvvm_read_ptx_sreg_tid_x:
44+
return getIntrinsic("nvvm.read.ptx.sreg.tid.x");
45+
case NVPTX::BI__nvvm_read_ptx_sreg_tid_y:
46+
return getIntrinsic("nvvm.read.ptx.sreg.tid.y");
47+
case NVPTX::BI__nvvm_read_ptx_sreg_tid_z:
48+
return getIntrinsic("nvvm.read.ptx.sreg.tid.z");
49+
case NVPTX::BI__nvvm_read_ptx_sreg_tid_w:
50+
return getIntrinsic("nvvm.read.ptx.sreg.tid.w");
51+
52+
case NVPTX::BI__nvvm_read_ptx_sreg_ntid_x:
53+
return getIntrinsic("nvvm.read.ptx.sreg.ntid.x");
54+
case NVPTX::BI__nvvm_read_ptx_sreg_ntid_y:
55+
return getIntrinsic("nvvm.read.ptx.sreg.ntid.y");
56+
case NVPTX::BI__nvvm_read_ptx_sreg_ntid_z:
57+
return getIntrinsic("nvvm.read.ptx.sreg.ntid.z");
58+
case NVPTX::BI__nvvm_read_ptx_sreg_ntid_w:
59+
return getIntrinsic("nvvm.read.ptx.sreg.ntid.w");
60+
61+
case NVPTX::BI__nvvm_read_ptx_sreg_ctaid_x:
62+
return getIntrinsic("nvvm.read.ptx.sreg.ctaid.x");
63+
case NVPTX::BI__nvvm_read_ptx_sreg_ctaid_y:
64+
return getIntrinsic("nvvm.read.ptx.sreg.ctaid.y");
65+
case NVPTX::BI__nvvm_read_ptx_sreg_ctaid_z:
66+
return getIntrinsic("nvvm.read.ptx.sreg.ctaid.z");
67+
case NVPTX::BI__nvvm_read_ptx_sreg_ctaid_w:
68+
return getIntrinsic("nvvm.read.ptx.sreg.ctaid.w");
69+
70+
case NVPTX::BI__nvvm_read_ptx_sreg_nctaid_x:
71+
return getIntrinsic("nvvm.read.ptx.sreg.nctaid.x");
72+
case NVPTX::BI__nvvm_read_ptx_sreg_nctaid_y:
73+
return getIntrinsic("nvvm.read.ptx.sreg.nctaid.y");
74+
case NVPTX::BI__nvvm_read_ptx_sreg_nctaid_z:
75+
return getIntrinsic("nvvm.read.ptx.sreg.nctaid.z");
76+
case NVPTX::BI__nvvm_read_ptx_sreg_nctaid_w:
77+
return getIntrinsic("nvvm.read.ptx.sreg.nctaid.w");
78+
79+
default:
80+
llvm_unreachable("NYI");
81+
}
82+
}

clang/lib/CIR/CodeGen/CIRGenExpr.cpp

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3313,3 +3313,92 @@ LValue CIRGenFunction::emitPredefinedLValue(const PredefinedExpr *E) {
33133313

33143314
return emitStringLiteralLValue(SL);
33153315
}
3316+
3317+
namespace {
3318+
struct LValueOrRValue {
3319+
LValue lv;
3320+
RValue rv;
3321+
};
3322+
3323+
LValueOrRValue emitPseudoObjectExpr(CIRGenFunction &cgf,
3324+
const PseudoObjectExpr *expr,
3325+
bool forLValue, AggValueSlot slot) {
3326+
SmallVector<CIRGenFunction::OpaqueValueMappingData, 4> opaques;
3327+
3328+
// Find the result expression, if any.
3329+
const Expr *resultExpr = expr->getResultExpr();
3330+
LValueOrRValue result;
3331+
3332+
for (PseudoObjectExpr::const_semantics_iterator i = expr->semantics_begin(),
3333+
e = expr->semantics_end();
3334+
i != e; ++i) {
3335+
const Expr *semantic = *i;
3336+
3337+
// If this semantic expression is an opaque value, bind it
3338+
// to the result of its source expression.
3339+
if (const auto *ov = dyn_cast<OpaqueValueExpr>(semantic)) {
3340+
// Skip unique OVEs.
3341+
if (ov->isUnique()) {
3342+
assert(ov != resultExpr &&
3343+
"A unique OVE cannot be used as the result expression");
3344+
continue;
3345+
}
3346+
3347+
// If this is the result expression, we may need to evaluate
3348+
// directly into the slot.
3349+
using OVMA = CIRGenFunction::OpaqueValueMappingData;
3350+
OVMA opaqueData;
3351+
if (ov == resultExpr && ov->isPRValue() && !forLValue &&
3352+
CIRGenFunction::hasAggregateEvaluationKind(ov->getType())) {
3353+
cgf.emitAggExpr(ov->getSourceExpr(), slot);
3354+
LValue lv = cgf.makeAddrLValue(slot.getAddress(), ov->getType(),
3355+
AlignmentSource::Decl);
3356+
opaqueData = OVMA::bind(cgf, ov, lv);
3357+
result.rv = slot.asRValue();
3358+
3359+
// Otherwise, emit as normal.
3360+
} else {
3361+
opaqueData = OVMA::bind(cgf, ov, ov->getSourceExpr());
3362+
3363+
// If this is the result, also evaluate the result now.
3364+
if (ov == resultExpr) {
3365+
if (forLValue)
3366+
result.lv = cgf.emitLValue(ov);
3367+
else
3368+
result.rv = cgf.emitAnyExpr(ov, slot);
3369+
}
3370+
}
3371+
3372+
opaques.push_back(opaqueData);
3373+
3374+
// Otherwise, if the expression is the result, evaluate it
3375+
// and remember the result.
3376+
} else if (semantic == resultExpr) {
3377+
if (forLValue)
3378+
result.lv = cgf.emitLValue(semantic);
3379+
else
3380+
result.rv = cgf.emitAnyExpr(semantic, slot);
3381+
3382+
// Otherwise, evaluate the expression in an ignored context.
3383+
} else {
3384+
cgf.emitIgnoredExpr(semantic);
3385+
}
3386+
}
3387+
3388+
// Unbind all the opaques now.
3389+
for (auto &opaque : opaques)
3390+
opaque.unbind(cgf);
3391+
3392+
return result;
3393+
}
3394+
3395+
} // namespace
3396+
3397+
RValue CIRGenFunction::emitPseudoObjectRValue(const PseudoObjectExpr *expr,
3398+
AggValueSlot slot) {
3399+
return emitPseudoObjectExpr(*this, expr, false, slot).rv;
3400+
}
3401+
3402+
LValue CIRGenFunction::emitPseudoObjectLValue(const PseudoObjectExpr *expr) {
3403+
return emitPseudoObjectExpr(*this, expr, true, AggValueSlot::ignored()).lv;
3404+
}

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
232232
llvm_unreachable("NYI");
233233
}
234234
mlir::Value VisitPseudoObjectExpr(PseudoObjectExpr *E) {
235-
llvm_unreachable("NYI");
235+
return CGF.emitPseudoObjectRValue(E).getScalarVal();
236236
}
237237
mlir::Value VisitSYCLUniqueStableNameExpr(SYCLUniqueStableNameExpr *E) {
238238
llvm_unreachable("NYI");

clang/lib/CIR/CodeGen/CIRGenFunction.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1282,6 +1282,11 @@ class CIRGenFunction : public CIRGenTypeCache {
12821282
ConstantEmission tryEmitAsConstant(DeclRefExpr *refExpr);
12831283
ConstantEmission tryEmitAsConstant(const MemberExpr *ME);
12841284

1285+
RValue emitPseudoObjectRValue(const PseudoObjectExpr *expr,
1286+
AggValueSlot slot = AggValueSlot::ignored());
1287+
1288+
LValue emitPseudoObjectLValue(const PseudoObjectExpr *expr);
1289+
12851290
/// Emit the computation of the specified expression of scalar type,
12861291
/// ignoring the result.
12871292
mlir::Value emitScalarExpr(const clang::Expr *E);
@@ -1471,6 +1476,7 @@ class CIRGenFunction : public CIRGenTypeCache {
14711476
mlir::Value emitAArch64SVEBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
14721477
mlir::Value emitAArch64SMEBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
14731478
mlir::Value emitX86BuiltinExpr(unsigned BuiltinID, const CallExpr *E);
1479+
mlir::Value emitNVPTXBuiltinExpr(unsigned builtinID, const CallExpr *expr);
14741480

14751481
/// Given an expression with a pointer type, emit the value and compute our
14761482
/// best estimate of the alignment of the pointee.

clang/lib/CIR/CodeGen/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ add_clang_library(clangCIR
1313
CIRGenBuiltin.cpp
1414
CIRGenBuiltinAArch64.cpp
1515
CIRGenBuiltinX86.cpp
16+
CIRGenBuiltinNVPTX.cpp
1617
CIRGenCXX.cpp
1718
CIRGenCXXABI.cpp
1819
CIRGenCall.cpp
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -fclangir \
2+
// RUN: -fcuda-is-device -emit-llvm -o - %s \
3+
// RUN: | FileCheck --check-prefix=LLVM %s
4+
5+
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -fclangir \
6+
// RUN: -fcuda-is-device -emit-cir -o - %s \
7+
// RUN: | FileCheck --check-prefix=CIR %s
8+
9+
#include "__clang_cuda_builtin_vars.h"
10+
11+
// LLVM: define{{.*}} void @_Z6kernelPi(ptr %0)
12+
__attribute__((global))
13+
void kernel(int *out) {
14+
int i = 0;
15+
16+
out[i++] = threadIdx.x;
17+
// CIR: cir.func linkonce_odr @_ZN26__cuda_builtin_threadIdx_t17__fetch_builtin_xEv()
18+
// CIR: cir.llvm.intrinsic "nvvm.read.ptx.sreg.tid.x"
19+
// LLVM: call {{.*}} i32 @llvm.nvvm.read.ptx.sreg.tid.x()
20+
21+
out[i++] = threadIdx.y;
22+
// CIR: cir.func linkonce_odr @_ZN26__cuda_builtin_threadIdx_t17__fetch_builtin_yEv()
23+
// CIR: cir.llvm.intrinsic "nvvm.read.ptx.sreg.tid.y"
24+
// LLVM: call {{.*}} i32 @llvm.nvvm.read.ptx.sreg.tid.y()
25+
26+
out[i++] = threadIdx.z;
27+
// CIR: cir.func linkonce_odr @_ZN26__cuda_builtin_threadIdx_t17__fetch_builtin_zEv()
28+
// CIR: cir.llvm.intrinsic "nvvm.read.ptx.sreg.tid.z"
29+
// LLVM: call {{.*}} i32 @llvm.nvvm.read.ptx.sreg.tid.z()
30+
31+
32+
out[i++] = blockIdx.x;
33+
// CIR: cir.func linkonce_odr @_ZN25__cuda_builtin_blockIdx_t17__fetch_builtin_xEv()
34+
// CIR: cir.llvm.intrinsic "nvvm.read.ptx.sreg.ctaid.x"
35+
// LLVM: call {{.*}} i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
36+
37+
out[i++] = blockIdx.y;
38+
// CIR: cir.func linkonce_odr @_ZN25__cuda_builtin_blockIdx_t17__fetch_builtin_yEv()
39+
// CIR: cir.llvm.intrinsic "nvvm.read.ptx.sreg.ctaid.y"
40+
// LLVM: call {{.*}} i32 @llvm.nvvm.read.ptx.sreg.ctaid.y()
41+
42+
out[i++] = blockIdx.z;
43+
// CIR: cir.func linkonce_odr @_ZN25__cuda_builtin_blockIdx_t17__fetch_builtin_zEv()
44+
// CIR: cir.llvm.intrinsic "nvvm.read.ptx.sreg.ctaid.z"
45+
// LLVM: call {{.*}} i32 @llvm.nvvm.read.ptx.sreg.ctaid.z()
46+
47+
48+
out[i++] = blockDim.x;
49+
// CIR: cir.func linkonce_odr @_ZN25__cuda_builtin_blockDim_t17__fetch_builtin_xEv()
50+
// CIR: cir.llvm.intrinsic "nvvm.read.ptx.sreg.ntid.x"
51+
// LLVM: call {{.*}} i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
52+
53+
out[i++] = blockDim.y;
54+
// CIR: cir.func linkonce_odr @_ZN25__cuda_builtin_blockDim_t17__fetch_builtin_yEv()
55+
// CIR: cir.llvm.intrinsic "nvvm.read.ptx.sreg.ntid.y"
56+
// LLVM: call {{.*}} i32 @llvm.nvvm.read.ptx.sreg.ntid.y()
57+
58+
out[i++] = blockDim.z;
59+
// CIR: cir.func linkonce_odr @_ZN25__cuda_builtin_blockDim_t17__fetch_builtin_zEv()
60+
// CIR: cir.llvm.intrinsic "nvvm.read.ptx.sreg.ntid.z"
61+
// LLVM: call {{.*}} i32 @llvm.nvvm.read.ptx.sreg.ntid.z()
62+
63+
64+
out[i++] = gridDim.x;
65+
// CIR: cir.func linkonce_odr @_ZN24__cuda_builtin_gridDim_t17__fetch_builtin_xEv()
66+
// CIR: cir.llvm.intrinsic "nvvm.read.ptx.sreg.nctaid.x"
67+
// LLVM: call {{.*}} i32 @llvm.nvvm.read.ptx.sreg.nctaid.x()
68+
69+
out[i++] = gridDim.y;
70+
// CIR: cir.func linkonce_odr @_ZN24__cuda_builtin_gridDim_t17__fetch_builtin_yEv()
71+
// CIR: cir.llvm.intrinsic "nvvm.read.ptx.sreg.nctaid.y"
72+
// LLVM: call {{.*}} i32 @llvm.nvvm.read.ptx.sreg.nctaid.y()
73+
74+
out[i++] = gridDim.z;
75+
// CIR: cir.func linkonce_odr @_ZN24__cuda_builtin_gridDim_t17__fetch_builtin_zEv()
76+
// CIR: cir.llvm.intrinsic "nvvm.read.ptx.sreg.nctaid.z"
77+
// LLVM: call {{.*}} i32 @llvm.nvvm.read.ptx.sreg.nctaid.z()
78+
79+
80+
out[i++] = warpSize;
81+
// CIR: [[REGISTER:%.*]] = cir.const #cir.int<32>
82+
// CIR: cir.store [[REGISTER]]
83+
// LLVM: store i32 32,
84+
85+
86+
// CIR: cir.return loc
87+
// LLVM: ret void
88+
}

0 commit comments

Comments
 (0)