Skip to content

Commit 473482c

Browse files
authored
Load functions along with module.
This allows concurrent execution of different kernels (different function or different module). See https://discourse.llvm.org/t/how-to-lower-the-combination-of-async-gpu-ops-in-gpu-dialect/72796/17.
1 parent 1e66a9b commit 473482c

File tree

1 file changed

+14
-7
lines changed

1 file changed

+14
-7
lines changed

mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414

1515
#include "mlir/ExecutionEngine/CRunnerUtils.h"
1616

17-
#include <stdio.h>
17+
#include <cstdio>
18+
#include <vector>
1819

1920
#include "cuda.h"
2021
#include "cuda_bf16.h"
@@ -56,14 +57,10 @@
5657

5758
thread_local static int32_t defaultDevice = 0;
5859

59-
const char *kDebugEnvironmentVariable = "MLIR_CUDA_DEBUG";
60-
6160
/// Helper method that checks environment value for debugging.
6261
bool isDebugEnabled() {
63-
static bool isInitialized = false;
64-
static bool isEnabled = false;
65-
if (!isInitialized)
66-
isEnabled = getenv(kDebugEnvironmentVariable) != nullptr;
62+
const char *kDebugEnvironmentVariable = "MLIR_CUDA_DEBUG";
63+
static bool isEnabled = getenv(kDebugEnvironmentVariable) != nullptr;
6764
return isEnabled;
6865
}
6966

@@ -125,6 +122,16 @@ mgpuModuleLoad(void *data, size_t /*gpuBlobSize*/) {
125122
ScopedContext scopedContext;
126123
CUmodule module = nullptr;
127124
CUDA_REPORT_IF_ERROR(cuModuleLoadData(&module, data));
125+
// Preload functions in the module so that the first call to
126+
// cuModuleGetFunction below doesn't synchronize context.
127+
unsigned numFunctions = 0;
128+
CUDA_REPORT_IF_ERROR(cuModuleGetFunctionCount(&numFunctions, module));
129+
std::vector<CUfunction> functions(numFunctions);
130+
CUDA_REPORT_IF_ERROR(
131+
cuModuleEnumerateFunctions(functions.data(), numFunctions, module));
132+
for (CUfunction function : functions) {
133+
CUDA_REPORT_IF_ERROR(cuFuncLoad(function));
134+
}
128135
return module;
129136
}
130137

0 commit comments

Comments
 (0)