Skip to content

Commit d8f3180

Browse files
authored
[CIR][CUDA] FIx CUDA host compilation on kernel launch (#1906)
This PR implements some missing blocks that allow us to effectively allow us to launch kernels from the host. All of the tests stated in this [commit](69f2099) are now resolved. I spent half a day figuring the following: I tried experiementing performing host compilation(`-fcuda-is-device`) with target triple: `nvptx64-nvidia-cuda` but was getting a module verification error that, to keep it simple looked like: `error: 'cir.call' op calling convention mismatch: expected ptx_kernel, but provided c`. I thought that was expected given that we're essentially using the device to compile on the host, which doesn't make a lot of sense. until I tried to replicate the same in OG and didn't really run into any problem in that regard. Are the calling conventions enforced in CIR much more strict as compared to OG? Or is that simply a bug from OG?
1 parent 611ca17 commit d8f3180

File tree

5 files changed

+106
-18
lines changed

5 files changed

+106
-18
lines changed

clang/lib/CIR/CodeGen/CIRGenExpr.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1408,7 +1408,9 @@ RValue CIRGenFunction::emitCallExpr(const clang::CallExpr *E,
14081408
if (const auto *CE = dyn_cast<CXXMemberCallExpr>(E))
14091409
return emitCXXMemberCallExpr(CE, ReturnValue);
14101410

1411-
assert(!dyn_cast<CUDAKernelCallExpr>(E) && "CUDA NYI");
1411+
if (const auto *CE = dyn_cast<CUDAKernelCallExpr>(E))
1412+
return emitCUDAKernelCallExpr(CE, ReturnValue);
1413+
14121414
if (const auto *CE = dyn_cast<CXXOperatorCallExpr>(E))
14131415
if (const CXXMethodDecl *MD =
14141416
dyn_cast_or_null<CXXMethodDecl>(CE->getCalleeDecl()))

clang/lib/CIR/CodeGen/CIRGenExprCXX.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,11 @@ CIRGenFunction::emitCXXOperatorMemberCallExpr(const CXXOperatorCallExpr *E,
376376
/*IsArrow=*/false, E->getArg(0));
377377
}
378378

379+
RValue CIRGenFunction::emitCUDAKernelCallExpr(const CUDAKernelCallExpr *E,
380+
ReturnValueSlot ReturnValue) {
381+
return CGM.getCUDARuntime().emitCUDAKernelCallExpr(*this, E, ReturnValue);
382+
}
383+
379384
static void emitNullBaseClassInitialization(CIRGenFunction &CGF,
380385
Address DestPtr,
381386
const CXXRecordDecl *Base) {

clang/lib/CIR/CodeGen/CIRGenFunction.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2068,6 +2068,9 @@ class CIRGenFunction : public CIRGenTypeCache {
20682068
const CXXMethodDecl *MD,
20692069
ReturnValueSlot ReturnValue);
20702070

2071+
RValue emitCUDAKernelCallExpr(const CUDAKernelCallExpr *E,
2072+
ReturnValueSlot ReturnValue);
2073+
20712074
RValue emitCXXPseudoDestructorExpr(const CXXPseudoDestructorExpr *expr);
20722075

20732076
void emitCXXTemporary(const CXXTemporary *Temporary, QualType TempType,

clang/test/CIR/CodeGen/CUDA/destructor.cu

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,26 @@
1010
// RUN: %s -o %t.cir
1111
// RUN: FileCheck --check-prefix=CIR-DEVICE --input-file=%t.cir %s
1212

13+
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -fclangir \
14+
// RUN: -fcuda-is-device -emit-llvm -target-sdk-version=12.3 \
15+
// RUN: %s -o %t.ll
16+
// RUN: FileCheck --check-prefix=LLVM-DEVICE --input-file=%t.ll %s
17+
18+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir \
19+
// RUN: -x cuda -emit-llvm -target-sdk-version=12.3 \
20+
// RUN: %s -o %t.ll
21+
// RUN: FileCheck --check-prefix=LLVM-HOST --input-file=%t.ll %s
22+
23+
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda \
24+
// RUN: -fcuda-is-device -emit-llvm -target-sdk-version=12.3 \
25+
// RUN: %s -o %t.ll
26+
// RUN: FileCheck --check-prefix=OGCG-DEVICE --input-file=%t.ll %s
27+
28+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu \
29+
// RUN: -x cuda -emit-llvm -target-sdk-version=12.3 \
30+
// RUN: %s -o %t.ll
31+
// RUN: FileCheck --check-prefix=OGCG-HOST --input-file=%t.ll %s
32+
1333
// Make sure we do emit device-side kernel even if it's only referenced
1434
// by the destructor of a variable not present on device.
1535
template<typename T> __global__ void f(T) {}
@@ -19,11 +39,23 @@ template<typename T> struct A {
1939

2040
// CIR-HOST: module
2141
// CIR-DEVICE: module
22-
// CIR-DEVICE-DISABLED: cir.func dso_local @_Z1fIiEvT_
42+
// CIR-DEVICE: cir.func dso_local @_Z1fIiEvT_
43+
// LLVM-DEVICE: define dso_local ptx_kernel void @_Z1fIiEvT_
44+
// OGCG-DEVICE: define ptx_kernel void @_Z1fIiEvT_
45+
46+
// CIR-HOST: cir.func {{.*}} @_ZN1AIiED2Ev{{.*}} {
47+
// CIR-HOST: cir.call @__cudaPushCallConfiguration
48+
// CIR-HOST: cir.call @_Z16__device_stub__fIiEvT_
49+
// CIR-HOST: }
50+
51+
// LLVM-HOST: define linkonce_odr void @_ZN1AIiED2Ev
52+
// LLVM-HOST: call i32 @__cudaPushCallConfiguration(
53+
// LLVM-HOST: call void @_Z16__device_stub__fIiEvT_
54+
55+
// OGCG-HOST: define linkonce_odr void @_ZN1AIiED2Ev
56+
// OGCG-HOST: call i32 @__cudaPushCallConfiguration(
57+
// OGCG-HOST: call void @_Z16__device_stub__fIiEvT_
58+
2359

24-
// CIR-HOST-DISABLED: cir.func {{.*}} @_ZN1AIiED2Ev{{.*}} {
25-
// CIR-HOST-DISABLED: cir.call @__cudaPushCallConfiguration
26-
// CIR-HOST-DISABLED: cir.call @_Z16__device_stub__fIiEvT_
27-
// CIR-HOST-DISABLED: }
2860

29-
// A<int> a;
61+
A<int> a;

clang/test/CIR/CodeGen/CUDA/simple.cu

Lines changed: 57 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,35 @@
11
#include "../Inputs/cuda.h"
22

3-
// TODO: host build is currently crashing.
4-
// RUN-DISABLE: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir \
5-
// RUN-DISABLE: -x cuda -emit-cir -target-sdk-version=12.3 \
6-
// RUN-DISABLE: %s -o %t.cir
7-
// RUN-DISABLE: FileCheck --check-prefix=CIR-HOST --input-file=%t.cir %s
3+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir \
4+
// RUN: -x cuda -emit-cir -target-sdk-version=12.3 \
5+
// RUN: %s -o %t.cir
6+
// RUN: FileCheck --check-prefix=CIR-HOST --input-file=%t.cir %s
87

98
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -fclangir \
109
// RUN: -fcuda-is-device -emit-cir -target-sdk-version=12.3 \
1110
// RUN: %s -o %t.cir
1211
// RUN: FileCheck --check-prefix=CIR-DEVICE --input-file=%t.cir %s
1312

13+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir \
14+
// RUN: -x cuda -emit-llvm -target-sdk-version=12.3 \
15+
// RUN: %s -o %t.ll
16+
// RUN: FileCheck --check-prefix=LLVM-HOST --input-file=%t.ll %s
17+
18+
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -fclangir \
19+
// RUN: -fcuda-is-device -emit-llvm -target-sdk-version=12.3 \
20+
// RUN: %s -o %t.ll
21+
// RUN: FileCheck --check-prefix=LLVM-DEVICE --input-file=%t.ll %s
22+
23+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu \
24+
// RUN: -x cuda -emit-llvm -target-sdk-version=12.3 \
25+
// RUN: %s -o %t.ll
26+
// RUN: FileCheck --check-prefix=OGCG-HOST --input-file=%t.ll %s
27+
28+
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda \
29+
// RUN: -fcuda-is-device -emit-llvm -target-sdk-version=12.3 \
30+
// RUN: %s -o %t.ll
31+
// RUN: FileCheck --check-prefix=OGCG-DEVICE --input-file=%t.ll %s
32+
1433
// Attribute for global_fn
1534
// CIR-HOST: [[Kernel:#[a-zA-Z_0-9]+]] = {{.*}}#cir.cu.kernel_name<_Z9global_fni>{{.*}}
1635

@@ -25,6 +44,7 @@ __device__ void device_fn(int* a, double b, float c) {}
2544
__global__ void global_fn(int a) {}
2645
// CIR-DEVICE: @_Z9global_fni({{.*}} cc(ptx_kernel)
2746
// LLVM-DEVICE: define dso_local ptx_kernel void @_Z9global_fni
47+
// OGCG-DEVICE: define dso_local ptx_kernel void @_Z9global_fni
2848

2949
// Check for device stub emission.
3050

@@ -38,10 +58,17 @@ __global__ void global_fn(int a) {}
3858
// LLVM-HOST: void @_Z24__device_stub__global_fni
3959
// LLVM-HOST: %[[#KernelArgs:]] = alloca [1 x ptr], i64 1, align 16
4060
// LLVM-HOST: %[[#GEP1:]] = getelementptr ptr, ptr %[[#KernelArgs]], i32 0
41-
// LLVM-HOST: %[[#GEP2:]] = getelementptr ptr, ptr %[[#GEP1]], i64 0
61+
// LLVM-HOST: %[[#GEP2:]] = getelementptr [1 x ptr], ptr %[[#KernelArgs]], i32 0, i64 0
4262
// LLVM-HOST: call i32 @__cudaPopCallConfiguration
4363
// LLVM-HOST: call i32 @cudaLaunchKernel(ptr @_Z24__device_stub__global_fni
4464

65+
// OGCG-HOST: void @_Z24__device_stub__global_fni
66+
// OGCG-HOST: %kernel_args = alloca ptr, i64 1, align 16
67+
// OGCG-HOST: getelementptr ptr, ptr %kernel_args, i32 0
68+
// OGCG-HOST: call i32 @__cudaPopCallConfiguration
69+
// OGCG-HOST: call noundef i32 @cudaLaunchKernel(ptr noundef @_Z24__device_stub__global_fni
70+
71+
4572
int main() {
4673
global_fn<<<1, 1>>>(1);
4774
}
@@ -63,10 +90,29 @@ int main() {
6390
// LLVM-HOST: alloca %struct.dim3
6491
// LLVM-HOST: call void @_ZN4dim3C1Ejjj
6592
// LLVM-HOST: call void @_ZN4dim3C1Ejjj
66-
// LLVM-HOST: [[LLVMConfigOK:%[0-9]+]] = call i32 @__cudaPushCallConfiguration
67-
// LLVM-HOST: br [[LLVMConfigOK]], label %[[#Good:]], label [[#Bad:]]
93+
// LLVM-HOST: %[[#ConfigOK:]] = call i32 @__cudaPushCallConfiguration
94+
// LLVM-HOST: %[[#ConfigCond:]] = icmp ne i32 %[[#ConfigOK]], 0
95+
// LLVM-HOST: br i1 %[[#ConfigCond]], label %[[#Good:]], label %[[#Bad:]]
6896
// LLVM-HOST: [[#Good]]:
69-
// LLVM-HOST: br label [[#End:]]
97+
// LLVM-HOST: br label %[[#End:]]
7098
// LLVM-HOST: [[#Bad]]:
71-
// LLVM-HOST: call void @_Z24__device_stub__global_fni
72-
// LLVM-HOST: br label [[#End]]
99+
// LLVM-HOST: call void @_Z24__device_stub__global_fni(i32 1)
100+
// LLVM-HOST: br label %[[#End:]]
101+
// LLVM-HOST: [[#End]]:
102+
// LLVM-HOST: %[[#]] = load i32
103+
// LLVM-HOST: ret i32
104+
105+
// OGCG-HOST: define dso_local noundef i32 @main
106+
// OGCG-HOST: alloca %struct.dim3, align 4
107+
// OGCG-HOST: alloca %struct.dim3, align 4
108+
// OGCG-HOST: call void @_ZN4dim3C1Ejjj
109+
// OGCG-HOST: call void @_ZN4dim3C1Ejjj
110+
// OGCG-HOST: %call = call i32 @__cudaPushCallConfiguration
111+
// OGCG-HOST: %tobool = icmp ne i32 %call, 0
112+
// OGCG-HOST: br i1 %tobool, label %kcall.end, label %kcall.configok
113+
// OGCG-HOST: kcall.configok:
114+
// OGCG-HOST: call void @_Z24__device_stub__global_fni(i32 noundef 1)
115+
// OGCG-HOST: br label %kcall.end
116+
// OGCG-HOST: kcall.end:
117+
// OGCG-HOST: %{{[0-9]+}} = load i32, ptr %retval, align 4
118+
// OGCG-HOST: ret i32

0 commit comments

Comments
 (0)