@@ -473,6 +473,18 @@ NVPTXSerializer::compileToBinary(const std::string &ptxCode) {
473473 } \
474474 } while (false )
475475
476+ #include < nvFatbin.h>
477+
478+ #define RETURN_ON_NVFATBIN_ERROR (expr ) \
479+ do { \
480+ auto result = (expr); \
481+ if (result != nvFatbinResult::NVFATBIN_SUCCESS) { \
482+ emitError (loc) << llvm::Twine (#expr).concat (" failed with error: " ) \
483+ << nvFatbinGetErrorString (result); \
484+ return std::nullopt ; \
485+ } \
486+ } while (false )
487+
476488std::optional<SmallVector<char , 0 >>
477489NVPTXSerializer::compileToBinaryNVPTX (const std::string &ptxCode) {
478490 Location loc = getOperation ().getLoc ();
@@ -486,6 +498,11 @@ NVPTXSerializer::compileToBinaryNVPTX(const std::string &ptxCode) {
486498 targetOptions.tokenizeCmdOptions ();
487499 cmdOpts.second .append (
488500 {" -arch" , getTarget ().getChip ().data (), " --opt-level" , optLevel.c_str ()});
501+ bool useFatbin32 = false ;
502+ for (const char *option : cmdOpts.second ) {
503+ if (StringRef (option) == " -32" )
504+ useFatbin32 = true ;
505+ }
489506
490507 // Create the compiler handle.
491508 RETURN_ON_NVPTXCOMPILER_ERROR (
@@ -538,6 +555,30 @@ NVPTXSerializer::compileToBinaryNVPTX(const std::string &ptxCode) {
538555 });
539556#undef DEBUG_TYPE
540557 RETURN_ON_NVPTXCOMPILER_ERROR (nvPTXCompilerDestroy (&compiler));
558+
559+ if (targetOptions.getCompilationTarget () == gpu::CompilationTarget::Fatbin) {
560+ const char *cubinOpts[1 ] = {" -64" };
561+ if (useFatbin32) {
562+ cubinOpts[0 ] = {" -32" };
563+ }
564+ nvFatbinHandle handle;
565+
566+ auto chip = getTarget ().getChip ();
567+ chip.consume_front (" sm_" );
568+
569+ RETURN_ON_NVFATBIN_ERROR (nvFatbinCreate (&handle, cubinOpts, 1 ));
570+ RETURN_ON_NVFATBIN_ERROR (nvFatbinAddCubin (
571+ handle, binary.data (), binary.size (), chip.data (), nullptr ));
572+ RETURN_ON_NVFATBIN_ERROR (nvFatbinAddPTX (
573+ handle, ptxCode.data (), ptxCode.size (), chip.data (), nullptr , nullptr ));
574+
575+ size_t fatbinSize;
576+ RETURN_ON_NVFATBIN_ERROR (nvFatbinSize (handle, &fatbinSize));
577+ SmallVector<char , 0 > fatbin (fatbinSize, 0 );
578+ RETURN_ON_NVFATBIN_ERROR (nvFatbinGet (handle, (void *)fatbin.data ()));
579+ RETURN_ON_NVFATBIN_ERROR (nvFatbinDestroy (&handle));
580+ return fatbin;
581+ }
541582 return binary;
542583}
543584#endif // MLIR_ENABLE_NVPTXCOMPILER
0 commit comments