Skip to content

Commit f64b31b

Browse files
authored
[CIR][CIRGen][CUDA] Add kernel stub retrieval and update built-in variables tests (#1904)
1 parent bfe865b commit f64b31b

File tree

4 files changed

+104
-75
lines changed

4 files changed

+104
-75
lines changed

clang/lib/CIR/CodeGen/CIRGenCUDANV.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,12 @@ class CIRGenNVCUDARuntime : public CIRGenCUDARuntime {
7575

7676
mlir::Operation *getKernelHandle(cir::FuncOp fn, GlobalDecl GD) override;
7777

78+
mlir::Operation *getKernelStub(mlir::Operation *handle) override {
79+
auto loc = KernelStubs.find(handle);
80+
assert(loc != KernelStubs.end());
81+
return loc->second;
82+
}
83+
7884
void internalizeDeviceSideVar(const VarDecl *d,
7985
cir::GlobalLinkageKind &linkage) override;
8086
/// Returns function or variable name on device side even if the current

clang/lib/CIR/CodeGen/CIRGenCUDARuntime.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ class CIRGenCUDARuntime {
4545
const CUDAKernelCallExpr *expr,
4646
ReturnValueSlot retValue);
4747
virtual mlir::Operation *getKernelHandle(cir::FuncOp fn, GlobalDecl GD) = 0;
48+
/// Get kernel stub by kernel handle.
49+
virtual mlir::Operation *getKernelStub(mlir::Operation *handle) = 0;
50+
4851
virtual void internalizeDeviceSideVar(const VarDecl *d,
4952
cir::GlobalLinkageKind &linkage) = 0;
5053
/// Returns function or variable name on device side even if the current

clang/lib/CIR/CodeGen/CIRGenExpr.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -533,9 +533,11 @@ static CIRGenCallee emitDirectCallee(CIRGenModule &CGM, GlobalDecl GD) {
533533
return CIRGenCallee::forBuiltin(builtinID, FD);
534534
}
535535

536-
auto CalleePtr = emitFunctionDeclPointer(CGM, GD);
536+
mlir::Operation *CalleePtr = emitFunctionDeclPointer(CGM, GD);
537537

538-
assert(!CGM.getLangOpts().CUDA && "NYI");
538+
if (CGM.getLangOpts().CUDA && !CGM.getLangOpts().CUDAIsDevice &&
539+
FD->hasAttr<CUDAGlobalAttr>())
540+
CalleePtr = CGM.getCUDARuntime().getKernelStub(CalleePtr);
539541

540542
return CIRGenCallee::forDirect(CalleePtr, GD);
541543
}

clang/test/CIR/CodeGen/CUDA/cuda-builtin-vars.cu

Lines changed: 91 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -6,84 +6,102 @@
66
// RUN: -fcuda-is-device -emit-cir -o - %s \
77
// RUN: | FileCheck --check-prefix=CIR %s
88

9+
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda \
10+
// RUN: -fcuda-is-device -emit-llvm -o - %s \
11+
// RUN: | FileCheck --check-prefix=OGCG %s
12+
913
#include "__clang_cuda_builtin_vars.h"
1014

1115
// LLVM: define{{.*}} void @_Z6kernelPi(ptr %0)
12-
// CIR-LABEL: @_Z6kernelPi
16+
// OGCG: define{{.*}} void @_Z6kernelPi(ptr noundef %out)
1317
__attribute__((global))
1418
void kernel(int *out) {
1519
int i = 0;
1620

17-
// out[i++] = threadIdx.x;
18-
// CIR-DISABLED: cir.func linkonce_odr @_ZN26__cuda_builtin_threadIdx_t17__fetch_builtin_xEv()
19-
// CIR-DISABLED: cir.llvm.intrinsic "nvvm.read.ptx.sreg.tid.x"
20-
// LLVM-DISABLED: call{{.*}} i32 @llvm.nvvm.read.ptx.sreg.tid.x()
21-
22-
// out[i++] = threadIdx.y;
23-
// CIR-DISABLED: cir.func linkonce_odr @_ZN26__cuda_builtin_threadIdx_t17__fetch_builtin_yEv()
24-
// CIR-DISABLED: cir.llvm.intrinsic "nvvm.read.ptx.sreg.tid.y"
25-
// LLVM-DISABLED: call{{.*}} i32 @llvm.nvvm.read.ptx.sreg.tid.y()
26-
27-
// out[i++] = threadIdx.z;
28-
// CIR-DISABLED: cir.func linkonce_odr @_ZN26__cuda_builtin_threadIdx_t17__fetch_builtin_zEv()
29-
// CIR-DISABLED: cir.llvm.intrinsic "nvvm.read.ptx.sreg.tid.z"
30-
// LLVM-DISABLED: call{{.*}} i32 @llvm.nvvm.read.ptx.sreg.tid.z()
31-
32-
33-
// out[i++] = blockIdx.x;
34-
// CIR-DISABLED: cir.func linkonce_odr @_ZN25__cuda_builtin_blockIdx_t17__fetch_builtin_xEv()
35-
// CIR-DISABLED: cir.llvm.intrinsic "nvvm.read.ptx.sreg.ctaid.x"
36-
// LLVM-DISABLED: call{{.*}} i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
37-
38-
// out[i++] = blockIdx.y;
39-
// CIR-DISABLED: cir.func linkonce_odr @_ZN25__cuda_builtin_blockIdx_t17__fetch_builtin_yEv()
40-
// CIR-DISABLED: cir.llvm.intrinsic "nvvm.read.ptx.sreg.ctaid.y"
41-
// LLVM-DISABLED: call{{.*}} i32 @llvm.nvvm.read.ptx.sreg.ctaid.y()
42-
43-
// out[i++] = blockIdx.z;
44-
// CIR-DISABLED: cir.func linkonce_odr @_ZN25__cuda_builtin_blockIdx_t17__fetch_builtin_zEv()
45-
// CIR-DISABLED: cir.llvm.intrinsic "nvvm.read.ptx.sreg.ctaid.z"
46-
// LLVM-DISABLED: call{{.*}} i32 @llvm.nvvm.read.ptx.sreg.ctaid.z()
47-
48-
49-
// out[i++] = blockDim.x;
50-
// CIR-DISABLED: cir.func linkonce_odr @_ZN25__cuda_builtin_blockDim_t17__fetch_builtin_xEv()
51-
// CIR-DISABLED: cir.llvm.intrinsic "nvvm.read.ptx.sreg.ntid.x"
52-
// LLVM-DISABLED: call{{.*}} i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
53-
54-
// out[i++] = blockDim.y;
55-
// CIR-DISABLED: cir.func linkonce_odr @_ZN25__cuda_builtin_blockDim_t17__fetch_builtin_yEv()
56-
// CIR-DISABLED: cir.llvm.intrinsic "nvvm.read.ptx.sreg.ntid.y"
57-
// LLVM-DISABLED: call{{.*}} i32 @llvm.nvvm.read.ptx.sreg.ntid.y()
58-
59-
// out[i++] = blockDim.z;
60-
// CIR-DISABLED: cir.func linkonce_odr @_ZN25__cuda_builtin_blockDim_t17__fetch_builtin_zEv()
61-
// CIR-DISABLED: cir.llvm.intrinsic "nvvm.read.ptx.sreg.ntid.z"
62-
// LLVM-DISABLED: call{{.*}} i32 @llvm.nvvm.read.ptx.sreg.ntid.z()
63-
64-
65-
// out[i++] = gridDim.x;
66-
// CIR-DISABLED: cir.func linkonce_odr @_ZN24__cuda_builtin_gridDim_t17__fetch_builtin_xEv()
67-
// CIR-DISABLED: cir.llvm.intrinsic "nvvm.read.ptx.sreg.nctaid.x"
68-
// LLVM-DISABLED: call{{.*}} i32 @llvm.nvvm.read.ptx.sreg.nctaid.x()
69-
70-
// out[i++] = gridDim.y;
71-
// CIR-DISABLED: cir.func linkonce_odr @_ZN24__cuda_builtin_gridDim_t17__fetch_builtin_yEv()
72-
// CIR-DISABLED: cir.llvm.intrinsic "nvvm.read.ptx.sreg.nctaid.y"
73-
// LLVM-DISABLED: call{{.*}} i32 @llvm.nvvm.read.ptx.sreg.nctaid.y()
74-
75-
// out[i++] = gridDim.z;
76-
// CIR-DISABLED: cir.func linkonce_odr @_ZN24__cuda_builtin_gridDim_t17__fetch_builtin_zEv()
77-
// CIR-DISABLED: cir.llvm.intrinsic "nvvm.read.ptx.sreg.nctaid.z"
78-
// LLVM-DISABLED: call{{.*}} i32 @llvm.nvvm.read.ptx.sreg.nctaid.z()
79-
80-
81-
// out[i++] = warpSize;
82-
// CIR-DISABLED: [[REGISTER:%.*]] = cir.const #cir.int<32>
83-
// CIR-DISABLED: cir.store{{.*}} [[REGISTER]]
84-
// LLVM-DISABLED: store i32 32,
85-
86-
87-
// CIR-DISABLED: cir.return loc
88-
// LLVM-DISABLED: ret void
21+
out[i++] = threadIdx.x;
22+
// CIR: cir.func linkonce_odr @_ZN26__cuda_builtin_threadIdx_t17__fetch_builtin_xEv()
23+
// CIR: cir.llvm.intrinsic "nvvm.read.ptx.sreg.tid.x"
24+
// LLVM: call{{.*}} i32 @llvm.nvvm.read.ptx.sreg.tid.x()
25+
// OGCG: call noundef{{.*}} i32 @llvm.nvvm.read.ptx.sreg.tid.x()
26+
27+
out[i++] = threadIdx.y;
28+
// CIR: cir.func linkonce_odr @_ZN26__cuda_builtin_threadIdx_t17__fetch_builtin_yEv()
29+
// CIR: cir.llvm.intrinsic "nvvm.read.ptx.sreg.tid.y"
30+
// LLVM: call{{.*}} i32 @llvm.nvvm.read.ptx.sreg.tid.y()
31+
// OGCG: call noundef{{.*}} i32 @llvm.nvvm.read.ptx.sreg.tid.y()
32+
33+
out[i++] = threadIdx.z;
34+
// CIR: cir.func linkonce_odr @_ZN26__cuda_builtin_threadIdx_t17__fetch_builtin_zEv()
35+
// CIR: cir.llvm.intrinsic "nvvm.read.ptx.sreg.tid.z"
36+
// LLVM: call{{.*}} i32 @llvm.nvvm.read.ptx.sreg.tid.z()
37+
// OGCG: call noundef{{.*}} i32 @llvm.nvvm.read.ptx.sreg.tid.z()
38+
39+
40+
out[i++] = blockIdx.x;
41+
// CIR: cir.func linkonce_odr @_ZN25__cuda_builtin_blockIdx_t17__fetch_builtin_xEv()
42+
// CIR: cir.llvm.intrinsic "nvvm.read.ptx.sreg.ctaid.x"
43+
// LLVM: call{{.*}} i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
44+
// OGCG: call noundef{{.*}} i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
45+
46+
out[i++] = blockIdx.y;
47+
// CIR: cir.func linkonce_odr @_ZN25__cuda_builtin_blockIdx_t17__fetch_builtin_yEv()
48+
// CIR: cir.llvm.intrinsic "nvvm.read.ptx.sreg.ctaid.y"
49+
// LLVM: call{{.*}} i32 @llvm.nvvm.read.ptx.sreg.ctaid.y()
50+
// OGCG: call noundef{{.*}} i32 @llvm.nvvm.read.ptx.sreg.ctaid.y()
51+
52+
out[i++] = blockIdx.z;
53+
// CIR: cir.func linkonce_odr @_ZN25__cuda_builtin_blockIdx_t17__fetch_builtin_zEv()
54+
// CIR: cir.llvm.intrinsic "nvvm.read.ptx.sreg.ctaid.z"
55+
// LLVM: call{{.*}} i32 @llvm.nvvm.read.ptx.sreg.ctaid.z()
56+
// OGCG: call noundef{{.*}} i32 @llvm.nvvm.read.ptx.sreg.ctaid.z()
57+
58+
59+
out[i++] = blockDim.x;
60+
// CIR: cir.func linkonce_odr @_ZN25__cuda_builtin_blockDim_t17__fetch_builtin_xEv()
61+
// CIR: cir.llvm.intrinsic "nvvm.read.ptx.sreg.ntid.x"
62+
// LLVM: call{{.*}} i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
63+
// OGCG: call noundef{{.*}} i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
64+
65+
out[i++] = blockDim.y;
66+
// CIR: cir.func linkonce_odr @_ZN25__cuda_builtin_blockDim_t17__fetch_builtin_yEv()
67+
// CIR: cir.llvm.intrinsic "nvvm.read.ptx.sreg.ntid.y"
68+
// LLVM: call{{.*}} i32 @llvm.nvvm.read.ptx.sreg.ntid.y()
69+
// OGCG: call noundef{{.*}} i32 @llvm.nvvm.read.ptx.sreg.ntid.y()
70+
71+
out[i++] = blockDim.z;
72+
// CIR: cir.func linkonce_odr @_ZN25__cuda_builtin_blockDim_t17__fetch_builtin_zEv()
73+
// CIR: cir.llvm.intrinsic "nvvm.read.ptx.sreg.ntid.z"
74+
// LLVM: call{{.*}} i32 @llvm.nvvm.read.ptx.sreg.ntid.z()
75+
// OGCG: call noundef{{.*}} i32 @llvm.nvvm.read.ptx.sreg.ntid.z()
76+
77+
78+
out[i++] = gridDim.x;
79+
// CIR: cir.func linkonce_odr @_ZN24__cuda_builtin_gridDim_t17__fetch_builtin_xEv()
80+
// CIR: cir.llvm.intrinsic "nvvm.read.ptx.sreg.nctaid.x"
81+
// LLVM: call{{.*}} i32 @llvm.nvvm.read.ptx.sreg.nctaid.x()
82+
// OGCG: call noundef{{.*}} i32 @llvm.nvvm.read.ptx.sreg.nctaid.x()
83+
84+
out[i++] = gridDim.y;
85+
// CIR: cir.func linkonce_odr @_ZN24__cuda_builtin_gridDim_t17__fetch_builtin_yEv()
86+
// CIR: cir.llvm.intrinsic "nvvm.read.ptx.sreg.nctaid.y"
87+
// LLVM: call{{.*}} i32 @llvm.nvvm.read.ptx.sreg.nctaid.y()
88+
// OGCG: call noundef{{.*}} i32 @llvm.nvvm.read.ptx.sreg.nctaid.y()
89+
90+
out[i++] = gridDim.z;
91+
// CIR: cir.func linkonce_odr @_ZN24__cuda_builtin_gridDim_t17__fetch_builtin_zEv()
92+
// CIR: cir.llvm.intrinsic "nvvm.read.ptx.sreg.nctaid.z"
93+
// LLVM: call{{.*}} i32 @llvm.nvvm.read.ptx.sreg.nctaid.z()
94+
// OGCG: call noundef{{.*}} i32 @llvm.nvvm.read.ptx.sreg.nctaid.z()
95+
96+
97+
out[i++] = warpSize;
98+
// CIR: [[REGISTER:%.*]] = cir.const #cir.int<32>
99+
// CIR: cir.store{{.*}} [[REGISTER]]
100+
// LLVM: store i32 32,
101+
// OGCG: store i32 32,
102+
103+
104+
// CIR: cir.return loc
105+
// LLVM: ret void
106+
// OGCG: ret void
89107
}

0 commit comments

Comments
 (0)