diff --git a/mlir/lib/Target/LLVM/CMakeLists.txt b/mlir/lib/Target/LLVM/CMakeLists.txt index 422f7e5fa7cae..543e917b528b3 100644 --- a/mlir/lib/Target/LLVM/CMakeLists.txt +++ b/mlir/lib/Target/LLVM/CMakeLists.txt @@ -88,6 +88,20 @@ if ("NVPTX" IN_LIST LLVM_TARGETS_TO_BUILD) # Link against `nvptxcompiler_static`. TODO: use `CUDA::nvptxcompiler_static`. target_link_libraries(MLIRNVVMTarget PRIVATE MLIR_NVPTXCOMPILER_LIB) target_include_directories(obj.MLIRNVVMTarget PUBLIC ${CUDAToolkit_INCLUDE_DIRS}) + + # Add the `nvfatbin` library. + find_library(MLIR_NVFATBIN_LIB_PATH nvfatbin_static + PATHS ${CUDAToolkit_LIBRARY_DIR} NO_DEFAULT_PATH) + # Fail if `nvfatbin_static` couldn't be found. + if(MLIR_NVFATBIN_LIB_PATH STREQUAL "MLIR_NVFATBIN_LIB_PATH-NOTFOUND") + message(FATAL_ERROR + "Requested using the static `nvptxcompiler` library which requires the \ + 'nvfatbin` library, but it couldn't be found.") + endif() + + add_library(MLIR_NVFATBIN_LIB STATIC IMPORTED GLOBAL) + set_property(TARGET MLIR_NVFATBIN_LIB PROPERTY IMPORTED_LOCATION ${MLIR_NVFATBIN_LIB_PATH}) + target_link_libraries(MLIRNVVMTarget PRIVATE MLIR_NVFATBIN_LIB) endif() else() # Fail if `MLIR_ENABLE_NVPTXCOMPILER` is enabled and the toolkit couldn't be found. diff --git a/mlir/lib/Target/LLVM/NVVM/Target.cpp b/mlir/lib/Target/LLVM/NVVM/Target.cpp index bca26e3a0e84a..3c92359915ded 100644 --- a/mlir/lib/Target/LLVM/NVVM/Target.cpp +++ b/mlir/lib/Target/LLVM/NVVM/Target.cpp @@ -473,6 +473,18 @@ NVPTXSerializer::compileToBinary(const std::string &ptxCode) { } \ } while (false) +#include "nvFatbin.h" + +#define RETURN_ON_NVFATBIN_ERROR(expr) \ + do { \ + auto result = (expr); \ + if (result != nvFatbinResult::NVFATBIN_SUCCESS) { \ + emitError(loc) << llvm::Twine(#expr).concat(" failed with error: ") \ + << nvFatbinGetErrorString(result); \ + return std::nullopt; \ + } \ + } while (false) + std::optional> NVPTXSerializer::compileToBinaryNVPTX(const std::string &ptxCode) { Location loc = getOperation().getLoc(); @@ -538,6 +550,32 @@ NVPTXSerializer::compileToBinaryNVPTX(const std::string &ptxCode) { }); #undef DEBUG_TYPE RETURN_ON_NVPTXCOMPILER_ERROR(nvPTXCompilerDestroy(&compiler)); + + if (targetOptions.getCompilationTarget() == gpu::CompilationTarget::Fatbin) { + bool useFatbin32 = llvm::any_of(cmdOpts.second, [](const char *option) { + return llvm::StringRef(option) == "-32"; + }); + + const char *cubinOpts[1] = {useFatbin32 ? "-32" : "-64"}; + nvFatbinHandle handle; + + auto chip = getTarget().getChip(); + chip.consume_front("sm_"); + + RETURN_ON_NVFATBIN_ERROR(nvFatbinCreate(&handle, cubinOpts, 1)); + RETURN_ON_NVFATBIN_ERROR(nvFatbinAddCubin( + handle, binary.data(), binary.size(), chip.data(), nullptr)); + RETURN_ON_NVFATBIN_ERROR(nvFatbinAddPTX( + handle, ptxCode.data(), ptxCode.size(), chip.data(), nullptr, nullptr)); + + size_t fatbinSize; + RETURN_ON_NVFATBIN_ERROR(nvFatbinSize(handle, &fatbinSize)); + SmallVector fatbin(fatbinSize, 0); + RETURN_ON_NVFATBIN_ERROR(nvFatbinGet(handle, (void *)fatbin.data())); + RETURN_ON_NVFATBIN_ERROR(nvFatbinDestroy(&handle)); + return fatbin; + } + return binary; } #endif // MLIR_ENABLE_NVPTXCOMPILER