Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions clang/lib/CIR/CodeGen/CIRGenCUDANV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,12 +161,18 @@ void CIRGenNVCUDARuntime::emitDeviceStubBodyNew(CIRGenFunction &cgf,

// The default stream is usually stream 0 (the legacy default stream).
// For per-thread default stream, we need a different LaunchKernel function.
std::string kernelLaunchAPI = "LaunchKernel";
if (cgm.getLangOpts().GPUDefaultStream ==
LangOptions::GPUDefaultStreamKind::PerThread)
llvm_unreachable("NYI");
LangOptions::GPUDefaultStreamKind::PerThread) {
if (cgf.getLangOpts().HIP)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add tests that also cover HIP!

kernelLaunchAPI = kernelLaunchAPI + "_spt";
else if (cgf.getLangOpts().CUDA)
kernelLaunchAPI = kernelLaunchAPI + "_ptsz";
}

std::string launchAPI = addPrefixToName("LaunchKernel");
const IdentifierInfo &launchII = cgm.getASTContext().Idents.get(launchAPI);
std::string launchKernelName = addPrefixToName(kernelLaunchAPI);
const IdentifierInfo &launchII =
cgm.getASTContext().Idents.get(launchKernelName);
FunctionDecl *launchFD = nullptr;
for (auto *result : dc->lookup(&launchII)) {
if (FunctionDecl *fd = dyn_cast<FunctionDecl>(result))
Expand All @@ -175,7 +181,7 @@ void CIRGenNVCUDARuntime::emitDeviceStubBodyNew(CIRGenFunction &cgf,

if (launchFD == nullptr) {
cgm.Error(cgf.CurFuncDecl->getLocation(),
"Can't find declaration for " + launchAPI);
"Can't find declaration for " + launchKernelName);
return;
}

Expand Down Expand Up @@ -256,8 +262,8 @@ void CIRGenNVCUDARuntime::emitDeviceStubBodyNew(CIRGenFunction &cgf,
launchFD->getParamDecl(5)->getType());

mlir::Type launchTy = cgm.getTypes().convertType(launchFD->getType());
mlir::Operation *launchFn =
cgm.createRuntimeFunction(cast<cir::FuncType>(launchTy), launchAPI);
mlir::Operation *launchFn = cgm.createRuntimeFunction(
cast<cir::FuncType>(launchTy), launchKernelName);
const auto &callInfo = cgm.getTypes().arrangeFunctionDeclaration(launchFD);
cgf.emitCall(callInfo, CIRGenCallee::forDirect(launchFn), ReturnValueSlot(),
launchArgs);
Expand Down
23 changes: 23 additions & 0 deletions clang/test/CIR/CodeGen/CUDA/simple.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,26 @@
// RUN: %s -o %t.ll
// RUN: FileCheck --check-prefix=OGCG-DEVICE --input-file=%t.ll %s

// Per Thread Stream test cases:

// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir \
// RUN: -x cuda -emit-cir -target-sdk-version=12.3 \
// RUN: -fgpu-default-stream=per-thread -DCUDA_API_PER_THREAD_DEFAULT_STREAM \
// RUN: %s -o %t.cir
// RUN: FileCheck --check-prefixes=CIR-HOST-PTH --input-file=%t.cir %s

// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir \
// RUN: -x cuda -emit-llvm -target-sdk-version=12.3 \
// RUN: -fgpu-default-stream=per-thread -DCUDA_API_PER_THREAD_DEFAULT_STREAM \
// RUN: %s -o %t.ll
// RUN: FileCheck --check-prefixes=LLVM-HOST-PTH --input-file=%t.ll %s

// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu \
// RUN: -x cuda -emit-llvm -target-sdk-version=12.3 \
// RUN: -fgpu-default-stream=per-thread -DCUDA_API_PER_THREAD_DEFAULT_STREAM \
// RUN: %s -o %t.ll
// RUN: FileCheck --check-prefixes=OGCG-HOST-PTH --input-file=%t.ll %s

// Attribute for global_fn
// CIR-HOST: [[Kernel:#[a-zA-Z_0-9]+]] = {{.*}}#cir.cu.kernel_name<_Z9global_fni>{{.*}}

Expand All @@ -54,19 +74,22 @@ __global__ void global_fn(int a) {}
// CIR-HOST: cir.call @__cudaPopCallConfiguration
// CIR-HOST: cir.get_global @_Z24__device_stub__global_fni
// CIR-HOST: cir.call @cudaLaunchKernel
// CIR-HOST-PTH: cir.call @cudaLaunchKernel_ptsz

// LLVM-HOST: void @_Z24__device_stub__global_fni
// LLVM-HOST: %[[#KernelArgs:]] = alloca [1 x ptr], i64 1, align 16
// LLVM-HOST: %[[#GEP1:]] = getelementptr ptr, ptr %[[#KernelArgs]], i32 0
// LLVM-HOST: %[[#GEP2:]] = getelementptr [1 x ptr], ptr %[[#KernelArgs]], i32 0, i64 0
// LLVM-HOST: call i32 @__cudaPopCallConfiguration
// LLVM-HOST: call i32 @cudaLaunchKernel(ptr @_Z24__device_stub__global_fni
// LLVM-HOST-PTH: call i32 @cudaLaunchKernel_ptsz(ptr @_Z24__device_stub__global_fni

// OGCG-HOST: void @_Z24__device_stub__global_fni
// OGCG-HOST: %kernel_args = alloca ptr, i64 1, align 16
// OGCG-HOST: getelementptr ptr, ptr %kernel_args, i32 0
// OGCG-HOST: call i32 @__cudaPopCallConfiguration
// OGCG-HOST: call noundef i32 @cudaLaunchKernel(ptr noundef @_Z24__device_stub__global_fni
// OGCG-HOST-PTH: call noundef i32 @cudaLaunchKernel_ptsz(ptr noundef @_Z24__device_stub__global_fni


int main() {
Expand Down