-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[Offload] Support loading CUDA fat binaries #156955
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -45,6 +45,40 @@ struct CUDAKernelTy; | |||
| struct CUDADeviceTy; | ||||
| struct CUDAPluginTy; | ||||
|
|
||||
| /// Definitions for parsing CUDA fatbins | ||||
| typedef struct __attribute__((__packed__)) { | ||||
| int Magic; | ||||
| int Version; | ||||
| const char *Data; | ||||
| const char *DataEnd; | ||||
| } FatbinWrapperTy; | ||||
|
|
||||
| typedef struct __attribute__((__packed__)) { | ||||
| uint32_t Magic; | ||||
| uint16_t Version; | ||||
| uint16_t HeaderSize; | ||||
| uint64_t FatSize; | ||||
| } CudaFatbinHeader; | ||||
|
|
||||
| // Inspired by | ||||
| // https://github.com/n-eiling/cuda-fatbin-decompression/blob/master/fatbin-decompress.h | ||||
| typedef struct __attribute__((__packed__)) { | ||||
| uint16_t Kind; | ||||
| uint16_t Unknown1; | ||||
| uint32_t HeaderSize; | ||||
| uint64_t Size; | ||||
| uint32_t CompressedSize; | ||||
| uint32_t Unknown2; | ||||
| uint16_t Minor; | ||||
| uint16_t Major; | ||||
| uint32_t Arch; | ||||
| uint32_t ObjNameOffset; | ||||
| uint32_t ObjNameLen; | ||||
| uint64_t Flags; | ||||
| uint64_t Zero; | ||||
| uint64_t DecompressedSize; | ||||
| } CudaFatbinTextHeader; | ||||
|
|
||||
| #if (defined(CUDA_VERSION) && (CUDA_VERSION < 11000)) | ||||
| /// Forward declarations for all Virtual Memory Management | ||||
| /// related data structures and functions. This is necessary | ||||
|
|
@@ -548,6 +582,73 @@ struct CUDADeviceTy : public GenericDeviceTy { | |||
| CUcontext getCUDAContext() const { return Context; } | ||||
| CUdevice getCUDADevice() const { return Device; } | ||||
|
|
||||
| /// Extract the right device image from a CUDA fat binary | ||||
| Expected<__tgt_device_image *> | ||||
| readFatbin(const __tgt_device_image *TgtImage) { | ||||
| const char *ImageStart = | ||||
| reinterpret_cast<const char *>(TgtImage->ImageStart); | ||||
|
|
||||
| uint32_t Magic = *reinterpret_cast<const uint32_t *>(ImageStart); | ||||
| if (Magic == 0x466243b1) { | ||||
| // The fatbin is wrapped in FatbinWrapperTy, extract it | ||||
| const auto *FW = reinterpret_cast<const FatbinWrapperTy *>(ImageStart); | ||||
| ImageStart = FW->Data; | ||||
| } | ||||
|
|
||||
| const CudaFatbinHeader *Header = | ||||
| reinterpret_cast<const CudaFatbinHeader *>(ImageStart); | ||||
| size_t HeaderSize = static_cast<size_t>(Header->HeaderSize); // Usually 16 | ||||
| size_t FatbinSize = static_cast<size_t>(Header->FatSize); | ||||
|
|
||||
| const void *ProgramData = nullptr; | ||||
| size_t ProgramSize = 0; | ||||
| uint32_t ProgramArch = 0; | ||||
|
|
||||
| const char *ReadPosition = ImageStart + HeaderSize; | ||||
| while (ReadPosition < (ImageStart + FatbinSize)) { | ||||
| const CudaFatbinTextHeader *TextHeader = | ||||
| reinterpret_cast<const CudaFatbinTextHeader *>(ReadPosition); | ||||
| size_t TextHeaderSize = | ||||
| static_cast<size_t>(TextHeader->HeaderSize); // Usually 64 | ||||
| size_t CubinSize = static_cast<size_t>(TextHeader->Size); | ||||
| const char *CubinData = | ||||
| reinterpret_cast<const char *>(ReadPosition + TextHeaderSize); | ||||
|
|
||||
| uint32_t Arch = TextHeader->Arch; | ||||
|
|
||||
| StringRef Image(CubinData, CubinSize); | ||||
| Expected<bool> CompatibilityCheckResult = | ||||
| Plugin.isELFCompatible(DeviceId, Image); | ||||
|
|
||||
| if (!CompatibilityCheckResult) { | ||||
| return CompatibilityCheckResult.takeError(); | ||||
| } | ||||
|
|
||||
| if (*CompatibilityCheckResult && Arch > ProgramArch) { | ||||
| ProgramData = CubinData; | ||||
| ProgramSize = CubinSize; | ||||
| ProgramArch = Arch; | ||||
| } | ||||
|
|
||||
| ReadPosition += TextHeaderSize + CubinSize; | ||||
| } | ||||
|
|
||||
| if (!ProgramData) { | ||||
| return createStringError(std::errc::not_supported, | ||||
| "Failed to find compatible binary"); | ||||
| } | ||||
|
|
||||
| __tgt_device_image *Result = new __tgt_device_image; | ||||
| const char *Start = reinterpret_cast<const char *>(ProgramData); | ||||
| Result->ImageStart = const_cast<void *>(static_cast<const void *>(Start)); | ||||
| Result->ImageEnd = | ||||
| const_cast<void *>(static_cast<const void *>(Start + ProgramSize)); | ||||
| Result->EntriesBegin = nullptr; | ||||
| Result->EntriesEnd = nullptr; | ||||
|
|
||||
| return Result; | ||||
| } | ||||
|
|
||||
| /// Load the binary image into the device and allocate an image object. | ||||
| Expected<DeviceImageTy *> loadBinaryImpl(const __tgt_device_image *TgtImage, | ||||
| int32_t ImageId) override { | ||||
|
|
@@ -556,7 +657,19 @@ struct CUDADeviceTy : public GenericDeviceTy { | |||
|
|
||||
| // Allocate and initialize the image object. | ||||
| CUDADeviceImageTy *CUDAImage = Plugin.allocate<CUDADeviceImageTy>(); | ||||
| new (CUDAImage) CUDADeviceImageTy(ImageId, *this, TgtImage); | ||||
|
|
||||
| uint32_t Magic = *reinterpret_cast<const uint32_t *>(TgtImage->ImageStart); | ||||
| if (Magic == 0x466243b1 || Magic == 0xba55ed50) { | ||||
| // It's a fatbin or a wrapped fatbin | ||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is a wrapped fatbin in this context the section that CUDA uses for their fatbin? Presumably we'd handle that from outside here, as it's not a file format.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The way I got it, they wrap it at runtime and have it unwrapped if you store it as a file.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't understand that, this magic is for what you get out when you call
|
||||
| auto CubinOrErr = readFatbin(TgtImage); | ||||
| if (!CubinOrErr) { | ||||
| return CubinOrErr.takeError(); | ||||
| } | ||||
| __tgt_device_image *Cubin = *CubinOrErr; | ||||
| new (CUDAImage) CUDADeviceImageTy(ImageId, *this, Cubin); | ||||
| } else { | ||||
| new (CUDAImage) CUDADeviceImageTy(ImageId, *this, TgtImage); | ||||
| } | ||||
|
|
||||
| // Load the CUDA module. | ||||
| if (auto Err = CUDAImage->loadModule()) | ||||
|
|
||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this leak? I've been planning on reworking these images to make them owned by the plugin and not a reference as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I think it does. How would you resolve that? I think I could technically just modify image, but I'm not sure if that's too ugly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The JIT engine just puts it in an STL container or something so that when it's deconstructed it frees.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What exactly would you change here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For some reason, everything in here works through pointers to this thing even though it's a trivial struct, I think it's a leftover relic of when we built the plugins as dynamic libraries. You just need to make sure that this has permanent storage but gets cleaned up when the plugin is unloaded.