Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 21 additions & 16 deletions mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ class OffloadingTranslationAttrTrait
/// ensure type safeness. Targets are free to ignore these options.
class TargetOptions {
public:
using DiagnosticCallback = function_ref<InFlightDiagnostic()>;
using LLVMIRCallback =
function_ref<LogicalResult(llvm::Module &, DiagnosticCallback)>;
using ISACallback =
function_ref<LogicalResult(StringRef, DiagnosticCallback)>;
/// Constructor initializing the toolkit path, the list of files to link to,
/// extra command line options, the compilation target and a callback for
/// obtaining the parent symbol table. The default compilation target is
Expand All @@ -55,10 +60,10 @@ class TargetOptions {
StringRef cmdOptions = {}, StringRef elfSection = {},
CompilationTarget compilationTarget = getDefaultCompilationTarget(),
function_ref<SymbolTable *()> getSymbolTableCallback = {},
function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
function_ref<void(StringRef)> isaCallback = {});
LLVMIRCallback initialLlvmIRCallback = {},
LLVMIRCallback linkedLlvmIRCallback = {},
LLVMIRCallback optimizedLlvmIRCallback = {},
ISACallback isaCallback = {});

/// Returns the typeID.
TypeID getTypeID() const;
Expand Down Expand Up @@ -97,19 +102,19 @@ class TargetOptions {

/// Returns the callback invoked with the initial LLVM IR for the device
/// module.
function_ref<void(llvm::Module &)> getInitialLlvmIRCallback() const;
LLVMIRCallback getInitialLlvmIRCallback() const;

/// Returns the callback invoked with LLVM IR for the device module
/// after linking the device libraries.
function_ref<void(llvm::Module &)> getLinkedLlvmIRCallback() const;
LLVMIRCallback getLinkedLlvmIRCallback() const;

/// Returns the callback invoked with LLVM IR for the device module after
/// LLVM optimizations but before codegen.
function_ref<void(llvm::Module &)> getOptimizedLlvmIRCallback() const;
LLVMIRCallback getOptimizedLlvmIRCallback() const;

/// Returns the callback invoked with the target ISA for the device,
/// for example PTX assembly.
function_ref<void(StringRef)> getISACallback() const;
ISACallback getISACallback() const;

/// Returns the default compilation target: `CompilationTarget::Fatbin`.
static CompilationTarget getDefaultCompilationTarget();
Expand All @@ -127,10 +132,10 @@ class TargetOptions {
StringRef elfSection = {},
CompilationTarget compilationTarget = getDefaultCompilationTarget(),
function_ref<SymbolTable *()> getSymbolTableCallback = {},
function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
function_ref<void(StringRef)> isaCallback = {});
LLVMIRCallback initialLlvmIRCallback = {},
LLVMIRCallback linkedLlvmIRCallback = {},
LLVMIRCallback optimizedLlvmIRCallback = {},
ISACallback isaCallback = {});

