Skip to content

Commit 9949967

Browse files
committed
[CIR][CUDA] Add Support stream per thread
1 parent 96a0551 commit 9949967

File tree

2 files changed

+35
-6
lines changed

2 files changed

+35
-6
lines changed

clang/lib/CIR/CodeGen/CIRGenCUDANV.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -161,12 +161,18 @@ void CIRGenNVCUDARuntime::emitDeviceStubBodyNew(CIRGenFunction &cgf,
161161

162162
// The default stream is usually stream 0 (the legacy default stream).
163163
// For per-thread default stream, we need a different LaunchKernel function.
164+
std::string kernelLaunchAPI = "LaunchKernel";
164165
if (cgm.getLangOpts().GPUDefaultStream ==
165-
LangOptions::GPUDefaultStreamKind::PerThread)
166-
llvm_unreachable("NYI");
166+
LangOptions::GPUDefaultStreamKind::PerThread) {
167+
if (cgf.getLangOpts().HIP)
168+
kernelLaunchAPI = kernelLaunchAPI + "_spt";
169+
else if (cgf.getLangOpts().CUDA)
170+
kernelLaunchAPI = kernelLaunchAPI + "_ptsz";
171+
}
167172

168-
std::string launchAPI = addPrefixToName("LaunchKernel");
169-
const IdentifierInfo &launchII = cgm.getASTContext().Idents.get(launchAPI);
173+
std::string launchKernelName = addPrefixToName(kernelLaunchAPI);
174+
const IdentifierInfo &launchII =
175+
cgm.getASTContext().Idents.get(launchKernelName);
170176
FunctionDecl *launchFD = nullptr;
171177
for (auto *result : dc->lookup(&launchII)) {
172178
if (FunctionDecl *fd = dyn_cast<FunctionDecl>(result))
@@ -175,7 +181,7 @@ void CIRGenNVCUDARuntime::emitDeviceStubBodyNew(CIRGenFunction &cgf,
175181

176182
if (launchFD == nullptr) {
177183
cgm.Error(cgf.CurFuncDecl->getLocation(),
178-
"Can't find declaration for " + launchAPI);
184+
"Can't find declaration for " + launchKernelName);
179185
return;
180186
}
181187

@@ -257,7 +263,7 @@ void CIRGenNVCUDARuntime::emitDeviceStubBodyNew(CIRGenFunction &cgf,
257263

258264
mlir::Type launchTy = cgm.getTypes().convertType(launchFD->getType());
259265
mlir::Operation *launchFn =
260-
cgm.createRuntimeFunction(cast<cir::FuncType>(launchTy), launchAPI);
266+
cgm.createRuntimeFunction(cast<cir::FuncType>(launchTy), launchKernelName);
261267
const auto &callInfo = cgm.getTypes().arrangeFunctionDeclaration(launchFD);
262268
cgf.emitCall(callInfo, CIRGenCallee::forDirect(launchFn), ReturnValueSlot(),
263269
launchArgs);

clang/test/CIR/CodeGen/CUDA/simple.cu

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,26 @@
3030
// RUN: %s -o %t.ll
3131
// RUN: FileCheck --check-prefix=OGCG-DEVICE --input-file=%t.ll %s
3232

33+
// Per Thread Stream test cases:
34+
35+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir \
36+
// RUN: -x cuda -emit-cir -target-sdk-version=12.3 \
37+
// RUN: -fgpu-default-stream=per-thread -DCUDA_API_PER_THREAD_DEFAULT_STREAM \
38+
// RUN: %s -o %t.cir
39+
// RUN: FileCheck --check-prefixes=CIR-HOST-PTH --input-file=%t.cir %s
40+
41+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir \
42+
// RUN: -x cuda -emit-llvm -target-sdk-version=12.3 \
43+
// RUN: -fgpu-default-stream=per-thread -DCUDA_API_PER_THREAD_DEFAULT_STREAM \
44+
// RUN: %s -o %t.ll
45+
// RUN: FileCheck --check-prefixes=LLVM-HOST-PTH --input-file=%t.ll %s
46+
47+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu \
48+
// RUN: -x cuda -emit-llvm -target-sdk-version=12.3 \
49+
// RUN: -fgpu-default-stream=per-thread -DCUDA_API_PER_THREAD_DEFAULT_STREAM \
50+
// RUN: %s -o %t.ll
51+
// RUN: FileCheck --check-prefixes=OGCG-HOST-PTH --input-file=%t.ll %s
52+
3353
// Attribute for global_fn
3454
// CIR-HOST: [[Kernel:#[a-zA-Z_0-9]+]] = {{.*}}#cir.cu.kernel_name<_Z9global_fni>{{.*}}
3555

@@ -54,19 +74,22 @@ __global__ void global_fn(int a) {}
5474
// CIR-HOST: cir.call @__cudaPopCallConfiguration
5575
// CIR-HOST: cir.get_global @_Z24__device_stub__global_fni
5676
// CIR-HOST: cir.call @cudaLaunchKernel
77+
// CIR-HOST-PTH: cir.call @cudaLaunchKernel_ptsz
5778

5879
// LLVM-HOST: void @_Z24__device_stub__global_fni
5980
// LLVM-HOST: %[[#KernelArgs:]] = alloca [1 x ptr], i64 1, align 16
6081
// LLVM-HOST: %[[#GEP1:]] = getelementptr ptr, ptr %[[#KernelArgs]], i32 0
6182
// LLVM-HOST: %[[#GEP2:]] = getelementptr [1 x ptr], ptr %[[#KernelArgs]], i32 0, i64 0
6283
// LLVM-HOST: call i32 @__cudaPopCallConfiguration
6384
// LLVM-HOST: call i32 @cudaLaunchKernel(ptr @_Z24__device_stub__global_fni
85+
// LLVM-HOST-PTH: call i32 @cudaLaunchKernel_ptsz(ptr @_Z24__device_stub__global_fni
6486

6587
// OGCG-HOST: void @_Z24__device_stub__global_fni
6688
// OGCG-HOST: %kernel_args = alloca ptr, i64 1, align 16
6789
// OGCG-HOST: getelementptr ptr, ptr %kernel_args, i32 0
6890
// OGCG-HOST: call i32 @__cudaPopCallConfiguration
6991
// OGCG-HOST: call noundef i32 @cudaLaunchKernel(ptr noundef @_Z24__device_stub__global_fni
92+
// OGCG-HOST-PTH: call noundef i32 @cudaLaunchKernel_ptsz(ptr noundef @_Z24__device_stub__global_fni
7093

7194

7295
int main() {

0 commit comments

Comments
 (0)