@@ -473,6 +473,20 @@ NVPTXSerializer::compileToBinary(const std::string &ptxCode) {
473473 } \
474474 } while (false )
475475
476+ #if MLIR_ENABLE_NVPTXCOMPILER_NVFATBIN
477+ #include " nvFatbin.h"
478+
479+ #define RETURN_ON_NVFATBIN_ERROR (expr ) \
480+ do { \
481+ auto result = (expr); \
482+ if (result != nvFatbinResult::NVFATBIN_SUCCESS) { \
483+ emitError (loc) << llvm::Twine (#expr).concat (" failed with error: " ) \
484+ << nvFatbinGetErrorString (result); \
485+ return std::nullopt ; \
486+ } \
487+ } while (false )
488+ #endif // MLIR_ENABLE_NVPTXCOMPILER_NVFATBIN
489+
476490std::optional<SmallVector<char , 0 >>
477491NVPTXSerializer::compileToBinaryNVPTX (const std::string &ptxCode) {
478492 Location loc = getOperation ().getLoc ();
@@ -538,6 +552,34 @@ NVPTXSerializer::compileToBinaryNVPTX(const std::string &ptxCode) {
538552 });
539553#undef DEBUG_TYPE
540554 RETURN_ON_NVPTXCOMPILER_ERROR (nvPTXCompilerDestroy (&compiler));
555+
556+ #if MLIR_ENABLE_NVPTXCOMPILER_NVFATBIN
557+ bool useFatbin32 = llvm::any_of (cmdOpts.second , [](const char *option) {
558+ return llvm::StringRef (option) == " -32" ;
559+ });
560+
561+ if (targetOptions.getCompilationTarget () == gpu::CompilationTarget::Fatbin) {
562+ const char *cubinOpts[1 ] = {useFatbin32 ? " -32" : " -64" };
563+ nvFatbinHandle handle;
564+
565+ auto chip = getTarget ().getChip ();
566+ chip.consume_front (" sm_" );
567+
568+ RETURN_ON_NVFATBIN_ERROR (nvFatbinCreate (&handle, cubinOpts, 1 ));
569+ RETURN_ON_NVFATBIN_ERROR (nvFatbinAddCubin (
570+ handle, binary.data (), binary.size (), chip.data (), nullptr ));
571+ RETURN_ON_NVFATBIN_ERROR (nvFatbinAddPTX (
572+ handle, ptxCode.data (), ptxCode.size (), chip.data (), nullptr , nullptr ));
573+
574+ size_t fatbinSize;
575+ RETURN_ON_NVFATBIN_ERROR (nvFatbinSize (handle, &fatbinSize));
576+ SmallVector<char , 0 > fatbin (fatbinSize, 0 );
577+ RETURN_ON_NVFATBIN_ERROR (nvFatbinGet (handle, (void *)fatbin.data ()));
578+ RETURN_ON_NVFATBIN_ERROR (nvFatbinDestroy (&handle));
579+ return fatbin;
580+ }
581+ #endif // MLIR_ENABLE_NVPTXCOMPILER_NVFATBIN
582+
541583 return binary;
542584}
543585#endif // MLIR_ENABLE_NVPTXCOMPILER
0 commit comments