From 2175a3ffd798b3645311da4111aa159c4883314a Mon Sep 17 00:00:00 2001 From: Zichen Lu Date: Wed, 13 Nov 2024 15:35:46 +0800 Subject: [PATCH] Add callback functions for ModuleToObject --- .../Dialect/GPU/IR/CompilationInterfaces.h | 44 ++++++++- .../include/mlir/Target/LLVM/ModuleToObject.h | 24 ++++- mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 41 ++++++++- mlir/lib/Target/LLVM/ModuleToObject.cpp | 22 ++++- mlir/lib/Target/LLVM/NVVM/Target.cpp | 9 +- .../Target/LLVM/SerializeNVVMTarget.cpp | 59 ++++++++++++ .../Target/LLVM/SerializeToLLVMBitcode.cpp | 89 ++++++++++++++++++- 7 files changed, 275 insertions(+), 13 deletions(-) diff --git a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h index 6d7cb5ca7a7f8..d4b16a1de8edd 100644 --- a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h +++ b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h @@ -14,6 +14,7 @@ #define MLIR_DIALECT_GPU_IR_COMPILATIONINTERFACES_H #include "mlir/IR/Attributes.h" +#include "llvm/IR/Module.h" namespace llvm { class IRBuilderBase; @@ -52,7 +53,11 @@ class TargetOptions { StringRef toolkitPath = {}, ArrayRef linkFiles = {}, StringRef cmdOptions = {}, CompilationTarget compilationTarget = getDefaultCompilationTarget(), - function_ref getSymbolTableCallback = {}); + function_ref getSymbolTableCallback = {}, + function_ref initialLlvmIRCallback = {}, + function_ref linkedLlvmIRCallback = {}, + function_ref optimizedLlvmIRCallback = {}, + function_ref isaCallback = {}); /// Returns the typeID. TypeID getTypeID() const; @@ -80,6 +85,22 @@ class TargetOptions { /// table. SymbolTable *getSymbolTable() const; + /// Returns the callback invoked with the initial LLVM IR for the device + /// module. + function_ref getInitialLlvmIRCallback() const; + + /// Returns the callback invoked with LLVM IR for the device module + /// after linking the device libraries. + function_ref getLinkedLlvmIRCallback() const; + + /// Returns the callback invoked with LLVM IR for the device module after + /// LLVM optimizations but before codegen. + function_ref getOptimizedLlvmIRCallback() const; + + /// Returns the callback invoked with the target ISA for the device, + /// for example PTX assembly. + function_ref getISACallback() const; + /// Returns the default compilation target: `CompilationTarget::Fatbin`. static CompilationTarget getDefaultCompilationTarget(); @@ -90,7 +111,11 @@ class TargetOptions { TypeID typeID, StringRef toolkitPath = {}, ArrayRef linkFiles = {}, StringRef cmdOptions = {}, CompilationTarget compilationTarget = getDefaultCompilationTarget(), - function_ref getSymbolTableCallback = {}); + function_ref getSymbolTableCallback = {}, + function_ref initialLlvmIRCallback = {}, + function_ref linkedLlvmIRCallback = {}, + function_ref optimizedLlvmIRCallback = {}, + function_ref isaCallback = {}); /// Path to the target toolkit. std::string toolkitPath; @@ -109,6 +134,21 @@ class TargetOptions { /// being serialized. function_ref getSymbolTableCallback; + /// Callback invoked with the initial LLVM IR for the device module. + function_ref initialLlvmIRCallback; + + /// Callback invoked with LLVM IR for the device module after + /// linking the device libraries. + function_ref linkedLlvmIRCallback; + + /// Callback invoked with LLVM IR for the device module after + /// LLVM optimizations but before codegen. + function_ref optimizedLlvmIRCallback; + + /// Callback invoked with the target ISA for the device, + /// for example PTX assembly. + function_ref isaCallback; + private: TypeID typeID; }; diff --git a/mlir/include/mlir/Target/LLVM/ModuleToObject.h b/mlir/include/mlir/Target/LLVM/ModuleToObject.h index e40d7e9a43dd6..07fc55b41ae9c 100644 --- a/mlir/include/mlir/Target/LLVM/ModuleToObject.h +++ b/mlir/include/mlir/Target/LLVM/ModuleToObject.h @@ -29,8 +29,13 @@ class ModuleTranslation; /// operations being transformed must be translatable into LLVM IR. class ModuleToObject { public: - ModuleToObject(Operation &module, StringRef triple, StringRef chip, - StringRef features = {}, int optLevel = 3); + ModuleToObject( + Operation &module, StringRef triple, StringRef chip, + StringRef features = {}, int optLevel = 3, + function_ref initialLlvmIRCallback = {}, + function_ref linkedLlvmIRCallback = {}, + function_ref optimizedLlvmIRCallback = {}, + function_ref isaCallback = {}); virtual ~ModuleToObject(); /// Returns the operation being serialized. @@ -114,6 +119,21 @@ class ModuleToObject { /// Optimization level. int optLevel; + /// Callback invoked with the initial LLVM IR for the device module. + function_ref initialLlvmIRCallback; + + /// Callback invoked with LLVM IR for the device module after + /// linking the device libraries. + function_ref linkedLlvmIRCallback; + + /// Callback invoked with LLVM IR for the device module after + /// LLVM optimizations but before codegen. + function_ref optimizedLlvmIRCallback; + + /// Callback invoked with the target ISA for the device, + /// for example PTX assembly. + function_ref isaCallback; + private: /// The TargetMachine created for the given Triple, if available. /// Accessible through `getOrCreateTargetMachine()`. diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index 956877497d933..d62ea72dcea2f 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -2302,17 +2302,31 @@ KernelMetadataAttr KernelTableAttr::lookup(StringAttr key) const { TargetOptions::TargetOptions( StringRef toolkitPath, ArrayRef linkFiles, StringRef cmdOptions, CompilationTarget compilationTarget, - function_ref getSymbolTableCallback) + function_ref getSymbolTableCallback, + function_ref initialLlvmIRCallback, + function_ref linkedLlvmIRCallback, + function_ref optimizedLlvmIRCallback, + function_ref isaCallback) : TargetOptions(TypeID::get(), toolkitPath, linkFiles, - cmdOptions, compilationTarget, getSymbolTableCallback) {} + cmdOptions, compilationTarget, getSymbolTableCallback, + initialLlvmIRCallback, linkedLlvmIRCallback, + optimizedLlvmIRCallback, isaCallback) {} TargetOptions::TargetOptions( TypeID typeID, StringRef toolkitPath, ArrayRef linkFiles, StringRef cmdOptions, CompilationTarget compilationTarget, - function_ref getSymbolTableCallback) + function_ref getSymbolTableCallback, + function_ref initialLlvmIRCallback, + function_ref linkedLlvmIRCallback, + function_ref optimizedLlvmIRCallback, + function_ref isaCallback) : toolkitPath(toolkitPath.str()), linkFiles(linkFiles), cmdOptions(cmdOptions.str()), compilationTarget(compilationTarget), - getSymbolTableCallback(getSymbolTableCallback), typeID(typeID) {} + getSymbolTableCallback(getSymbolTableCallback), + initialLlvmIRCallback(initialLlvmIRCallback), + linkedLlvmIRCallback(linkedLlvmIRCallback), + optimizedLlvmIRCallback(optimizedLlvmIRCallback), + isaCallback(isaCallback), typeID(typeID) {} TypeID TargetOptions::getTypeID() const { return typeID; } @@ -2326,6 +2340,25 @@ SymbolTable *TargetOptions::getSymbolTable() const { return getSymbolTableCallback ? getSymbolTableCallback() : nullptr; } +function_ref +TargetOptions::getInitialLlvmIRCallback() const { + return initialLlvmIRCallback; +} + +function_ref +TargetOptions::getLinkedLlvmIRCallback() const { + return linkedLlvmIRCallback; +} + +function_ref +TargetOptions::getOptimizedLlvmIRCallback() const { + return optimizedLlvmIRCallback; +} + +function_ref TargetOptions::getISACallback() const { + return isaCallback; +} + CompilationTarget TargetOptions::getCompilationTarget() const { return compilationTarget; } diff --git a/mlir/lib/Target/LLVM/ModuleToObject.cpp b/mlir/lib/Target/LLVM/ModuleToObject.cpp index 77391341adaad..3f5b3d5e31864 100644 --- a/mlir/lib/Target/LLVM/ModuleToObject.cpp +++ b/mlir/lib/Target/LLVM/ModuleToObject.cpp @@ -34,10 +34,17 @@ using namespace mlir; using namespace mlir::LLVM; -ModuleToObject::ModuleToObject(Operation &module, StringRef triple, - StringRef chip, StringRef features, int optLevel) +ModuleToObject::ModuleToObject( + Operation &module, StringRef triple, StringRef chip, StringRef features, + int optLevel, function_ref initialLlvmIRCallback, + function_ref linkedLlvmIRCallback, + function_ref optimizedLlvmIRCallback, + function_ref isaCallback) : module(module), triple(triple), chip(chip), features(features), - optLevel(optLevel) {} + optLevel(optLevel), initialLlvmIRCallback(initialLlvmIRCallback), + linkedLlvmIRCallback(linkedLlvmIRCallback), + optimizedLlvmIRCallback(optimizedLlvmIRCallback), + isaCallback(isaCallback) {} ModuleToObject::~ModuleToObject() = default; @@ -215,6 +222,9 @@ std::optional> ModuleToObject::run() { } setDataLayoutAndTriple(*llvmModule); + if (initialLlvmIRCallback) + initialLlvmIRCallback(*llvmModule); + // Link bitcode files. handleModulePreLink(*llvmModule); { @@ -227,10 +237,16 @@ std::optional> ModuleToObject::run() { handleModulePostLink(*llvmModule); } + if (linkedLlvmIRCallback) + linkedLlvmIRCallback(*llvmModule); + // Optimize the module. if (failed(optimizeModule(*llvmModule, optLevel))) return std::nullopt; + if (optimizedLlvmIRCallback) + optimizedLlvmIRCallback(*llvmModule); + // Return the serialized object. return moduleToObject(*llvmModule); } diff --git a/mlir/lib/Target/LLVM/NVVM/Target.cpp b/mlir/lib/Target/LLVM/NVVM/Target.cpp index 69602af8563aa..bca26e3a0e84a 100644 --- a/mlir/lib/Target/LLVM/NVVM/Target.cpp +++ b/mlir/lib/Target/LLVM/NVVM/Target.cpp @@ -86,7 +86,11 @@ SerializeGPUModuleBase::SerializeGPUModuleBase( Operation &module, NVVMTargetAttr target, const gpu::TargetOptions &targetOptions) : ModuleToObject(module, target.getTriple(), target.getChip(), - target.getFeatures(), target.getO()), + target.getFeatures(), target.getO(), + targetOptions.getInitialLlvmIRCallback(), + targetOptions.getLinkedLlvmIRCallback(), + targetOptions.getOptimizedLlvmIRCallback(), + targetOptions.getISACallback()), target(target), toolkitPath(targetOptions.getToolkitPath()), fileList(targetOptions.getLinkFiles()) { @@ -572,6 +576,9 @@ NVPTXSerializer::moduleToObject(llvm::Module &llvmModule) { getOperation().emitError() << "Failed translating the module to ISA."; return std::nullopt; } + if (isaCallback) + isaCallback(serializedISA.value()); + #define DEBUG_TYPE "serialize-to-isa" LLVM_DEBUG({ llvm::dbgs() << "PTX for module: " << getOperation().getNameAttr() << "\n"; diff --git a/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp b/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp index 642aa04517809..eee9bd5f23475 100644 --- a/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp +++ b/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp @@ -156,3 +156,62 @@ TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(SerializeNVVMToBinary)) { ASSERT_TRUE(!object->empty()); } } + +// Test callback functions invoked with LLVM IR and ISA. +TEST_F(MLIRTargetLLVMNVVM, + SKIP_WITHOUT_NVPTX(CallbackInvokedWithLLVMIRAndISA)) { + MLIRContext context(registry); + + OwningOpRef module = + parseSourceString(moduleStr, &context); + ASSERT_TRUE(!!module); + + NVVM::NVVMTargetAttr target = NVVM::NVVMTargetAttr::get(&context); + + auto serializer = dyn_cast(target); + ASSERT_TRUE(!!serializer); + + std::string initialLLVMIR; + auto initialCallback = [&initialLLVMIR](llvm::Module &module) { + llvm::raw_string_ostream ros(initialLLVMIR); + module.print(ros, nullptr); + }; + + std::string linkedLLVMIR; + auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) { + llvm::raw_string_ostream ros(linkedLLVMIR); + module.print(ros, nullptr); + }; + + std::string optimizedLLVMIR; + auto optimizedCallback = [&optimizedLLVMIR](llvm::Module &module) { + llvm::raw_string_ostream ros(optimizedLLVMIR); + module.print(ros, nullptr); + }; + + std::string isaResult; + auto isaCallback = [&isaResult](llvm::StringRef isa) { + isaResult = isa.str(); + }; + + gpu::TargetOptions options({}, {}, {}, gpu::CompilationTarget::Assembly, {}, + initialCallback, linkedCallback, optimizedCallback, + isaCallback); + + for (auto gpuModule : (*module).getBody()->getOps()) { + std::optional> object = + serializer.serializeToObject(gpuModule, options); + + ASSERT_TRUE(object != std::nullopt); + ASSERT_TRUE(!object->empty()); + ASSERT_TRUE(!initialLLVMIR.empty()); + ASSERT_TRUE(!linkedLLVMIR.empty()); + ASSERT_TRUE(!optimizedLLVMIR.empty()); + ASSERT_TRUE(!isaResult.empty()); + + initialLLVMIR.clear(); + linkedLLVMIR.clear(); + optimizedLLVMIR.clear(); + isaResult.clear(); + } +} diff --git a/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp b/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp index 0d4277ed2fdfd..63d1dbd2519be 100644 --- a/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp +++ b/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp @@ -105,7 +105,9 @@ TargetAttrImpl::serializeToObject(Attribute attribute, Operation *module, // Set a dummy attr to be retrieved by `createObject`. module->setAttr("serialize_attr", UnitAttr::get(module->getContext())); std::string targetTriple = llvm::sys::getProcessTriple(); - LLVM::ModuleToObject serializer(*module, targetTriple, "", ""); + LLVM::ModuleToObject serializer( + *module, targetTriple, "", "", 3, options.getInitialLlvmIRCallback(), + options.getLinkedLlvmIRCallback(), options.getOptimizedLlvmIRCallback()); return serializer.run(); } @@ -153,3 +155,88 @@ TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(TargetAttrAPI)) { // `serializeToObject`. ASSERT_TRUE(properties.contains("serialize_attr")); } + +// Test callback function invoked with initial LLVM IR +TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackInvokedWithInitialLLVMIR)) { + MLIRContext context(registry); + + OwningOpRef module = + parseSourceString(moduleStr, &context); + ASSERT_TRUE(!!module); + Builder builder(&context); + IntegerAttr target = builder.getI32IntegerAttr(0); + auto targetAttr = dyn_cast(target); + + std::string initialLLVMIR; + auto initialCallback = [&initialLLVMIR](llvm::Module &module) { + llvm::raw_string_ostream ros(initialLLVMIR); + module.print(ros, nullptr); + }; + + gpu::TargetOptions opts( + {}, {}, {}, mlir::gpu::TargetOptions::getDefaultCompilationTarget(), {}, + initialCallback); + std::optional> serializedBinary = + targetAttr.serializeToObject(*module, opts); + + ASSERT_TRUE(serializedBinary != std::nullopt); + ASSERT_TRUE(!serializedBinary->empty()); + ASSERT_TRUE(!initialLLVMIR.empty()); +} + +// Test callback function invoked with linked LLVM IR +TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackInvokedWithLinkedLLVMIR)) { + MLIRContext context(registry); + + OwningOpRef module = + parseSourceString(moduleStr, &context); + ASSERT_TRUE(!!module); + Builder builder(&context); + IntegerAttr target = builder.getI32IntegerAttr(0); + auto targetAttr = dyn_cast(target); + + std::string linkedLLVMIR; + auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) { + llvm::raw_string_ostream ros(linkedLLVMIR); + module.print(ros, nullptr); + }; + + gpu::TargetOptions opts( + {}, {}, {}, mlir::gpu::TargetOptions::getDefaultCompilationTarget(), {}, + {}, linkedCallback); + std::optional> serializedBinary = + targetAttr.serializeToObject(*module, opts); + + ASSERT_TRUE(serializedBinary != std::nullopt); + ASSERT_TRUE(!serializedBinary->empty()); + ASSERT_TRUE(!linkedLLVMIR.empty()); +} + +// Test callback function invoked with optimized LLVM IR +TEST_F(MLIRTargetLLVM, + SKIP_WITHOUT_NATIVE(CallbackInvokedWithOptimizedLLVMIR)) { + MLIRContext context(registry); + + OwningOpRef module = + parseSourceString(moduleStr, &context); + ASSERT_TRUE(!!module); + Builder builder(&context); + IntegerAttr target = builder.getI32IntegerAttr(0); + auto targetAttr = dyn_cast(target); + + std::string optimizedLLVMIR; + auto optimizedCallback = [&optimizedLLVMIR](llvm::Module &module) { + llvm::raw_string_ostream ros(optimizedLLVMIR); + module.print(ros, nullptr); + }; + + gpu::TargetOptions opts( + {}, {}, {}, mlir::gpu::TargetOptions::getDefaultCompilationTarget(), {}, + {}, {}, optimizedCallback); + std::optional> serializedBinary = + targetAttr.serializeToObject(*module, opts); + + ASSERT_TRUE(serializedBinary != std::nullopt); + ASSERT_TRUE(!serializedBinary->empty()); + ASSERT_TRUE(!optimizedLLVMIR.empty()); +} \ No newline at end of file