@@ -458,8 +458,7 @@ void printVersion(raw_ostream &OS) {
458458
459459namespace nvptx {
460460Expected<StringRef>
461- fatbinary (ArrayRef<std::pair<StringRef, StringRef>> InputFiles,
462- const ArgList &Args) {
461+ fatbinary (ArrayRef<OffloadingImage> Images, const ArgList &Args) {
463462 llvm::TimeTraceScope TimeScope (" NVPTX fatbinary" );
464463 // NVPTX uses the fatbinary program to bundle the linked images.
465464 Expected<std::string> FatBinaryPath =
@@ -481,9 +480,26 @@ fatbinary(ArrayRef<std::pair<StringRef, StringRef>> InputFiles,
481480 CmdArgs.push_back (Triple.isArch64Bit () ? " -64" : " -32" );
482481 CmdArgs.push_back (" --create" );
483482 CmdArgs.push_back (*TempFileOrErr);
484- for (const auto &[File, Arch] : InputFiles)
483+ for (const OffloadingImage &Image : Images) {
484+ StringRef File = Image.Image ->getBufferIdentifier ();
485+ StringRef Arch = Image.StringData .lookup (" arch" );
486+
487+ // Determine the kind based on image type
488+ const char *Kind = " elf" ;
489+ if (Image.TheImageKind == ImageKind::IMG_PTX)
490+ Kind = " ptx" ;
491+
492+ // Extract numeric SM value from arch
493+ // Arch can be "sm_75", "compute_75", or just "75"
494+ StringRef SMValue = Arch;
495+ if (Arch.starts_with (" sm_" ))
496+ SMValue = Arch.drop_front (3 );
497+ else if (Arch.starts_with (" compute_" ))
498+ SMValue = Arch.drop_front (8 );
499+
485500 CmdArgs.push_back (
486- Args.MakeArgString (" --image=profile=" + Arch + " ,file=" + File));
501+ Args.MakeArgString (" --image3=kind=" + Twine (Kind) + " ,sm=" + SMValue + " ,file=" + File));
502+ }
487503
488504 if (Error Err = executeCommands (*FatBinaryPath, CmdArgs))
489505 return std::move (Err);
@@ -1992,12 +2008,7 @@ bundleSYCL(ArrayRef<OffloadingImage> Images) {
19922008
19932009Expected<SmallVector<std::unique_ptr<MemoryBuffer>>>
19942010bundleCuda (ArrayRef<OffloadingImage> Images, const ArgList &Args) {
1995- SmallVector<std::pair<StringRef, StringRef>, 4 > InputFiles;
1996- for (const OffloadingImage &Image : Images)
1997- InputFiles.emplace_back (std::make_pair (Image.Image ->getBufferIdentifier (),
1998- Image.StringData .lookup (" arch" )));
1999-
2000- auto FileOrErr = nvptx::fatbinary (InputFiles, Args);
2011+ auto FileOrErr = nvptx::fatbinary (Images, Args);
20012012 if (!FileOrErr)
20022013 return FileOrErr.takeError ();
20032014
@@ -2279,7 +2290,7 @@ linkAndWrapDeviceFiles(ArrayRef<SmallVector<OffloadFile>> LinkerInputFiles,
22792290 }
22802291 for (size_t I = 0 , E = SplitModules.size (); I != E; ++I) {
22812292 SmallVector<StringRef> Files = {SplitModules[I].ModuleFilePath };
2282- SmallVector<std::pair<StringRef, StringRef>, 4 > BundlerInputFiles ;
2293+ SmallVector<OffloadingImage, 4 > BundlerImages ;
22832294 auto ClangOutputOrErr =
22842295 linkDevice (Files, LinkerArgs, true /* IsSYCLKind */ ,
22852296 CompileLinkOptionsOrErr->first );
@@ -2292,14 +2303,36 @@ linkAndWrapDeviceFiles(ArrayRef<SmallVector<OffloadFile>> LinkerInputFiles,
22922303 nvptx::ptxas (*ClangOutputOrErr, LinkerArgs, Arch);
22932304 if (!PtxasOutputOrErr)
22942305 return PtxasOutputOrErr.takeError ();
2295- BundlerInputFiles.emplace_back (*ClangOutputOrErr, VirtualArch);
2296- BundlerInputFiles.emplace_back (*PtxasOutputOrErr, Arch);
2306+
2307+ // Create OffloadingImage for PTX output
2308+ OffloadingImage PtxImage;
2309+ PtxImage.TheImageKind = ImageKind::IMG_PTX;
2310+ PtxImage.TheOffloadKind = OffloadKind::OFK_Cuda;
2311+ PtxImage.StringData [" arch" ] = VirtualArch;
2312+ auto PtxBuffer = MemoryBuffer::getFile (*ClangOutputOrErr);
2313+ if (!PtxBuffer)
2314+ return createFileError (*ClangOutputOrErr, PtxBuffer.getError ());
2315+ PtxImage.Image = std::move (*PtxBuffer);
2316+ BundlerImages.push_back (std::move (PtxImage));
2317+
2318+ // Create OffloadingImage for Cubin output
2319+ OffloadingImage CubinImage;
2320+ CubinImage.TheImageKind = ImageKind::IMG_Cubin;
2321+ CubinImage.TheOffloadKind = OffloadKind::OFK_Cuda;
2322+ CubinImage.StringData [" arch" ] = Arch;
2323+ auto CubinBuffer = MemoryBuffer::getFile (*PtxasOutputOrErr);
2324+ if (!CubinBuffer)
2325+ return createFileError (*PtxasOutputOrErr, CubinBuffer.getError ());
2326+ CubinImage.Image = std::move (*CubinBuffer);
2327+ BundlerImages.push_back (std::move (CubinImage));
2328+
22972329 auto BundledFileOrErr =
2298- nvptx::fatbinary (BundlerInputFiles , LinkerArgs);
2330+ nvptx::fatbinary (BundlerImages , LinkerArgs);
22992331 if (!BundledFileOrErr)
23002332 return BundledFileOrErr.takeError ();
23012333 SplitModules[I].ModuleFilePath = *BundledFileOrErr;
23022334 } else if (Triple.isAMDGCN ()) {
2335+ SmallVector<std::pair<StringRef, StringRef>, 4 > BundlerInputFiles;
23032336 BundlerInputFiles.emplace_back (*ClangOutputOrErr, Arch);
23042337 auto BundledFileOrErr =
23052338 amdgcn::fatbinary (BundlerInputFiles, LinkerArgs);
0 commit comments