Skip to content

Commit 98b04ee

Browse files
koparasylanza
authored andcommitted
[CIR][HIP] Support call of HIP Kernels (#1952)
This patch extends `emitDirectCallee` to resolve HIP host launches to the correct kernel stub (`__device_stub__...`), matching CUDA semantics
1 parent 5c40d7f commit 98b04ee

File tree

3 files changed

+29
-17
lines changed

3 files changed

+29
-17
lines changed

clang/lib/CIR/CodeGen/CIRGenExpr.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -555,9 +555,16 @@ static CIRGenCallee emitDirectCallee(CIRGenModule &CGM, GlobalDecl GD) {
555555

556556
mlir::Operation *CalleePtr = emitFunctionDeclPointer(CGM, GD);
557557

558-
if (CGM.getLangOpts().CUDA && !CGM.getLangOpts().CUDAIsDevice &&
559-
FD->hasAttr<CUDAGlobalAttr>())
560-
CalleePtr = CGM.getCUDARuntime().getKernelStub(CalleePtr);
558+
if ((CGM.getLangOpts().HIP || CGM.getLangOpts().CUDA) &&
559+
!CGM.getLangOpts().CUDAIsDevice && FD->hasAttr<CUDAGlobalAttr>()) {
560+
561+
// Ensure the handle is created and use it as the lookup key.
562+
auto *Handle = CGM.getCUDARuntime().getKernelHandle(
563+
llvm::cast<cir::FuncOp>(CalleePtr), GD);
564+
565+
// Now look up the stub via the handle
566+
CalleePtr = CGM.getCUDARuntime().getKernelStub(Handle);
567+
}
561568

562569
return CIRGenCallee::forDirect(CalleePtr, GD);
563570
}
@@ -1579,8 +1586,6 @@ RValue CIRGenFunction::emitCall(clang::QualType CalleeType,
15791586
Callee.setFunctionPointer(Fn);
15801587
}
15811588

1582-
assert(!CGM.getLangOpts().HIP && "HIP NYI");
1583-
15841589
assert(!MustTailCall && "Must tail NYI");
15851590
cir::CIRCallOpInterface callOP;
15861591
RValue Call = emitCall(FnInfo, Callee, ReturnValue, Args, &callOP,

clang/lib/CIR/CodeGen/CIRGenModule.cpp

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2555,20 +2555,9 @@ cir::FuncOp CIRGenModule::GetAddrOfFunction(clang::GlobalDecl gd, mlir::Type ty,
25552555
auto f = GetOrCreateCIRFunction(mangledName, ty, gd, forVTable, dontDefer,
25562556
/*IsThunk=*/false, isForDefinition);
25572557

2558-
// As __global__ functions (kernels) always reside on device,
2559-
// when we access them from host, we must refer to the kernel handle.
2560-
// For HIP, we should never directly access the host device addr, but
2561-
// instead the Global Variable of that stub. For CUDA, it's just the device
2562-
// stub. For HIP, it's something different.
25632558
if ((langOpts.HIP || langOpts.CUDA) && !langOpts.CUDAIsDevice &&
2564-
cast<FunctionDecl>(gd.getDecl())->hasAttr<CUDAGlobalAttr>()) {
2559+
cast<FunctionDecl>(gd.getDecl())->hasAttr<CUDAGlobalAttr>())
25652560
(void)getCUDARuntime().getKernelHandle(f, gd);
2566-
if (isForDefinition)
2567-
return f;
2568-
2569-
if (langOpts.HIP)
2570-
llvm_unreachable("NYI");
2571-
}
25722561

25732562
return f;
25742563
}

clang/test/CIR/CodeGen/HIP/simple.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,21 @@ __global__ void global_fn(int a) {}
3232
// The stub has the mangled name of the function
3333
// CIR-HOST: cir.get_global @_Z9global_fni
3434
// CIR-HOST: cir.call @hipLaunchKernel
35+
36+
int main() {
37+
global_fn<<<1, 1>>>(1);
38+
}
39+
// CIR-DEVICE-NOT: cir.func dso_local @main()
40+
41+
// CIR-HOST: cir.func dso_local @main()
42+
// CIR-HOST: cir.call @_ZN4dim3C1Ejjj
43+
// CIR-HOST: cir.call @_ZN4dim3C1Ejjj
44+
// CIR-HOST: [[Push:%[0-9]+]] = cir.call @__hipPushCallConfiguration
45+
// CIR-HOST: [[ConfigOK:%[0-9]+]] = cir.cast int_to_bool [[Push]]
46+
// CIR-HOST: cir.if [[ConfigOK]] {
47+
// CIR-HOST: } else {
48+
// CIR-HOST: [[Arg:%[0-9]+]] = cir.const #cir.int<1>
49+
// CIR-HOST: cir.call @_Z24__device_stub__global_fni([[Arg]])
50+
// CIR-HOST: }
51+
52+

0 commit comments

Comments
 (0)