Skip to content

Conversation

cadivus
Copy link

@cadivus cadivus commented Sep 4, 2025

This enables the cuda device plugin to load fat binaries directly without manually extracting the right device image.
If this approach is okay, I will create a PR for AMD as well.

The reason for this PR is the discussion in #156259

Copy link

github-actions bot commented Sep 4, 2025

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot
Copy link
Member

llvmbot commented Sep 4, 2025

@llvm/pr-subscribers-offload

Author: Jonas Greifenhain (cadivus)

Changes

This enables the cuda device plugin to load fat binaries directly without manually extracting the right device image.
If this approach is okay, I will create a PR for AMD as well.

The reason for this PR is the discussion in #156259


Full diff: https://github.com/llvm/llvm-project/pull/156955.diff

1 Files Affected:

  • (modified) offload/plugins-nextgen/cuda/src/rtl.cpp (+114-1)
diff --git a/offload/plugins-nextgen/cuda/src/rtl.cpp b/offload/plugins-nextgen/cuda/src/rtl.cpp
index af3c74636bff3..ebb3bfdb8a724 100644
--- a/offload/plugins-nextgen/cuda/src/rtl.cpp
+++ b/offload/plugins-nextgen/cuda/src/rtl.cpp
@@ -45,6 +45,40 @@ struct CUDAKernelTy;
 struct CUDADeviceTy;
 struct CUDAPluginTy;
 
+/// Definitions for parsing CUDA fatbins
+typedef struct __attribute__((__packed__)) {
+  int Magic;
+  int Version;
+  const char *Data;
+  const char *DataEnd;
+} FatbinWrapperTy;
+
+typedef struct __attribute__((__packed__)) {
+  uint32_t Magic;
+  uint16_t Version;
+  uint16_t HeaderSize;
+  uint64_t FatSize;
+} CudaFatbinHeader;
+
+// Inspired by
+// https://github.com/n-eiling/cuda-fatbin-decompression/blob/master/fatbin-decompress.h
+typedef struct __attribute__((__packed__)) {
+  uint16_t Kind;
+  uint16_t Unknown1;
+  uint32_t HeaderSize;
+  uint64_t Size;
+  uint32_t CompressedSize;
+  uint32_t Unknown2;
+  uint16_t Minor;
+  uint16_t Major;
+  uint32_t Arch;
+  uint32_t ObjNameOffset;
+  uint32_t ObjNameLen;
+  uint64_t Flags;
+  uint64_t Zero;
+  uint64_t DecompressedSize;
+} CudaFatbinTextHeader;
+
 #if (defined(CUDA_VERSION) && (CUDA_VERSION < 11000))
 /// Forward declarations for all Virtual Memory Management
 /// related data structures and functions. This is necessary
