Skip to content

Conversation

@clementval
Copy link
Contributor

-1, 1, 1 is passed when calling a kernel with the <<<*, block>>> syntax. Query the device to compute the grid.x value.

@clementval clementval requested a review from wangzpgi November 8, 2024 19:43
@llvmbot llvmbot added flang:runtime flang Flang issues not falling into any other category labels Nov 8, 2024
@llvmbot
Copy link
Member

llvmbot commented Nov 8, 2024

@llvm/pr-subscribers-flang-runtime

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

-1, 1, 1 is passed when calling a kernel with the &lt;&lt;&lt;*, block&gt;&gt;&gt; syntax. Query the device to compute the grid.x value.


Full diff: https://github.com/llvm/llvm-project/pull/115538.diff

1 Files Affected:

  • (modified) flang/runtime/CUDA/kernel.cpp (+46)
diff --git a/flang/runtime/CUDA/kernel.cpp b/flang/runtime/CUDA/kernel.cpp
index abb7ebb72e5923..8881d8a524aac0 100644
--- a/flang/runtime/CUDA/kernel.cpp
+++ b/flang/runtime/CUDA/kernel.cpp
@@ -25,6 +25,29 @@ void RTDEF(CUFLaunchKernel)(const void *kernel, intptr_t gridX, intptr_t gridY,
   blockDim.x = blockX;
   blockDim.y = blockY;
   blockDim.z = blockZ;
+  bool gridIsStar = (gridX < 0); // <<<*, block>>> syntax was used.
+  if (gridIsStar) {
+    int maxBlocks, nbBlocks, dev, multiProcCount;
+    cudaError_t err1, err2;
+    nbBlocks = blockDim.x * blockDim.y * blockDim.z;
+    cudaGetDevice(&dev);
+    err1 = cudaDeviceGetAttribute(
+        &multiProcCount, cudaDevAttrMultiProcessorCount, dev);
+    err2 = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
+        &maxBlocks, kernel, nbBlocks, smem);
+    if (err1 == cudaSuccess && err2 == cudaSuccess)
+      maxBlocks = multiProcCount * maxBlocks;
+    if (maxBlocks > 0) {
+      if (gridDim.y > 0)
+        maxBlocks = maxBlocks / gridDim.y;
+      if (gridDim.z > 0)
+        maxBlocks = maxBlocks / gridDim.z;
+      if (maxBlocks < 1)
+        maxBlocks = 1;
+      if (gridIsStar)
+        gridDim.x = maxBlocks;
+    }
+  }
   cudaStream_t stream = 0; // TODO stream managment
   CUDA_REPORT_IF_ERROR(
       cudaLaunchKernel(kernel, gridDim, blockDim, params, smem, stream));
@@ -41,6 +64,29 @@ void RTDEF(CUFLaunchClusterKernel)(const void *kernel, intptr_t clusterX,
   config.blockDim.x = blockX;
   config.blockDim.y = blockY;
   config.blockDim.z = blockZ;
+  bool gridIsStar = (gridX < 0); // <<<*, block>>> syntax was used.
+  if (gridIsStar) {
+    int maxBlocks, nbBlocks, dev, multiProcCount;
+    cudaError_t err1, err2;
+    nbBlocks = config.blockDim.x * config.blockDim.y * config.blockDim.z;
+    cudaGetDevice(&dev);
+    err1 = cudaDeviceGetAttribute(
+        &multiProcCount, cudaDevAttrMultiProcessorCount, dev);
+    err2 = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
+        &maxBlocks, kernel, nbBlocks, smem);
+    if (err1 == cudaSuccess && err2 == cudaSuccess)
+      maxBlocks = multiProcCount * maxBlocks;
+    if (maxBlocks > 0) {
+      if (config.gridDim.y > 0)
+        maxBlocks = maxBlocks / config.gridDim.y;
+      if (config.gridDim.z > 0)
+        maxBlocks = maxBlocks / config.gridDim.z;
+      if (maxBlocks < 1)
+        maxBlocks = 1;
+      if (gridIsStar)
+        config.gridDim.x = maxBlocks;
+    }
+  }
   config.dynamicSmemBytes = smem;
   config.stream = 0; // TODO stream managment
   cudaLaunchAttribute launchAttr[1];

@clementval clementval merged commit 6b21cf8 into llvm:main Nov 8, 2024
6 of 7 checks passed
@clementval clementval deleted the cuf_launch_kernel_compute branch November 8, 2024 22:34
Groverkss pushed a commit to iree-org/llvm-project that referenced this pull request Nov 15, 2024
llvm#115538)

`-1, 1, 1` is passed when calling a kernel with the `<<<*, block>>>`
syntax. Query the device to compute the grid.x value.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

flang:runtime flang Flang issues not falling into any other category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants