Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 17 additions & 16 deletions mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@ class TargetOptions {
StringRef cmdOptions = {}, StringRef elfSection = {},
CompilationTarget compilationTarget = getDefaultCompilationTarget(),
function_ref<SymbolTable *()> getSymbolTableCallback = {},
function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
function_ref<void(StringRef)> isaCallback = {});
function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback = {},
function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback = {},
function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback = {},
function_ref<LogicalResult(StringRef)> isaCallback = {});

/// Returns the typeID.
TypeID getTypeID() const;
Expand Down Expand Up @@ -97,19 +97,20 @@ class TargetOptions {

/// Returns the callback invoked with the initial LLVM IR for the device
/// module.
function_ref<void(llvm::Module &)> getInitialLlvmIRCallback() const;
function_ref<LogicalResult(llvm::Module &)> getInitialLlvmIRCallback() const;

/// Returns the callback invoked with LLVM IR for the device module
/// after linking the device libraries.
function_ref<void(llvm::Module &)> getLinkedLlvmIRCallback() const;
function_ref<LogicalResult(llvm::Module &)> getLinkedLlvmIRCallback() const;

/// Returns the callback invoked with LLVM IR for the device module after
/// LLVM optimizations but before codegen.
function_ref<void(llvm::Module &)> getOptimizedLlvmIRCallback() const;
function_ref<LogicalResult(llvm::Module &)>
getOptimizedLlvmIRCallback() const;

/// Returns the callback invoked with the target ISA for the device,
/// for example PTX assembly.
function_ref<void(StringRef)> getISACallback() const;
function_ref<LogicalResult(StringRef)> getISACallback() const;

/// Returns the default compilation target: `CompilationTarget::Fatbin`.
static CompilationTarget getDefaultCompilationTarget();
Expand All @@ -127,10 +128,10 @@ class TargetOptions {
StringRef elfSection = {},
CompilationTarget compilationTarget = getDefaultCompilationTarget(),
function_ref<SymbolTable *()> getSymbolTableCallback = {},
function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
function_ref<void(StringRef)> isaCallback = {});
function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback = {},
function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback = {},
function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback = {},
function_ref<LogicalResult(StringRef)> isaCallback = {});

/// Path to the target toolkit.
std::string toolkitPath;
Expand All @@ -153,19 +154,19 @@ class TargetOptions {
function_ref<SymbolTable *()> getSymbolTableCallback;

/// Callback invoked with the initial LLVM IR for the device module.
function_ref<void(llvm::Module &)> initialLlvmIRCallback;
function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback;

/// Callback invoked with LLVM IR for the device module after
/// linking the device libraries.
function_ref<void(llvm::Module &)> linkedLlvmIRCallback;
function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback;

/// Callback invoked with LLVM IR for the device module after
/// LLVM optimizations but before codegen.
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback;
function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback;

/// Callback invoked with the target ISA for the device,
/// for example PTX assembly.
function_ref<void(StringRef)> isaCallback;
function_ref<LogicalResult(StringRef)> isaCallback;

private:
TypeID typeID;
Expand Down
16 changes: 8 additions & 8 deletions mlir/include/mlir/Target/LLVM/ModuleToObject.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ class ModuleToObject {
ModuleToObject(
Operation &module, StringRef triple, StringRef chip,
StringRef features = {}, int optLevel = 3,
function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
function_ref<void(StringRef)> isaCallback = {});
function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback = {},
function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback = {},
function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback = {},
function_ref<LogicalResult(StringRef)> isaCallback = {});
virtual ~ModuleToObject();

/// Returns the operation being serialized.
Expand Down Expand Up @@ -120,19 +120,19 @@ class ModuleToObject {
int optLevel;

/// Callback invoked with the initial LLVM IR for the device module.
function_ref<void(llvm::Module &)> initialLlvmIRCallback;
function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback;

/// Callback invoked with LLVM IR for the device module after
/// linking the device libraries.
function_ref<void(llvm::Module &)> linkedLlvmIRCallback;
function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback;

/// Callback invoked with LLVM IR for the device module after
/// LLVM optimizations but before codegen.
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback;
function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback;

/// Callback invoked with the target ISA for the device,
/// for example PTX assembly.
function_ref<void(StringRef)> isaCallback;
function_ref<LogicalResult(StringRef)> isaCallback;

private:
/// The TargetMachine created for the given Triple, if available.
Expand Down
24 changes: 12 additions & 12 deletions mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2652,10 +2652,10 @@ TargetOptions::TargetOptions(
StringRef cmdOptions, StringRef elfSection,
CompilationTarget compilationTarget,
function_ref<SymbolTable *()> getSymbolTableCallback,
function_ref<void(llvm::Module &)> initialLlvmIRCallback,
function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
function_ref<void(StringRef)> isaCallback)
function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback,
function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback,
function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback,
function_ref<LogicalResult(StringRef)> isaCallback)
: TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, librariesToLink,
cmdOptions, elfSection, compilationTarget,
getSymbolTableCallback, initialLlvmIRCallback,
Expand All @@ -2667,10 +2667,10 @@ TargetOptions::TargetOptions(
StringRef cmdOptions, StringRef elfSection,
CompilationTarget compilationTarget,
function_ref<SymbolTable *()> getSymbolTableCallback,
function_ref<void(llvm::Module &)> initialLlvmIRCallback,
function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
function_ref<void(StringRef)> isaCallback)
function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback,
function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback,
function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback,
function_ref<LogicalResult(StringRef)> isaCallback)
: toolkitPath(toolkitPath.str()), librariesToLink(librariesToLink),
cmdOptions(cmdOptions.str()), elfSection(elfSection.str()),
compilationTarget(compilationTarget),
Expand All @@ -2696,22 +2696,22 @@ SymbolTable *TargetOptions::getSymbolTable() const {
return getSymbolTableCallback ? getSymbolTableCallback() : nullptr;
}

function_ref<void(llvm::Module &)>
function_ref<LogicalResult(llvm::Module &)>
TargetOptions::getInitialLlvmIRCallback() const {
return initialLlvmIRCallback;
}

function_ref<void(llvm::Module &)>
function_ref<LogicalResult(llvm::Module &)>
TargetOptions::getLinkedLlvmIRCallback() const {
return linkedLlvmIRCallback;
}

function_ref<void(llvm::Module &)>
function_ref<LogicalResult(llvm::Module &)>
TargetOptions::getOptimizedLlvmIRCallback() const {
return optimizedLlvmIRCallback;
}

function_ref<void(StringRef)> TargetOptions::getISACallback() const {
function_ref<LogicalResult(StringRef)> TargetOptions::getISACallback() const {
return isaCallback;
}

Expand Down
33 changes: 23 additions & 10 deletions mlir/lib/Target/LLVM/ModuleToObject.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,11 @@ using namespace mlir::LLVM;

ModuleToObject::ModuleToObject(
Operation &module, StringRef triple, StringRef chip, StringRef features,
int optLevel, function_ref<void(llvm::Module &)> initialLlvmIRCallback,
function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
function_ref<void(StringRef)> isaCallback)
int optLevel,
function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback,
function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback,
function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback,
function_ref<LogicalResult(StringRef)> isaCallback)
: module(module), triple(triple), chip(chip), features(features),
optLevel(optLevel), initialLlvmIRCallback(initialLlvmIRCallback),
linkedLlvmIRCallback(linkedLlvmIRCallback),
Expand Down Expand Up @@ -254,8 +255,12 @@ std::optional<SmallVector<char, 0>> ModuleToObject::run() {
}
setDataLayoutAndTriple(*llvmModule);

if (initialLlvmIRCallback)
initialLlvmIRCallback(*llvmModule);
if (initialLlvmIRCallback) {
if (failed(initialLlvmIRCallback(*llvmModule))) {
getOperation().emitError() << "InitialLLVMIRCallback failed.";
return std::nullopt;
}
}

// Link bitcode files.
handleModulePreLink(*llvmModule);
Expand All @@ -269,15 +274,23 @@ std::optional<SmallVector<char, 0>> ModuleToObject::run() {
handleModulePostLink(*llvmModule);
}

if (linkedLlvmIRCallback)
linkedLlvmIRCallback(*llvmModule);
if (linkedLlvmIRCallback) {
if (failed(linkedLlvmIRCallback(*llvmModule))) {
getOperation().emitError() << "LinkedLLVMIRCallback failed.";
return std::nullopt;
}
}

