-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[Offload] Support loading CUDA fat binaries #156955
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be notified. If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers. If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
@llvm/pr-subscribers-offload Author: Jonas Greifenhain (cadivus) ChangesThis enables the cuda device plugin to load fat binaries directly without manually extracting the right device image. The reason for this PR is the discussion in #156259 Full diff: https://github.com/llvm/llvm-project/pull/156955.diff 1 Files Affected:
diff --git a/offload/plugins-nextgen/cuda/src/rtl.cpp b/offload/plugins-nextgen/cuda/src/rtl.cpp
index af3c74636bff3..ebb3bfdb8a724 100644
--- a/offload/plugins-nextgen/cuda/src/rtl.cpp
+++ b/offload/plugins-nextgen/cuda/src/rtl.cpp
@@ -45,6 +45,40 @@ struct CUDAKernelTy;
struct CUDADeviceTy;
struct CUDAPluginTy;
+/// Definitions for parsing CUDA fatbins
+typedef struct __attribute__((__packed__)) {
+ int Magic;
+ int Version;
+ const char *Data;
+ const char *DataEnd;
+} FatbinWrapperTy;
+
+typedef struct __attribute__((__packed__)) {
+ uint32_t Magic;
+ uint16_t Version;
+ uint16_t HeaderSize;
+ uint64_t FatSize;
+} CudaFatbinHeader;
+
+// Inspired by
+// https://github.com/n-eiling/cuda-fatbin-decompression/blob/master/fatbin-decompress.h
+typedef struct __attribute__((__packed__)) {
+ uint16_t Kind;
+ uint16_t Unknown1;
+ uint32_t HeaderSize;
+ uint64_t Size;
+ uint32_t CompressedSize;
+ uint32_t Unknown2;
+ uint16_t Minor;
+ uint16_t Major;
+ uint32_t Arch;
+ uint32_t ObjNameOffset;
+ uint32_t ObjNameLen;
+ uint64_t Flags;
+ uint64_t Zero;
+ uint64_t DecompressedSize;
+} CudaFatbinTextHeader;
+
#if (defined(CUDA_VERSION) && (CUDA_VERSION < 11000))
/// Forward declarations for all Virtual Memory Management
/// related data structures and functions. This is necessary
@@ -548,6 +582,73 @@ struct CUDADeviceTy : public GenericDeviceTy {
CUcontext getCUDAContext() const { return Context; }
CUdevice getCUDADevice() const { return Device; }
+ /// Extract the right device image from a CUDA fat binary
+ Expected<__tgt_device_image *>
+ readFatbin(const __tgt_device_image *TgtImage) {
+ const char *ImageStart =
+ reinterpret_cast<const char *>(TgtImage->ImageStart);
+
+ uint32_t Magic = *reinterpret_cast<const uint32_t *>(ImageStart);
+ if (Magic == 0x466243b1) {
+ // The fatbin is wrapped in FatbinWrapperTy, extract it
+ const auto *FW = reinterpret_cast<const FatbinWrapperTy *>(ImageStart);
+ ImageStart = FW->Data;
+ }
+
+ const CudaFatbinHeader *Header =
+ reinterpret_cast<const CudaFatbinHeader *>(ImageStart);
+ size_t HeaderSize = static_cast<size_t>(Header->HeaderSize); // Usually 16
+ size_t FatbinSize = static_cast<size_t>(Header->FatSize);
+
+ const void *ProgramData = nullptr;
+ size_t ProgramSize = 0;
+ uint32_t ProgramArch = 0;
+
+ const char *ReadPosition = ImageStart + HeaderSize;
+ while (ReadPosition < (ImageStart + FatbinSize)) {
+ const CudaFatbinTextHeader *TextHeader =
+ reinterpret_cast<const CudaFatbinTextHeader *>(ReadPosition);
+ size_t TextHeaderSize =
+ static_cast<size_t>(TextHeader->HeaderSize); // Usually 64
+ size_t CubinSize = static_cast<size_t>(TextHeader->Size);
+ const char *CubinData =
+ reinterpret_cast<const char *>(ReadPosition + TextHeaderSize);
+
+ uint32_t Arch = TextHeader->Arch;
+
+ StringRef Image(CubinData, CubinSize);
+ Expected<bool> CompatibilityCheckResult =
+ Plugin.isELFCompatible(DeviceId, Image);
+
+ if (!CompatibilityCheckResult) {
+ return CompatibilityCheckResult.takeError();
+ }
+
+ if (*CompatibilityCheckResult && Arch > ProgramArch) {
+ ProgramData = CubinData;
+ ProgramSize = CubinSize;
+ ProgramArch = Arch;
+ }
+
+ ReadPosition += TextHeaderSize + CubinSize;
+ }
+
+ if (!ProgramData) {
+ return createStringError(std::errc::not_supported,
+ "Failed to find compatible binary");
+ }
+
+ __tgt_device_image *Result = new __tgt_device_image;
+ const char *Start = reinterpret_cast<const char *>(ProgramData);
+ Result->ImageStart = const_cast<void *>(static_cast<const void *>(Start));
+ Result->ImageEnd =
+ const_cast<void *>(static_cast<const void *>(Start + ProgramSize));
+ Result->EntriesBegin = nullptr;
+ Result->EntriesEnd = nullptr;
+
+ return Result;
+ }
+
/// Load the binary image into the device and allocate an image object.
Expected<DeviceImageTy *> loadBinaryImpl(const __tgt_device_image *TgtImage,
int32_t ImageId) override {
@@ -556,7 +657,19 @@ struct CUDADeviceTy : public GenericDeviceTy {
// Allocate and initialize the image object.
CUDADeviceImageTy *CUDAImage = Plugin.allocate<CUDADeviceImageTy>();
- new (CUDAImage) CUDADeviceImageTy(ImageId, *this, TgtImage);
+
+ uint32_t Magic = *reinterpret_cast<const uint32_t *>(TgtImage->ImageStart);
+ if (Magic == 0x466243b1 || Magic == 0xba55ed50) {
+ // It's a fatbin or a wrapped fatbin
+ auto CubinOrErr = readFatbin(TgtImage);
+ if (!CubinOrErr) {
+ return CubinOrErr.takeError();
+ }
+ __tgt_device_image *Cubin = *CubinOrErr;
+ new (CUDAImage) CUDADeviceImageTy(ImageId, *this, Cubin);
+ } else {
+ new (CUDAImage) CUDADeviceImageTy(ImageId, *this, TgtImage);
+ }
// Load the CUDA module.
if (auto Err = CUDAImage->loadModule())
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'll also want to handle this when we check if an image is compatible, which means you'll need to handle the magic in the is_device_compatible
and is_plugin_compatible
. Likely you'll want a utility function, and we'll want some more strict parsing.
I'm not going to require a full LLVM object support, since we just want to treat this as an implementation detail. I'm unsure how we'd handle compression. I don't know what all their struct headers do, we don't need any for the ELF since we just get it from the ELF flags. For PTX we'd just JIT compile it anyway.
Here's a suggestion for what parsing this thing would look like, since we more or less just want to treat it as a container of files like an archive.
Error nvptx::extractFromFatbinary(StringRef Binary,
SmallVectorImpl<StringRef> &Entries) {
struct FatBinaryHeader {
uint8_t Magic[4];
uint16_t Version;
uint16_t HeaderSize;
uint64_t FatbinarySize;
};
struct FatBinaryEntry {
uint16_t EntryKind;
uint16_t Reserved;
uint32_t EntrySize;
uint64_t DataSize;
};
if (identify_magic(Binary) != file_magic::cuda_fatbinary)
return errorCodeToError(object::object_error::parse_failed);
if (!isAddrAligned(Align(alignof(FatBinaryHeader)), Binary.bytes_begin()))
return errorCodeToError(object::object_error::parse_failed);
const FatBinaryHeader *Header =
reinterpret_cast<const FatBinaryHeader *>(Binary.bytes_begin());
if (Binary.size() != Header->HeaderSize + Header->FatbinarySize)
return errorCodeToError(object::object_error::parse_failed);
const uint8_t *Buffer = Binary.bytes_begin() + Header->HeaderSize;
while (Buffer < Binary.bytes_end()) {
if (Buffer + sizeof(FatBinaryEntry) >= Binary.bytes_end())
return errorCodeToError(object::object_error::parse_failed);
const FatBinaryEntry *Entry =
reinterpret_cast<const FatBinaryEntry *>(Buffer);
if (Buffer + Entry->EntrySize + Entry->DataSize >= Binary.bytes_end())
return errorCodeToError(object::object_error::parse_failed);
Entries.emplace_back(
StringRef(reinterpret_cast<const char *>(Buffer) + Entry->EntrySize,
Entry->DataSize));
Buffer += Entry->EntrySize + Entry->DataSize;
}
return Error::success();
}
"Failed to find compatible binary"); | ||
} | ||
|
||
__tgt_device_image *Result = new __tgt_device_image; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this leak? I've been planning on reworking these images to make them owned by the plugin and not a reference as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I think it does. How would you resolve that? I think I could technically just modify image, but I'm not sure if that's too ugly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The JIT engine just puts it in an STL container or something so that when it's deconstructed it frees.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What exactly would you change here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For some reason, everything in here works through pointers to this thing even though it's a trivial struct, I think it's a leftover relic of when we built the plugins as dynamic libraries. You just need to make sure that this has permanent storage but gets cleaned up when the plugin is unloaded.
|
||
uint32_t Magic = *reinterpret_cast<const uint32_t *>(TgtImage->ImageStart); | ||
if (Magic == 0x466243b1 || Magic == 0xba55ed50) { | ||
// It's a fatbin or a wrapped fatbin |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is a wrapped fatbin in this context the section that CUDA uses for their fatbin? Presumably we'd handle that from outside here, as it's not a file format.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The way I got it, they wrap it at runtime and have it unwrapped if you store it as a file.
I think dealing with both here is the easier.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand that, this magic is for what you get out when you call fatbinary
. We then embed that into a special section in the CUDA wrapper code. If there's some mysterious second set of magic bits they use for what fatbinary
spits out, then we should add that to Magic.h
. I think you're confusing it with
constexpr unsigned CudaFatMagic = 0x466243b1; |
This enables the cuda device plugin to load fat binaries directly without manually extracting the right device image.
If this approach is okay, I will create a PR for AMD as well.
The reason for this PR is the discussion in #156259