Skip to content

Commit 90a5b61

Browse files
authored
[CIR][CUDA] Initial support for host compilation (#1309)
Adds support for `__host__` and `__device__` functions when compiling for CUDA host. The conditions I checked against is taken from OG.
1 parent 5373f42 commit 90a5b61

File tree

4 files changed

+167
-5
lines changed

4 files changed

+167
-5
lines changed

clang/lib/CIR/CodeGen/CIRGenModule.cpp

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -514,7 +514,19 @@ void CIRGenModule::emitGlobal(GlobalDecl GD) {
514514

515515
assert(!Global->hasAttr<IFuncAttr>() && "NYI");
516516
assert(!Global->hasAttr<CPUDispatchAttr>() && "NYI");
517-
assert(!langOpts.CUDA && "NYI");
517+
518+
if (langOpts.CUDA) {
519+
if (langOpts.CUDAIsDevice)
520+
llvm_unreachable("NYI");
521+
522+
if (dyn_cast<VarDecl>(Global))
523+
llvm_unreachable("NYI");
524+
525+
// We must skip __device__ functions when compiling for host.
526+
if (!Global->hasAttr<CUDAHostAttr>() && Global->hasAttr<CUDADeviceAttr>()) {
527+
return;
528+
}
529+
}
518530

519531
if (langOpts.OpenMP) {
520532
// If this is OpenMP, check if it is legal to emit this global normally.
@@ -557,6 +569,7 @@ void CIRGenModule::emitGlobal(GlobalDecl GD) {
557569
return;
558570
}
559571
} else {
572+
assert(!langOpts.CUDA && "NYI");
560573
const auto *VD = cast<VarDecl>(Global);
561574
assert(VD->isFileVarDecl() && "Cannot emit local var decl as global.");
562575
if (VD->isThisDeclarationADefinition() != VarDecl::Definition &&
@@ -2322,7 +2335,13 @@ cir::FuncOp CIRGenModule::GetAddrOfFunction(clang::GlobalDecl GD, mlir::Type Ty,
23222335
auto F = GetOrCreateCIRFunction(MangledName, Ty, GD, ForVTable, DontDefer,
23232336
/*IsThunk=*/false, IsForDefinition);
23242337

2325-
assert(!langOpts.CUDA && "NYI");
2338+
// As __global__ functions always reside on device,
2339+
// we need special care when accessing them from host;
2340+
// otherwise, CUDA functions behave as normal functions
2341+
if (langOpts.CUDA && !langOpts.CUDAIsDevice &&
2342+
cast<FunctionDecl>(GD.getDecl())->hasAttr<CUDAGlobalAttr>()) {
2343+
llvm_unreachable("NYI");
2344+
}
23262345

23272346
return F;
23282347
}
@@ -3164,9 +3183,6 @@ void CIRGenModule::Release() {
31643183
assert(!MissingFeatures::registerGlobalDtorsWithAtExit());
31653184
assert(!MissingFeatures::emitCXXThreadLocalInitFunc());
31663185
assert(!MissingFeatures::objCRuntime());
3167-
if (astContext.getLangOpts().CUDA) {
3168-
llvm_unreachable("NYI");
3169-
}
31703186
assert(!MissingFeatures::openMPRuntime());
31713187
assert(!MissingFeatures::pgoReader());
31723188
assert(!MissingFeatures::emitCtorList()); // GlobalCtors, GlobalDtors

clang/lib/CIR/CodeGen/TargetInfo.cpp

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,30 @@ class SPIRVTargetCIRGenInfo : public CommonSPIRTargetCIRGenInfo {
305305

306306
} // namespace
307307

308+
//===----------------------------------------------------------------------===//
309+
// NVPTX ABI Implementation
310+
//===----------------------------------------------------------------------===//
311+
312+
namespace {
313+
314+
class NVPTXABIInfo : public ABIInfo {
315+
public:
316+
NVPTXABIInfo(CIRGenTypes &cgt) : ABIInfo(cgt) {}
317+
318+
cir::ABIArgInfo classifyReturnType(QualType retTy) const;
319+
cir::ABIArgInfo classifyArgumentType(QualType ty) const;
320+
321+
void computeInfo(CIRGenFunctionInfo &fnInfo) const override;
322+
};
323+
324+
class NVPTXTargetCIRGenInfo : public TargetCIRGenInfo {
325+
public:
326+
NVPTXTargetCIRGenInfo(CIRGenTypes &cgt)
327+
: TargetCIRGenInfo(std::make_unique<NVPTXABIInfo>(cgt)) {}
328+
};
329+
330+
} // namespace
331+
308332
// TODO(cir): remove the attribute once this gets used.
309333
LLVM_ATTRIBUTE_UNUSED
310334
static bool classifyReturnType(const CIRGenCXXABI &CXXABI,
@@ -443,6 +467,34 @@ cir::ABIArgInfo X86_64ABIInfo::classifyArgumentType(QualType Ty,
443467
return cir::ABIArgInfo::getDirect(ResType);
444468
}
445469

470+
// Skeleton only. Implement when used in TargetLower stage.
471+
cir::ABIArgInfo NVPTXABIInfo::classifyReturnType(QualType retTy) const {
472+
llvm_unreachable("not yet implemented");
473+
}
474+
475+
cir::ABIArgInfo NVPTXABIInfo::classifyArgumentType(QualType ty) const {
476+
llvm_unreachable("not yet implemented");
477+
}
478+
479+
void NVPTXABIInfo::computeInfo(CIRGenFunctionInfo &fnInfo) const {
480+
// Top level CIR has unlimited arguments and return types. Lowering for ABI
481+
// specific concerns should happen during a lowering phase. Assume everything
482+
// is direct for now.
483+
for (CIRGenFunctionInfo::arg_iterator it = fnInfo.arg_begin(),
484+
ie = fnInfo.arg_end();
485+
it != ie; ++it) {
486+
if (testIfIsVoidTy(it->type))
487+
it->info = cir::ABIArgInfo::getIgnore();
488+
else
489+
it->info = cir::ABIArgInfo::getDirect(CGT.convertType(it->type));
490+
}
491+
auto retTy = fnInfo.getReturnType();
492+
if (testIfIsVoidTy(retTy))
493+
fnInfo.getReturnInfo() = cir::ABIArgInfo::getIgnore();
494+
else
495+
fnInfo.getReturnInfo() = cir::ABIArgInfo::getDirect(CGT.convertType(retTy));
496+
}
497+
446498
ABIInfo::~ABIInfo() {}
447499

448500
bool ABIInfo::isPromotableIntegerTypeForABI(QualType Ty) const {
@@ -634,5 +686,9 @@ const TargetCIRGenInfo &CIRGenModule::getTargetCIRGenInfo() {
634686
case llvm::Triple::spirv64: {
635687
return SetCIRGenInfo(new SPIRVTargetCIRGenInfo(genTypes));
636688
}
689+
690+
case llvm::Triple::nvptx64: {
691+
return SetCIRGenInfo(new NVPTXTargetCIRGenInfo(genTypes));
692+
}
637693
}
638694
}

clang/test/CIR/CodeGen/CUDA/simple.cu

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#include "../Inputs/cuda.h"
2+
3+
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -fclangir \
4+
// RUN: -emit-cir %s -o %t.cir
5+
// RUN: FileCheck --check-prefix=CIR --input-file=%t.cir %s
6+
7+
8+
// This should emit as a normal C++ function.
9+
__host__ void host_fn(int *a, int *b, int *c) {}
10+
11+
// CIR: cir.func @_Z7host_fnPiS_S_
12+
13+
// This shouldn't emit.
14+
__device__ void device_fn(int* a, double b, float c) {}
15+
16+
// CHECK-NOT: cir.func @_Z9device_fnPidf

clang/test/CIR/CodeGen/Inputs/cuda.h

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/* Minimal declarations for CUDA support. Testing purposes only. */
2+
/* From test/CodeGenCUDA/Inputs/cuda.h. */
3+
#include <stddef.h>
4+
5+
#if __HIP__ || __CUDA__
6+
#define __constant__ __attribute__((constant))
7+
#define __device__ __attribute__((device))
8+
#define __global__ __attribute__((global))
9+
#define __host__ __attribute__((host))
10+
#define __shared__ __attribute__((shared))
11+
#if __HIP__
12+
#define __managed__ __attribute__((managed))
13+
#endif
14+
#define __launch_bounds__(...) __attribute__((launch_bounds(__VA_ARGS__)))
15+
#define __grid_constant__ __attribute__((grid_constant))
16+
#else
17+
#define __constant__
18+
#define __device__
19+
#define __global__
20+
#define __host__
21+
#define __shared__
22+
#define __managed__
23+
#define __launch_bounds__(...)
24+
#define __grid_constant__
25+
#endif
26+
27+
struct dim3 {
28+
unsigned x, y, z;
29+
__host__ __device__ dim3(unsigned x, unsigned y = 1, unsigned z = 1) : x(x), y(y), z(z) {}
30+
};
31+
32+
#if __HIP__ || HIP_PLATFORM
33+
typedef struct hipStream *hipStream_t;
34+
typedef enum hipError {} hipError_t;
35+
int hipConfigureCall(dim3 gridSize, dim3 blockSize, size_t sharedSize = 0,
36+
hipStream_t stream = 0);
37+
extern "C" hipError_t __hipPushCallConfiguration(dim3 gridSize, dim3 blockSize,
38+
size_t sharedSize = 0,
39+
hipStream_t stream = 0);
40+
#ifndef __HIP_API_PER_THREAD_DEFAULT_STREAM__
41+
extern "C" hipError_t hipLaunchKernel(const void *func, dim3 gridDim,
42+
dim3 blockDim, void **args,
43+
size_t sharedMem,
44+
hipStream_t stream);
45+
#else
46+
extern "C" hipError_t hipLaunchKernel_spt(const void *func, dim3 gridDim,
47+
dim3 blockDim, void **args,
48+
size_t sharedMem,
49+
hipStream_t stream);
50+
#endif // __HIP_API_PER_THREAD_DEFAULT_STREAM__
51+
#elif __OFFLOAD_VIA_LLVM__
52+
extern "C" unsigned __llvmPushCallConfiguration(dim3 gridDim, dim3 blockDim,
53+
size_t sharedMem = 0, void *stream = 0);
54+
extern "C" unsigned llvmLaunchKernel(const void *func, dim3 gridDim, dim3 blockDim,
55+
void **args, size_t sharedMem = 0, void *stream = 0);
56+
#else
57+
typedef struct cudaStream *cudaStream_t;
58+
typedef enum cudaError {} cudaError_t;
59+
extern "C" int cudaConfigureCall(dim3 gridSize, dim3 blockSize,
60+
size_t sharedSize = 0,
61+
cudaStream_t stream = 0);
62+
extern "C" int __cudaPushCallConfiguration(dim3 gridSize, dim3 blockSize,
63+
size_t sharedSize = 0,
64+
cudaStream_t stream = 0);
65+
extern "C" cudaError_t cudaLaunchKernel(const void *func, dim3 gridDim,
66+
dim3 blockDim, void **args,
67+
size_t sharedMem, cudaStream_t stream);
68+
extern "C" cudaError_t cudaLaunchKernel_ptsz(const void *func, dim3 gridDim,
69+
dim3 blockDim, void **args,
70+
size_t sharedMem, cudaStream_t stream);
71+
72+
#endif
73+
74+
extern "C" __device__ int printf(const char*, ...);

0 commit comments

Comments
 (0)