@@ -548,6 +582,73 @@ struct CUDADeviceTy : public GenericDeviceTy {
   CUcontext getCUDAContext() const { return Context; }
   CUdevice getCUDADevice() const { return Device; }
 
+  /// Extract the right device image from a CUDA fat binary
+  Expected<__tgt_device_image *>
+  readFatbin(const __tgt_device_image *TgtImage) {
+    const char *ImageStart =
+        reinterpret_cast<const char *>(TgtImage->ImageStart);
+
+    uint32_t Magic = *reinterpret_cast<const uint32_t *>(ImageStart);
+    if (Magic == 0x466243b1) {
+      // The fatbin is wrapped in FatbinWrapperTy, extract it
+      const auto *FW = reinterpret_cast<const FatbinWrapperTy *>(ImageStart);
+      ImageStart = FW->Data;
+    }
+
+    const CudaFatbinHeader *Header =
+        reinterpret_cast<const CudaFatbinHeader *>(ImageStart);
+    size_t HeaderSize = static_cast<size_t>(Header->HeaderSize); // Usually 16
+    size_t FatbinSize = static_cast<size_t>(Header->FatSize);
+
+    const void *ProgramData = nullptr;
+    size_t ProgramSize = 0;
+    uint32_t ProgramArch = 0;
+
+    const char *ReadPosition = ImageStart + HeaderSize;
+    while (ReadPosition < (ImageStart + FatbinSize)) {
+      const CudaFatbinTextHeader *TextHeader =
+          reinterpret_cast<const CudaFatbinTextHeader *>(ReadPosition);
+      size_t TextHeaderSize =
+          static_cast<size_t>(TextHeader->HeaderSize); // Usually 64
+      size_t CubinSize = static_cast<size_t>(TextHeader->Size);
+      const char *CubinData =
+          reinterpret_cast<const char *>(ReadPosition + TextHeaderSize);
+
+      uint32_t Arch = TextHeader->Arch;
+
+      StringRef Image(CubinData, CubinSize);
+      Expected<bool> CompatibilityCheckResult =
+          Plugin.isELFCompatible(DeviceId, Image);
+
+      if (!CompatibilityCheckResult) {
+        return CompatibilityCheckResult.takeError();
+      }
+
+      if (*CompatibilityCheckResult && Arch > ProgramArch) {
+        ProgramData = CubinData;
+        ProgramSize = CubinSize;
+        ProgramArch = Arch;
+      }
+
+      ReadPosition += TextHeaderSize + CubinSize;
+    }
+
+    if (!ProgramData) {
+      return createStringError(std::errc::not_supported,
+                               "Failed to find compatible binary");
+    }
+
+    __tgt_device_image *Result = new __tgt_device_image;
+    const char *Start = reinterpret_cast<const char *>(ProgramData);
+    Result->ImageStart = const_cast<void *>(static_cast<const void *>(Start));
+    Result->ImageEnd =
+        const_cast<void *>(static_cast<const void *>(Start + ProgramSize));
+    Result->EntriesBegin = nullptr;
+    Result->EntriesEnd = nullptr;
+
+    return Result;
+  }
+
   /// Load the binary image into the device and allocate an image object.
   Expected<DeviceImageTy *> loadBinaryImpl(const __tgt_device_image *TgtImage,
                                            int32_t ImageId) override {
@@ -556,7 +657,19 @@ struct CUDADeviceTy : public GenericDeviceTy {
 
     // Allocate and initialize the image object.
     CUDADeviceImageTy *CUDAImage = Plugin.allocate<CUDADeviceImageTy>();
-    new (CUDAImage) CUDADeviceImageTy(ImageId, *this, TgtImage);
+
+    uint32_t Magic = *reinterpret_cast<const uint32_t *>(TgtImage->ImageStart);
+    if (Magic == 0x466243b1 || Magic == 0xba55ed50) {
+      // It's a fatbin or a wrapped fatbin
+      auto CubinOrErr = readFatbin(TgtImage);
+      if (!CubinOrErr) {
+        return CubinOrErr.takeError();
+      }
+      __tgt_device_image *Cubin = *CubinOrErr;
+      new (CUDAImage) CUDADeviceImageTy(ImageId, *this, Cubin);
+    } else {
+      new (CUDAImage) CUDADeviceImageTy(ImageId, *this, TgtImage);
+    }
 
     // Load the CUDA module.
     if (auto Err = CUDAImage->loadModule())

Copy link
Contributor

@jhuber6 jhuber6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll also want to handle this when we check if an image is compatible, which means you'll need to handle the magic in the is_device_compatible and is_plugin_compatible. Likely you'll want a utility function, and we'll want some more strict parsing.

I'm not going to require a full LLVM object support, since we just want to treat this as an implementation detail. I'm unsure how we'd handle compression. I don't know what all their struct headers do, we don't need any for the ELF since we just get it from the ELF flags. For PTX we'd just JIT compile it anyway.

Here's a suggestion for what parsing this thing would look like, since we more or less just want to treat it as a container of files like an archive.

Error nvptx::extractFromFatbinary(StringRef Binary,
                                  SmallVectorImpl<StringRef> &Entries) {
  struct FatBinaryHeader {
    uint8_t Magic[4];
    uint16_t Version;
    uint16_t HeaderSize;
    uint64_t FatbinarySize;
  };

  struct FatBinaryEntry {
    uint16_t EntryKind;
    uint16_t Reserved;
    uint32_t EntrySize;
    uint64_t DataSize;
  };

  if (identify_magic(Binary) != file_magic::cuda_fatbinary)
    return errorCodeToError(object::object_error::parse_failed);
  if (!isAddrAligned(Align(alignof(FatBinaryHeader)), Binary.bytes_begin()))
    return errorCodeToError(object::object_error::parse_failed);

  const FatBinaryHeader *Header =
      reinterpret_cast<const FatBinaryHeader *>(Binary.bytes_begin());
  if (Binary.size() != Header->HeaderSize + Header->FatbinarySize)
    return errorCodeToError(object::object_error::parse_failed);

  const uint8_t *Buffer = Binary.bytes_begin() + Header->HeaderSize;
  while (Buffer < Binary.bytes_end()) {
    if (Buffer + sizeof(FatBinaryEntry) >= Binary.bytes_end())
      return errorCodeToError(object::object_error::parse_failed);

    const FatBinaryEntry *Entry =
        reinterpret_cast<const FatBinaryEntry *>(Buffer);
    if (Buffer + Entry->EntrySize + Entry->DataSize >= Binary.bytes_end())
      return errorCodeToError(object::object_error::parse_failed);

    Entries.emplace_back(
        StringRef(reinterpret_cast<const char *>(Buffer) + Entry->EntrySize,
                  Entry->DataSize));

    Buffer += Entry->EntrySize + Entry->DataSize;
  }
  return Error::success();
}

"Failed to find compatible binary");
}

__tgt_device_image *Result = new __tgt_device_image;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this leak? I've been planning on reworking these images to make them owned by the plugin and not a reference as well.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think it does. How would you resolve that? I think I could technically just modify image, but I'm not sure if that's too ugly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The JIT engine just puts it in an STL container or something so that when it's deconstructed it frees.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What exactly would you change here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For some reason, everything in here works through pointers to this thing even though it's a trivial struct, I think it's a leftover relic of when we built the plugins as dynamic libraries. You just need to make sure that this has permanent storage but gets cleaned up when the plugin is unloaded.


uint32_t Magic = *reinterpret_cast<const uint32_t *>(TgtImage->ImageStart);
if (Magic == 0x466243b1 || Magic == 0xba55ed50) {
// It's a fatbin or a wrapped fatbin
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is a wrapped fatbin in this context the section that CUDA uses for their fatbin? Presumably we'd handle that from outside here, as it's not a file format.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The way I got it, they wrap it at runtime and have it unwrapped if you store it as a file.
I think dealing with both here is the easier.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand that, this magic is for what you get out when you call fatbinary. We then embed that into a special section in the CUDA wrapper code. If there's some mysterious second set of magic bits they use for what fatbinary spits out, then we should add that to Magic.h. I think you're confusing it with

constexpr unsigned CudaFatMagic = 0x466243b1;
which serves a different purpose, unless it's absolutely necessary that we pass a pointer to that section instead of just opening that struct in the hypothetical CUDA runtime we're building.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants