Skip to content

Commit 909d916

Browse files
committed
[mlir][gpu] Propagate errors from ModuleToObject callbacks
1 parent 24b87b8 commit 909d916

File tree

7 files changed

+186
-56
lines changed

7 files changed

+186
-56
lines changed

mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,10 @@ class TargetOptions {
5555
StringRef cmdOptions = {}, StringRef elfSection = {},
5656
CompilationTarget compilationTarget = getDefaultCompilationTarget(),
5757
function_ref<SymbolTable *()> getSymbolTableCallback = {},
58-
function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
59-
function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
60-
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
61-
function_ref<void(StringRef)> isaCallback = {});
58+
function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback = {},
59+
function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback = {},
60+
function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback = {},
61+
function_ref<LogicalResult(StringRef)> isaCallback = {});
6262

6363
/// Returns the typeID.
6464
TypeID getTypeID() const;
@@ -97,19 +97,20 @@ class TargetOptions {
9797

9898
/// Returns the callback invoked with the initial LLVM IR for the device
9999
/// module.
100-
function_ref<void(llvm::Module &)> getInitialLlvmIRCallback() const;
100+
function_ref<LogicalResult(llvm::Module &)> getInitialLlvmIRCallback() const;
101101

102102
/// Returns the callback invoked with LLVM IR for the device module
103103
/// after linking the device libraries.
104-
function_ref<void(llvm::Module &)> getLinkedLlvmIRCallback() const;
104+
function_ref<LogicalResult(llvm::Module &)> getLinkedLlvmIRCallback() const;
105105

106106
/// Returns the callback invoked with LLVM IR for the device module after
107107
/// LLVM optimizations but before codegen.
108-
function_ref<void(llvm::Module &)> getOptimizedLlvmIRCallback() const;
108+
function_ref<LogicalResult(llvm::Module &)>
109+
getOptimizedLlvmIRCallback() const;
109110

110111
/// Returns the callback invoked with the target ISA for the device,
111112
/// for example PTX assembly.
112-
function_ref<void(StringRef)> getISACallback() const;
113+
function_ref<LogicalResult(StringRef)> getISACallback() const;
113114

114115
/// Returns the default compilation target: `CompilationTarget::Fatbin`.
115116
static CompilationTarget getDefaultCompilationTarget();
@@ -127,10 +128,10 @@ class TargetOptions {
127128
StringRef elfSection = {},
128129
CompilationTarget compilationTarget = getDefaultCompilationTarget(),
129130
function_ref<SymbolTable *()> getSymbolTableCallback = {},
130-
function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
131-
function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
132-
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
133-
function_ref<void(StringRef)> isaCallback = {});
131+
function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback = {},
132+
function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback = {},
133+
function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback = {},
134+
function_ref<LogicalResult(StringRef)> isaCallback = {});
134135

135136
/// Path to the target toolkit.
136137
std::string toolkitPath;
@@ -153,19 +154,19 @@ class TargetOptions {
153154
function_ref<SymbolTable *()> getSymbolTableCallback;
154155

155156
/// Callback invoked with the initial LLVM IR for the device module.
156-
function_ref<void(llvm::Module &)> initialLlvmIRCallback;
157+
function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback;
157158

158159
/// Callback invoked with LLVM IR for the device module after
159160
/// linking the device libraries.
160-
function_ref<void(llvm::Module &)> linkedLlvmIRCallback;
161+
function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback;
161162

162163
/// Callback invoked with LLVM IR for the device module after
163164
/// LLVM optimizations but before codegen.
164-
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback;
165+
function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback;
165166

166167
/// Callback invoked with the target ISA for the device,
167168
/// for example PTX assembly.
168-
function_ref<void(StringRef)> isaCallback;
169+
function_ref<LogicalResult(StringRef)> isaCallback;
169170

170171
private:
171172
TypeID typeID;

mlir/include/mlir/Target/LLVM/ModuleToObject.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@ class ModuleToObject {
3232
ModuleToObject(
3333
Operation &module, StringRef triple, StringRef chip,
3434
StringRef features = {}, int optLevel = 3,
35-
function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
36-
function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
37-
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
38-
function_ref<void(StringRef)> isaCallback = {});
35+
function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback = {},
36+
function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback = {},
37+
function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback = {},
38+
function_ref<LogicalResult(StringRef)> isaCallback = {});
3939
virtual ~ModuleToObject();
4040

4141
/// Returns the operation being serialized.
@@ -120,19 +120,19 @@ class ModuleToObject {
120120
int optLevel;
121121

122122
/// Callback invoked with the initial LLVM IR for the device module.
123-
function_ref<void(llvm::Module &)> initialLlvmIRCallback;
123+
function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback;
124124

125125
/// Callback invoked with LLVM IR for the device module after
126126
/// linking the device libraries.
127-
function_ref<void(llvm::Module &)> linkedLlvmIRCallback;
127+
function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback;
128128

129129
/// Callback invoked with LLVM IR for the device module after
130130
/// LLVM optimizations but before codegen.
131-
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback;
131+
function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback;
132132

133133
/// Callback invoked with the target ISA for the device,
134134
/// for example PTX assembly.
135-
function_ref<void(StringRef)> isaCallback;
135+
function_ref<LogicalResult(StringRef)> isaCallback;
136136

137137
private:
138138
/// The TargetMachine created for the given Triple, if available.

mlir/lib/Dialect/GPU/IR/GPUDialect.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2652,10 +2652,10 @@ TargetOptions::TargetOptions(
26522652
StringRef cmdOptions, StringRef elfSection,
26532653
CompilationTarget compilationTarget,
26542654
function_ref<SymbolTable *()> getSymbolTableCallback,
2655-
function_ref<void(llvm::Module &)> initialLlvmIRCallback,
2656-
function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
2657-
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
2658-
function_ref<void(StringRef)> isaCallback)
2655+
function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback,
2656+
function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback,
2657+
function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback,
2658+
function_ref<LogicalResult(StringRef)> isaCallback)
26592659
: TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, librariesToLink,
26602660
cmdOptions, elfSection, compilationTarget,
26612661
getSymbolTableCallback, initialLlvmIRCallback,
@@ -2667,10 +2667,10 @@ TargetOptions::TargetOptions(
26672667
StringRef cmdOptions, StringRef elfSection,
26682668
CompilationTarget compilationTarget,
26692669
function_ref<SymbolTable *()> getSymbolTableCallback,
2670-
function_ref<void(llvm::Module &)> initialLlvmIRCallback,
2671-
function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
2672-
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
2673-
function_ref<void(StringRef)> isaCallback)
2670+
function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback,
2671+
function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback,
2672+
function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback,
2673+
function_ref<LogicalResult(StringRef)> isaCallback)
26742674
: toolkitPath(toolkitPath.str()), librariesToLink(librariesToLink),
26752675
cmdOptions(cmdOptions.str()), elfSection(elfSection.str()),
26762676
compilationTarget(compilationTarget),
@@ -2696,22 +2696,22 @@ SymbolTable *TargetOptions::getSymbolTable() const {
26962696
return getSymbolTableCallback ? getSymbolTableCallback() : nullptr;
26972697
}
26982698

2699-
function_ref<void(llvm::Module &)>
2699+
function_ref<LogicalResult(llvm::Module &)>
27002700
TargetOptions::getInitialLlvmIRCallback() const {
27012701
return initialLlvmIRCallback;
27022702
}
27032703

2704-
function_ref<void(llvm::Module &)>
2704+
function_ref<LogicalResult(llvm::Module &)>
27052705
TargetOptions::getLinkedLlvmIRCallback() const {
27062706
return linkedLlvmIRCallback;
27072707
}
27082708

2709-
function_ref<void(llvm::Module &)>
2709+
function_ref<LogicalResult(llvm::Module &)>
27102710
TargetOptions::getOptimizedLlvmIRCallback() const {
27112711
return optimizedLlvmIRCallback;
27122712
}
27132713

2714-
function_ref<void(StringRef)> TargetOptions::getISACallback() const {
2714+
function_ref<LogicalResult(StringRef)> TargetOptions::getISACallback() const {
27152715
return isaCallback;
27162716
}
27172717

mlir/lib/Target/LLVM/ModuleToObject.cpp

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,11 @@ using namespace mlir::LLVM;
3636

3737
ModuleToObject::ModuleToObject(
3838
Operation &module, StringRef triple, StringRef chip, StringRef features,
39-
int optLevel, function_ref<void(llvm::Module &)> initialLlvmIRCallback,
40-
function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
41-
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
42-
function_ref<void(StringRef)> isaCallback)
39+
int optLevel,
40+
function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback,
41+
function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback,
42+
function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback,
43+
function_ref<LogicalResult(StringRef)> isaCallback)
4344
: module(module), triple(triple), chip(chip), features(features),
4445
optLevel(optLevel), initialLlvmIRCallback(initialLlvmIRCallback),
4546
linkedLlvmIRCallback(linkedLlvmIRCallback),
@@ -254,8 +255,12 @@ std::optional<SmallVector<char, 0>> ModuleToObject::run() {
254255
}
255256
setDataLayoutAndTriple(*llvmModule);
256257

257-
if (initialLlvmIRCallback)
258-
initialLlvmIRCallback(*llvmModule);
258+
if (initialLlvmIRCallback) {
259+
if (failed(initialLlvmIRCallback(*llvmModule))) {
260+
getOperation().emitError() << "InitialLLVMIRCallback failed.";
261+
return std::nullopt;
262+
}
263+
}
259264

260265
// Link bitcode files.
261266
handleModulePreLink(*llvmModule);
@@ -269,15 +274,23 @@ std::optional<SmallVector<char, 0>> ModuleToObject::run() {
269274
handleModulePostLink(*llvmModule);
270275
}
271276

272-
if (linkedLlvmIRCallback)
273-
linkedLlvmIRCallback(*llvmModule);
277+
if (linkedLlvmIRCallback) {
278+
if (failed(linkedLlvmIRCallback(*llvmModule))) {
279+
getOperation().emitError() << "LinkedLLVMIRCallback failed.";
280+
return std::nullopt;
281+
}
282+
}
274283

275284
// Optimize the module.
276285
if (failed(optimizeModule(*llvmModule, optLevel)))
277286
return std::nullopt;
278287

279-
if (optimizedLlvmIRCallback)
280-
optimizedLlvmIRCallback(*llvmModule);
288+
if (optimizedLlvmIRCallback) {
289+
if (failed(optimizedLlvmIRCallback(*llvmModule))) {
290+
getOperation().emitError() << "OptimizedLLVMIRCallback failed.";
291+
return std::nullopt;
292+
}
293+
}
281294

282295
// Return the serialized object.
283296
return moduleToObject(*llvmModule);

mlir/lib/Target/LLVM/NVVM/Target.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -707,8 +707,12 @@ NVPTXSerializer::moduleToObject(llvm::Module &llvmModule) {
707707
return std::nullopt;
708708
}
709709

710-
if (isaCallback)
711-
isaCallback(serializedISA.value());
710+
if (isaCallback) {
711+
if (failed(isaCallback(serializedISA.value()))) {
712+
getOperation().emitError() << "ISACallback failed.";
713+
return std::nullopt;
714+
}
715+
}
712716

713717
#define DEBUG_TYPE "serialize-to-isa"
714718
LDBG() << "PTX for module: " << getOperation().getNameAttr() << "\n"

mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -176,26 +176,32 @@ TEST_F(MLIRTargetLLVMNVVM,
176176
ASSERT_TRUE(!!serializer);
177177

178178
std::string initialLLVMIR;
179-
auto initialCallback = [&initialLLVMIR](llvm::Module &module) {
179+
auto initialCallback =
180+
[&initialLLVMIR](llvm::Module &module) -> LogicalResult {
180181
llvm::raw_string_ostream ros(initialLLVMIR);
181182
module.print(ros, nullptr);
183+
return success();
182184
};
183185

184186
std::string linkedLLVMIR;
185-
auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) {
187+
auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) -> LogicalResult {
186188
llvm::raw_string_ostream ros(linkedLLVMIR);
187189
module.print(ros, nullptr);
190+
return success();
188191
};
189192

190193
std::string optimizedLLVMIR;
191-
auto optimizedCallback = [&optimizedLLVMIR](llvm::Module &module) {
194+
auto optimizedCallback =
195+
[&optimizedLLVMIR](llvm::Module &module) -> LogicalResult {
192196
llvm::raw_string_ostream ros(optimizedLLVMIR);
193197
module.print(ros, nullptr);
198+
return success();
194199
};
195200

196201
std::string isaResult;
197-
auto isaCallback = [&isaResult](llvm::StringRef isa) {
202+
auto isaCallback = [&isaResult](llvm::StringRef isa) -> LogicalResult {
198203
isaResult = isa.str();
204+
return success();
199205
};
200206

201207
gpu::TargetOptions options({}, {}, {}, {}, gpu::CompilationTarget::Assembly,
@@ -220,6 +226,34 @@ TEST_F(MLIRTargetLLVMNVVM,
220226
}
221227
}
222228

229+
// Test callback functions failure with ISA.
230+
TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(CallbackFailedWithISA)) {
231+
MLIRContext context(registry);
232+
233+
OwningOpRef<ModuleOp> module =
234+
parseSourceString<ModuleOp>(moduleStr, &context);
235+
ASSERT_TRUE(!!module);
236+
237+
NVVM::NVVMTargetAttr target = NVVM::NVVMTargetAttr::get(&context);
238+
239+
auto serializer = dyn_cast<gpu::TargetAttrInterface>(target);
240+
ASSERT_TRUE(!!serializer);
241+
242+
auto isaCallback = [](llvm::StringRef /*isa*/) -> LogicalResult {
243+
return failure();
244+
};
245+
246+
gpu::TargetOptions options({}, {}, {}, {}, gpu::CompilationTarget::Assembly,
247+
{}, {}, {}, {}, isaCallback);
248+
249+
for (auto gpuModule : (*module).getBody()->getOps<gpu::GPUModuleOp>()) {
250+
std::optional<SmallVector<char, 0>> object =
251+
serializer.serializeToObject(gpuModule, options);
252+
253+
ASSERT_TRUE(object == std::nullopt);
254+
}
255+
}
256+
223257
// Test linking LLVM IR from a resource attribute.
224258
TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(LinkedLLVMIRResource)) {
225259
MLIRContext context(registry);
@@ -261,9 +295,10 @@ TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(LinkedLLVMIRResource)) {
261295

262296
// Hook to intercept the LLVM IR after linking external libs.
263297
std::string linkedLLVMIR;
264-
auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) {
298+
auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) -> LogicalResult {
265299
llvm::raw_string_ostream ros(linkedLLVMIR);
266300
module.print(ros, nullptr);
301+
return success();
267302
};
268303

269304
// Store the bitcode as a DenseI8ArrayAttr.

0 commit comments

Comments
 (0)