Skip to content

Commit d6c59d4

Browse files
committed
[mlir][Target] Support Fatbin target for static nvptxcompiler
1 parent 79f59af commit d6c59d4

File tree

3 files changed

+61
-0
lines changed

3 files changed

+61
-0
lines changed

mlir/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,9 @@ set(MLIR_ENABLE_VULKAN_RUNNER 0 CACHE BOOL "Enable building the MLIR Vulkan runn
135135
set(MLIR_ENABLE_NVPTXCOMPILER 0 CACHE BOOL
136136
"Statically link the nvptxlibrary instead of calling ptxas as a subprocess \
137137
for compiling PTX to cubin")
138+
set(MLIR_ENABLE_NVPTXCOMPILER_NVFATBIN 0 CACHE BOOL
139+
"Statically link the nvfatbin library instead of calling fatbinary as a subprocess \
140+
for compiling PTX to fatbin")
138141

139142
set(MLIR_ENABLE_PDL_IN_PATTERNMATCH 1 CACHE BOOL "Enable PDL in PatternMatch")
140143

mlir/lib/Target/LLVM/CMakeLists.txt

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,22 @@ if ("NVPTX" IN_LIST LLVM_TARGETS_TO_BUILD)
8888
# Link against `nvptxcompiler_static`. TODO: use `CUDA::nvptxcompiler_static`.
8989
target_link_libraries(MLIRNVVMTarget PRIVATE MLIR_NVPTXCOMPILER_LIB)
9090
target_include_directories(obj.MLIRNVVMTarget PUBLIC ${CUDAToolkit_INCLUDE_DIRS})
91+
92+
# Add the `nvfatbin` library.
93+
if(MLIR_ENABLE_NVPTXCOMPILER_NVFATBIN)
94+
find_library(MLIR_NVFATBIN_LIB_PATH nvfatbin_static
95+
PATHS ${CUDAToolkit_LIBRARY_DIR} NO_DEFAULT_PATH)
96+
# Fail if `nvfatbin_static` couldn't be found.
97+
if(MLIR_NVFATBIN_LIB_PATH STREQUAL "MLIR_NVFATBIN_LIB_PATH-NOTFOUND")
98+
message(FATAL_ERROR
99+
"Requested using the static `nvptxcompiler` library which requires the \
100+
'nvfatbin` library, but it couldn't be found.")
101+
endif()
102+
103+
add_library(MLIR_NVFATBIN_LIB STATIC IMPORTED GLOBAL)
104+
set_property(TARGET MLIR_NVFATBIN_LIB PROPERTY IMPORTED_LOCATION ${MLIR_NVFATBIN_LIB_PATH})
105+
target_link_libraries(MLIRNVVMTarget PRIVATE MLIR_NVFATBIN_LIB)
106+
endif()
91107
endif()
92108
else()
93109
# Fail if `MLIR_ENABLE_NVPTXCOMPILER` is enabled and the toolkit couldn't be found.

mlir/lib/Target/LLVM/NVVM/Target.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,20 @@ NVPTXSerializer::compileToBinary(const std::string &ptxCode) {
473473
} \
474474
} while (false)
475475

476+
#if MLIR_ENABLE_NVPTXCOMPILER_NVFATBIN
477+
#include "nvFatbin.h"
478+
479+
#define RETURN_ON_NVFATBIN_ERROR(expr) \
480+
do { \
481+
auto result = (expr); \
482+
if (result != nvFatbinResult::NVFATBIN_SUCCESS) { \
483+
emitError(loc) << llvm::Twine(#expr).concat(" failed with error: ") \
484+
<< nvFatbinGetErrorString(result); \
485+
return std::nullopt; \
486+
} \
487+
} while (false)
488+
#endif // MLIR_ENABLE_NVPTXCOMPILER_NVFATBIN
489+
476490
std::optional<SmallVector<char, 0>>
477491
NVPTXSerializer::compileToBinaryNVPTX(const std::string &ptxCode) {
478492
Location loc = getOperation().getLoc();
@@ -538,6 +552,34 @@ NVPTXSerializer::compileToBinaryNVPTX(const std::string &ptxCode) {
538552
});
539553
#undef DEBUG_TYPE
540554
RETURN_ON_NVPTXCOMPILER_ERROR(nvPTXCompilerDestroy(&compiler));
555+
556+
#if MLIR_ENABLE_NVPTXCOMPILER_NVFATBIN
557+
bool useFatbin32 = llvm::any_of(cmdOpts.second, [](const char *option) {
558+
return llvm::StringRef(option) == "-32";
559+
});
560+
561+
if (targetOptions.getCompilationTarget() == gpu::CompilationTarget::Fatbin) {
562+
const char *cubinOpts[1] = {useFatbin32 ? "-32" : "-64"};
563+
nvFatbinHandle handle;
564+
565+
auto chip = getTarget().getChip();
566+
chip.consume_front("sm_");
567+
568+
RETURN_ON_NVFATBIN_ERROR(nvFatbinCreate(&handle, cubinOpts, 1));
569+
RETURN_ON_NVFATBIN_ERROR(nvFatbinAddCubin(
570+
handle, binary.data(), binary.size(), chip.data(), nullptr));
571+
RETURN_ON_NVFATBIN_ERROR(nvFatbinAddPTX(
572+
handle, ptxCode.data(), ptxCode.size(), chip.data(), nullptr, nullptr));
573+
574+
size_t fatbinSize;
575+
RETURN_ON_NVFATBIN_ERROR(nvFatbinSize(handle, &fatbinSize));
576+
SmallVector<char, 0> fatbin(fatbinSize, 0);
577+
RETURN_ON_NVFATBIN_ERROR(nvFatbinGet(handle, (void *)fatbin.data()));
578+
RETURN_ON_NVFATBIN_ERROR(nvFatbinDestroy(&handle));
579+
return fatbin;
580+
}
581+
#endif // MLIR_ENABLE_NVPTXCOMPILER_NVFATBIN
582+
541583
return binary;
542584
}
543585
#endif // MLIR_ENABLE_NVPTXCOMPILER

0 commit comments

Comments
 (0)