/// Path to the target toolkit.
std::string toolkitPath;
Expand All @@ -153,19 +158,19 @@ class TargetOptions {
function_ref<SymbolTable *()> getSymbolTableCallback;

/// Callback invoked with the initial LLVM IR for the device module.
function_ref<void(llvm::Module &)> initialLlvmIRCallback;
LLVMIRCallback initialLlvmIRCallback;

/// Callback invoked with LLVM IR for the device module after
/// linking the device libraries.
function_ref<void(llvm::Module &)> linkedLlvmIRCallback;
LLVMIRCallback linkedLlvmIRCallback;

/// Callback invoked with LLVM IR for the device module after
/// LLVM optimizations but before codegen.
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback;
LLVMIRCallback optimizedLlvmIRCallback;

/// Callback invoked with the target ISA for the device,
/// for example PTX assembly.
function_ref<void(StringRef)> isaCallback;
ISACallback isaCallback;

private:
TypeID typeID;
Expand Down
26 changes: 15 additions & 11 deletions mlir/include/mlir/Target/LLVM/ModuleToObject.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,17 @@ class ModuleTranslation;
/// operations being transformed must be translatable into LLVM IR.
class ModuleToObject {
public:
ModuleToObject(
Operation &module, StringRef triple, StringRef chip,
StringRef features = {}, int optLevel = 3,
function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
function_ref<void(StringRef)> isaCallback = {});
using DiagnosticCallback = function_ref<InFlightDiagnostic()>;
using LLVMIRCallback =
function_ref<LogicalResult(llvm::Module &, DiagnosticCallback)>;
using ISACallback =
function_ref<LogicalResult(StringRef, DiagnosticCallback)>;
ModuleToObject(Operation &module, StringRef triple, StringRef chip,
StringRef features = {}, int optLevel = 3,
LLVMIRCallback initialLlvmIRCallback = {},
LLVMIRCallback linkedLlvmIRCallback = {},
LLVMIRCallback optimizedLlvmIRCallback = {},
ISACallback isaCallback = {});
virtual ~ModuleToObject();

/// Returns the operation being serialized.
Expand Down Expand Up @@ -120,19 +124,19 @@ class ModuleToObject {
int optLevel;

/// Callback invoked with the initial LLVM IR for the device module.
function_ref<void(llvm::Module &)> initialLlvmIRCallback;
LLVMIRCallback initialLlvmIRCallback;

/// Callback invoked with LLVM IR for the device module after
/// linking the device libraries.
function_ref<void(llvm::Module &)> linkedLlvmIRCallback;
LLVMIRCallback linkedLlvmIRCallback;

/// Callback invoked with LLVM IR for the device module after
/// LLVM optimizations but before codegen.
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback;
LLVMIRCallback optimizedLlvmIRCallback;

/// Callback invoked with the target ISA for the device,
/// for example PTX assembly.
function_ref<void(StringRef)> isaCallback;
ISACallback isaCallback;

private:
/// The TargetMachine created for the given Triple, if available.
Expand Down
22 changes: 8 additions & 14 deletions mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2652,10 +2652,8 @@ TargetOptions::TargetOptions(
StringRef cmdOptions, StringRef elfSection,
CompilationTarget compilationTarget,
function_ref<SymbolTable *()> getSymbolTableCallback,
function_ref<void(llvm::Module &)> initialLlvmIRCallback,
function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
function_ref<void(StringRef)> isaCallback)
LLVMIRCallback initialLlvmIRCallback, LLVMIRCallback linkedLlvmIRCallback,
LLVMIRCallback optimizedLlvmIRCallback, ISACallback isaCallback)
: TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, librariesToLink,
cmdOptions, elfSection, compilationTarget,
getSymbolTableCallback, initialLlvmIRCallback,
Expand All @@ -2667,10 +2665,8 @@ TargetOptions::TargetOptions(
StringRef cmdOptions, StringRef elfSection,
CompilationTarget compilationTarget,
function_ref<SymbolTable *()> getSymbolTableCallback,
function_ref<void(llvm::Module &)> initialLlvmIRCallback,
function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
function_ref<void(StringRef)> isaCallback)
LLVMIRCallback initialLlvmIRCallback, LLVMIRCallback linkedLlvmIRCallback,
LLVMIRCallback optimizedLlvmIRCallback, ISACallback isaCallback)
: toolkitPath(toolkitPath.str()), librariesToLink(librariesToLink),
cmdOptions(cmdOptions.str()), elfSection(elfSection.str()),
compilationTarget(compilationTarget),
Expand All @@ -2696,22 +2692,20 @@ SymbolTable *TargetOptions::getSymbolTable() const {
return getSymbolTableCallback ? getSymbolTableCallback() : nullptr;
}

function_ref<void(llvm::Module &)>
TargetOptions::getInitialLlvmIRCallback() const {
TargetOptions::LLVMIRCallback TargetOptions::getInitialLlvmIRCallback() const {
return initialLlvmIRCallback;
}

function_ref<void(llvm::Module &)>
TargetOptions::getLinkedLlvmIRCallback() const {
TargetOptions::LLVMIRCallback TargetOptions::getLinkedLlvmIRCallback() const {
return linkedLlvmIRCallback;
}

function_ref<void(llvm::Module &)>
TargetOptions::LLVMIRCallback
TargetOptions::getOptimizedLlvmIRCallback() const {
return optimizedLlvmIRCallback;
}

function_ref<void(StringRef)> TargetOptions::getISACallback() const {
TargetOptions::ISACallback TargetOptions::getISACallback() const {
return isaCallback;
}

Expand Down
25 changes: 16 additions & 9 deletions mlir/lib/Target/LLVM/ModuleToObject.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@
using namespace mlir;
using namespace mlir::LLVM;

ModuleToObject::ModuleToObject(
Operation &module, StringRef triple, StringRef chip, StringRef features,
int optLevel, function_ref<void(llvm::Module &)> initialLlvmIRCallback,
function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
function_ref<void(StringRef)> isaCallback)
ModuleToObject::ModuleToObject(Operation &module, StringRef triple,
StringRef chip, StringRef features, int optLevel,
LLVMIRCallback initialLlvmIRCallback,
LLVMIRCallback linkedLlvmIRCallback,
LLVMIRCallback optimizedLlvmIRCallback,
ISACallback isaCallback)
: module(module), triple(triple), chip(chip), features(features),
optLevel(optLevel), initialLlvmIRCallback(initialLlvmIRCallback),
linkedLlvmIRCallback(linkedLlvmIRCallback),
Expand Down Expand Up @@ -254,8 +254,13 @@ std::optional<SmallVector<char, 0>> ModuleToObject::run() {
}
setDataLayoutAndTriple(*llvmModule);

auto diagnosticCallback = [&]() -> InFlightDiagnostic {
return getOperation().emitError();
};

if (initialLlvmIRCallback)
initialLlvmIRCallback(*llvmModule);
if (failed(initialLlvmIRCallback(*llvmModule, diagnosticCallback)))
return std::nullopt;

// Link bitcode files.
handleModulePreLink(*llvmModule);
Expand All @@ -270,14 +275,16 @@ std::optional<SmallVector<char, 0>> ModuleToObject::run() {
}

if (linkedLlvmIRCallback)
linkedLlvmIRCallback(*llvmModule);
if (failed(linkedLlvmIRCallback(*llvmModule, diagnosticCallback)))
return std::nullopt;

// Optimize the module.
if (failed(optimizeModule(*llvmModule, optLevel)))
return std::nullopt;

if (optimizedLlvmIRCallback)
optimizedLlvmIRCallback(*llvmModule);
if (failed(optimizedLlvmIRCallback(*llvmModule, diagnosticCallback)))
return std::nullopt;

// Return the serialized object.
return moduleToObject(*llvmModule);
Expand Down
7 changes: 6 additions & 1 deletion mlir/lib/Target/LLVM/NVVM/Target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -707,8 +707,13 @@ NVPTXSerializer::moduleToObject(llvm::Module &llvmModule) {
return std::nullopt;
}

auto diagnosticCallback = [&]() -> InFlightDiagnostic {
return getOperation().emitError();
};

if (isaCallback)
isaCallback(serializedISA.value());
if (failed(isaCallback(serializedISA.value(), diagnosticCallback)))
return std::nullopt;

#define DEBUG_TYPE "serialize-to-isa"
LDBG() << "PTX for module: " << getOperation().getNameAttr() << "\n"
Expand Down
58 changes: 53 additions & 5 deletions mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,26 +176,40 @@ TEST_F(MLIRTargetLLVMNVVM,
ASSERT_TRUE(!!serializer);

std::string initialLLVMIR;
auto initialCallback = [&initialLLVMIR](llvm::Module &module) {
auto initialCallback =
[&initialLLVMIR](
llvm::Module &module,
gpu::TargetOptions::DiagnosticCallback) -> LogicalResult {
llvm::raw_string_ostream ros(initialLLVMIR);
module.print(ros, nullptr);
return success();
};

std::string linkedLLVMIR;
auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) {
auto linkedCallback =
[&linkedLLVMIR](llvm::Module &module,
gpu::TargetOptions::DiagnosticCallback) -> LogicalResult {
llvm::raw_string_ostream ros(linkedLLVMIR);
module.print(ros, nullptr);
return success();
};

std::string optimizedLLVMIR;
auto optimizedCallback = [&optimizedLLVMIR](llvm::Module &module) {
auto optimizedCallback =
[&optimizedLLVMIR](
llvm::Module &module,
gpu::TargetOptions::DiagnosticCallback) -> LogicalResult {
llvm::raw_string_ostream ros(optimizedLLVMIR);
module.print(ros, nullptr);
return success();
};

std::string isaResult;
auto isaCallback = [&isaResult](llvm::StringRef isa) {
auto isaCallback =
[&isaResult](llvm::StringRef isa,
gpu::TargetOptions::DiagnosticCallback) -> LogicalResult {
isaResult = isa.str();
return success();
};

gpu::TargetOptions options({}, {}, {}, {}, gpu::CompilationTarget::Assembly,
Expand All @@ -220,6 +234,36 @@ TEST_F(MLIRTargetLLVMNVVM,
}
}

// Test callback functions failure with ISA.
TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(CallbackFailedWithISA)) {
MLIRContext context(registry);

OwningOpRef<ModuleOp> module =
parseSourceString<ModuleOp>(moduleStr, &context);
ASSERT_TRUE(!!module);

NVVM::NVVMTargetAttr target = NVVM::NVVMTargetAttr::get(&context);

auto serializer = dyn_cast<gpu::TargetAttrInterface>(target);
ASSERT_TRUE(!!serializer);

auto isaCallback =
[](llvm::StringRef /*isa*/,
gpu::TargetOptions::DiagnosticCallback diag) -> LogicalResult {
return diag() << "test";
};

gpu::TargetOptions options({}, {}, {}, {}, gpu::CompilationTarget::Assembly,
{}, {}, {}, {}, isaCallback);

for (auto gpuModule : (*module).getBody()->getOps<gpu::GPUModuleOp>()) {
std::optional<SmallVector<char, 0>> object =
serializer.serializeToObject(gpuModule, options);

ASSERT_TRUE(object == std::nullopt);
}
}

// Test linking LLVM IR from a resource attribute.
TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(LinkedLLVMIRResource)) {
MLIRContext context(registry);
Expand Down Expand Up @@ -261,9 +305,13 @@ TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(LinkedLLVMIRResource)) {

// Hook to intercept the LLVM IR after linking external libs.
std::string linkedLLVMIR;
auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) {
auto linkedCallback =
[&linkedLLVMIR](
llvm::Module &module,
gpu::TargetOptions::DiagnosticCallback /*diag*/) -> LogicalResult {
llvm::raw_string_ostream ros(linkedLLVMIR);
module.print(ros, nullptr);
return success();
};

// Store the bitcode as a DenseI8ArrayAttr.
Expand Down
Loading