diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td index 4a9ddafdd177d..313c58715ed51 100644 --- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td @@ -95,7 +95,19 @@ def GpuModuleToBinaryPass Option<"cmdOptions", "opts", "std::string", [{""}], "Command line options to pass to the tools.">, Option<"compilationTarget", "format", "std::string", [{"fatbin"}], - "The target representation of the compilation process."> + "The target representation of the compilation process.">, + Option<"initialLlvmIRCallback", "initialLlvmIRCallback", + "llvm::function_ref", "nullptr", + "Callback invoked with the initial LLVM IR for the device module.">, + Option<"linkedLlvmIRCallback", "linkedLlvmIRCallback", + "llvm::function_ref", "nullptr", + "Callback invoked with LLVM IR for the device module after linking the device libraries.">, + Option<"optimizedLlvmIRCallback", "optimizedLlvmIRCallback", + "llvm::function_ref", "nullptr", + "Callback invoked with LLVM IR for the device module after LLVM optimizations but before codegen.">, + Option<"isaCallback", "isaCallback", + "llvm::function_ref", "nullptr", + "Callback invoked with the target ISA for the device, for example PTX assembly."> ]; } diff --git a/mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp b/mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp index 86a3b4780e88c..ce399067c1b04 100644 --- a/mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp @@ -69,8 +69,10 @@ void GpuModuleToBinaryPass::runOnOperation() { return &parentTable.value(); }; - TargetOptions targetOptions(toolkitPath, linkFiles, cmdOptions, *targetFormat, - lazyTableBuilder); + TargetOptions targetOptions( + toolkitPath, linkFiles, cmdOptions, *targetFormat, lazyTableBuilder, + initialLlvmIRCallback.getValue(), linkedLlvmIRCallback.getValue(), + optimizedLlvmIRCallback.getValue(), isaCallback.getValue()); if (failed(transformGpuModulesToBinaries( getOperation(), OffloadingLLVMTranslationAttrInterface(nullptr), targetOptions)))