// Optimize the module.
if (failed(optimizeModule(*llvmModule, optLevel)))
return std::nullopt;

if (optimizedLlvmIRCallback)
optimizedLlvmIRCallback(*llvmModule);
if (optimizedLlvmIRCallback) {
if (failed(optimizedLlvmIRCallback(*llvmModule))) {
getOperation().emitError() << "OptimizedLLVMIRCallback failed.";
return std::nullopt;
}
}

// Return the serialized object.
return moduleToObject(*llvmModule);
Expand Down
8 changes: 6 additions & 2 deletions mlir/lib/Target/LLVM/NVVM/Target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -707,8 +707,12 @@ NVPTXSerializer::moduleToObject(llvm::Module &llvmModule) {
return std::nullopt;
}

if (isaCallback)
isaCallback(serializedISA.value());
if (isaCallback) {
if (failed(isaCallback(serializedISA.value()))) {
getOperation().emitError() << "ISACallback failed.";
return std::nullopt;
}
}

#define DEBUG_TYPE "serialize-to-isa"
LDBG() << "PTX for module: " << getOperation().getNameAttr() << "\n"
Expand Down
45 changes: 40 additions & 5 deletions mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,26 +176,32 @@ TEST_F(MLIRTargetLLVMNVVM,
ASSERT_TRUE(!!serializer);

std::string initialLLVMIR;
auto initialCallback = [&initialLLVMIR](llvm::Module &module) {
auto initialCallback =
[&initialLLVMIR](llvm::Module &module) -> LogicalResult {
llvm::raw_string_ostream ros(initialLLVMIR);
module.print(ros, nullptr);
return success();
};

std::string linkedLLVMIR;
auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) {
auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) -> LogicalResult {
llvm::raw_string_ostream ros(linkedLLVMIR);
module.print(ros, nullptr);
return success();
};

std::string optimizedLLVMIR;
auto optimizedCallback = [&optimizedLLVMIR](llvm::Module &module) {
auto optimizedCallback =
[&optimizedLLVMIR](llvm::Module &module) -> LogicalResult {
llvm::raw_string_ostream ros(optimizedLLVMIR);
module.print(ros, nullptr);
return success();
};

std::string isaResult;
auto isaCallback = [&isaResult](llvm::StringRef isa) {
auto isaCallback = [&isaResult](llvm::StringRef isa) -> LogicalResult {
isaResult = isa.str();
return success();
};

gpu::TargetOptions options({}, {}, {}, {}, gpu::CompilationTarget::Assembly,
Expand All @@ -220,6 +226,34 @@ TEST_F(MLIRTargetLLVMNVVM,
}
}

// Test callback functions failure with ISA.
TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(CallbackFailedWithISA)) {
MLIRContext context(registry);

OwningOpRef<ModuleOp> module =
parseSourceString<ModuleOp>(moduleStr, &context);
ASSERT_TRUE(!!module);

NVVM::NVVMTargetAttr target = NVVM::NVVMTargetAttr::get(&context);

auto serializer = dyn_cast<gpu::TargetAttrInterface>(target);
ASSERT_TRUE(!!serializer);

auto isaCallback = [](llvm::StringRef /*isa*/) -> LogicalResult {
return failure();
};

gpu::TargetOptions options({}, {}, {}, {}, gpu::CompilationTarget::Assembly,
{}, {}, {}, {}, isaCallback);

for (auto gpuModule : (*module).getBody()->getOps<gpu::GPUModuleOp>()) {
std::optional<SmallVector<char, 0>> object =
serializer.serializeToObject(gpuModule, options);

ASSERT_TRUE(object == std::nullopt);
}
}

// Test linking LLVM IR from a resource attribute.
TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(LinkedLLVMIRResource)) {
MLIRContext context(registry);
Expand Down Expand Up @@ -261,9 +295,10 @@ TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(LinkedLLVMIRResource)) {

// Hook to intercept the LLVM IR after linking external libs.
std::string linkedLLVMIR;
auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) {
auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) -> LogicalResult {
llvm::raw_string_ostream ros(linkedLLVMIR);
module.print(ros, nullptr);
return success();
};

// Store the bitcode as a DenseI8ArrayAttr.
Expand Down
Loading