Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion flang/include/flang/Runtime/CUDA/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,19 @@

extern "C" {

// This function uses intptr_t instead of CUDA's unsigned int to match
// These functions use intptr_t instead of CUDA's unsigned int to match
// the type of MLIR's index type. This avoids the need for casts in the
// generated MLIR code.

void RTDEF(CUFLaunchKernel)(const void *kernelName, intptr_t gridX,
intptr_t gridY, intptr_t gridZ, intptr_t blockX, intptr_t blockY,
intptr_t blockZ, int32_t smem, void **params, void **extra);

void RTDEF(CUFLaunchClusterKernel)(const void *kernelName, intptr_t clusterX,
intptr_t clusterY, intptr_t clusterZ, intptr_t gridX, intptr_t gridY,
intptr_t gridZ, intptr_t blockX, intptr_t blockY, intptr_t blockZ,
int32_t smem, void **params, void **extra);

} // extern "C"

#endif // FORTRAN_RUNTIME_CUDA_KERNEL_H_
25 changes: 24 additions & 1 deletion flang/runtime/CUDA/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,32 @@ void RTDEF(CUFLaunchKernel)(const void *kernel, intptr_t gridX, intptr_t gridY,
blockDim.x = blockX;
blockDim.y = blockY;
blockDim.z = blockZ;
cudaStream_t stream = 0;
cudaStream_t stream = 0; // TODO stream managment
CUDA_REPORT_IF_ERROR(
cudaLaunchKernel(kernel, gridDim, blockDim, params, smem, stream));
}

void RTDEF(CUFLaunchClusterKernel)(const void *kernel, intptr_t clusterX,
intptr_t clusterY, intptr_t clusterZ, intptr_t gridX, intptr_t gridY,
intptr_t gridZ, intptr_t blockX, intptr_t blockY, intptr_t blockZ,
int32_t smem, void **params, void **extra) {
cudaLaunchConfig_t config;
config.gridDim.x = gridX;
config.gridDim.y = gridY;
config.gridDim.z = gridZ;
config.blockDim.x = blockX;
config.blockDim.y = blockY;
config.blockDim.z = blockZ;
config.dynamicSmemBytes = smem;
config.stream = 0; // TODO stream managment
cudaLaunchAttribute launchAttr[1];
launchAttr[0].id = cudaLaunchAttributeClusterDimension;
launchAttr[0].val.clusterDim.x = clusterX;
launchAttr[0].val.clusterDim.y = clusterY;
launchAttr[0].val.clusterDim.z = clusterZ;
config.numAttrs = 1;
config.attrs = launchAttr;
CUDA_REPORT_IF_ERROR(cudaLaunchKernelExC(&config, kernel, params));
}

} // extern "C"
Loading