Skip to content

Commit ce94f04

Browse files
committed
pass op to the callback
1 parent 909d916 commit ce94f04

File tree

7 files changed

+99
-72
lines changed

7 files changed

+99
-72
lines changed

mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,13 @@ class TargetOptions {
5555
StringRef cmdOptions = {}, StringRef elfSection = {},
5656
CompilationTarget compilationTarget = getDefaultCompilationTarget(),
5757
function_ref<SymbolTable *()> getSymbolTableCallback = {},
58-
function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback = {},
59-
function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback = {},
60-
function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback = {},
61-
function_ref<LogicalResult(StringRef)> isaCallback = {});
58+
function_ref<LogicalResult(Operation *op, llvm::Module &)>
59+
initialLlvmIRCallback = {},
60+
function_ref<LogicalResult(Operation *op, llvm::Module &)>
61+
linkedLlvmIRCallback = {},
62+
function_ref<LogicalResult(Operation *op, llvm::Module &)>
63+
optimizedLlvmIRCallback = {},
64+
function_ref<LogicalResult(Operation *op, StringRef)> isaCallback = {});
6265

6366
/// Returns the typeID.
6467
TypeID getTypeID() const;
@@ -97,20 +100,22 @@ class TargetOptions {
97100

98101
/// Returns the callback invoked with the initial LLVM IR for the device
99102
/// module.
100-
function_ref<LogicalResult(llvm::Module &)> getInitialLlvmIRCallback() const;
103+
function_ref<LogicalResult(Operation *op, llvm::Module &)>
104+
getInitialLlvmIRCallback() const;
101105

102106
/// Returns the callback invoked with LLVM IR for the device module
103107
/// after linking the device libraries.
104-
function_ref<LogicalResult(llvm::Module &)> getLinkedLlvmIRCallback() const;
108+
function_ref<LogicalResult(Operation *op, llvm::Module &)>
109+
getLinkedLlvmIRCallback() const;
105110

106111
/// Returns the callback invoked with LLVM IR for the device module after
107112
/// LLVM optimizations but before codegen.
108-
function_ref<LogicalResult(llvm::Module &)>
113+
function_ref<LogicalResult(Operation *op, llvm::Module &)>
109114
getOptimizedLlvmIRCallback() const;
110115

111116
/// Returns the callback invoked with the target ISA for the device,
112117
/// for example PTX assembly.
113-
function_ref<LogicalResult(StringRef)> getISACallback() const;
118+
function_ref<LogicalResult(Operation *op, StringRef)> getISACallback() const;
114119

115120
/// Returns the default compilation target: `CompilationTarget::Fatbin`.
116121
static CompilationTarget getDefaultCompilationTarget();
@@ -128,10 +133,13 @@ class TargetOptions {
128133
StringRef elfSection = {},
129134
CompilationTarget compilationTarget = getDefaultCompilationTarget(),
130135
function_ref<SymbolTable *()> getSymbolTableCallback = {},
131-
function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback = {},
132-
function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback = {},
133-
function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback = {},
134-
function_ref<LogicalResult(StringRef)> isaCallback = {});
136+
function_ref<LogicalResult(Operation *op, llvm::Module &)>
137+
initialLlvmIRCallback = {},
138+
function_ref<LogicalResult(Operation *op, llvm::Module &)>
139+
linkedLlvmIRCallback = {},
140+
function_ref<LogicalResult(Operation *op, llvm::Module &)>
141+
optimizedLlvmIRCallback = {},
142+
function_ref<LogicalResult(Operation *op, StringRef)> isaCallback = {});
135143

136144
/// Path to the target toolkit.
137145
std::string toolkitPath;
@@ -154,19 +162,22 @@ class TargetOptions {
154162
function_ref<SymbolTable *()> getSymbolTableCallback;
155163

156164
/// Callback invoked with the initial LLVM IR for the device module.
157-
function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback;
165+
function_ref<LogicalResult(Operation *op, llvm::Module &)>
166+
initialLlvmIRCallback;
158167

159168
/// Callback invoked with LLVM IR for the device module after
160169
/// linking the device libraries.
161-
function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback;
170+
function_ref<LogicalResult(Operation *op, llvm::Module &)>
171+
linkedLlvmIRCallback;
162172

163173
/// Callback invoked with LLVM IR for the device module after
164174
/// LLVM optimizations but before codegen.
165-
function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback;
175+
function_ref<LogicalResult(Operation *op, llvm::Module &)>
176+
optimizedLlvmIRCallback;
166177

167178
/// Callback invoked with the target ISA for the device,
168179
/// for example PTX assembly.
169-
function_ref<LogicalResult(StringRef)> isaCallback;
180+
function_ref<LogicalResult(Operation *op, StringRef)> isaCallback;
170181

171182
private:
172183
TypeID typeID;

mlir/include/mlir/Target/LLVM/ModuleToObject.h

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,13 @@ class ModuleToObject {
3232
ModuleToObject(
3333
Operation &module, StringRef triple, StringRef chip,
3434
StringRef features = {}, int optLevel = 3,
35-
function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback = {},
36-
function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback = {},
37-
function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback = {},
38-
function_ref<LogicalResult(StringRef)> isaCallback = {});
35+
function_ref<LogicalResult(Operation *op, llvm::Module &)>
36+
initialLlvmIRCallback = {},
37+
function_ref<LogicalResult(Operation *op, llvm::Module &)>
38+
linkedLlvmIRCallback = {},
39+
function_ref<LogicalResult(Operation *op, llvm::Module &)>
40+
optimizedLlvmIRCallback = {},
41+
function_ref<LogicalResult(Operation *op, StringRef)> isaCallback = {});
3942
virtual ~ModuleToObject();
4043

4144
/// Returns the operation being serialized.
@@ -120,19 +123,22 @@ class ModuleToObject {
120123
int optLevel;
121124

122125
/// Callback invoked with the initial LLVM IR for the device module.
123-
function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback;
126+
function_ref<LogicalResult(Operation *op, llvm::Module &)>
127+
initialLlvmIRCallback;
124128

125129
/// Callback invoked with LLVM IR for the device module after
126130
/// linking the device libraries.
127-
function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback;
131+
function_ref<LogicalResult(Operation *op, llvm::Module &)>
132+
linkedLlvmIRCallback;
128133

129134
/// Callback invoked with LLVM IR for the device module after
130135
/// LLVM optimizations but before codegen.
131-
function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback;
136+
function_ref<LogicalResult(Operation *op, llvm::Module &)>
137+
optimizedLlvmIRCallback;
132138

133139
/// Callback invoked with the target ISA for the device,
134140
/// for example PTX assembly.
135-
function_ref<LogicalResult(StringRef)> isaCallback;
141+
function_ref<LogicalResult(Operation *op, StringRef)> isaCallback;
136142

137143
private:
138144
/// The TargetMachine created for the given Triple, if available.

mlir/lib/Dialect/GPU/IR/GPUDialect.cpp

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2652,10 +2652,13 @@ TargetOptions::TargetOptions(
26522652
StringRef cmdOptions, StringRef elfSection,
26532653
CompilationTarget compilationTarget,
26542654
function_ref<SymbolTable *()> getSymbolTableCallback,
2655-
function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback,
2656-
function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback,
2657-
function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback,
2658-
function_ref<LogicalResult(StringRef)> isaCallback)
2655+
function_ref<LogicalResult(Operation *op, llvm::Module &)>
2656+
initialLlvmIRCallback,
2657+
function_ref<LogicalResult(Operation *op, llvm::Module &)>
2658+
linkedLlvmIRCallback,
2659+
function_ref<LogicalResult(Operation *op, llvm::Module &)>
2660+
optimizedLlvmIRCallback,
2661+
function_ref<LogicalResult(Operation *op, StringRef)> isaCallback)
26592662
: TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, librariesToLink,
26602663
cmdOptions, elfSection, compilationTarget,
26612664
getSymbolTableCallback, initialLlvmIRCallback,
@@ -2667,10 +2670,13 @@ TargetOptions::TargetOptions(
26672670
StringRef cmdOptions, StringRef elfSection,
26682671
CompilationTarget compilationTarget,
26692672
function_ref<SymbolTable *()> getSymbolTableCallback,
2670-
function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback,
2671-
function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback,
2672-
function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback,
2673-
function_ref<LogicalResult(StringRef)> isaCallback)
2673+
function_ref<LogicalResult(Operation *op, llvm::Module &)>
2674+
initialLlvmIRCallback,
2675+
function_ref<LogicalResult(Operation *op, llvm::Module &)>
2676+
linkedLlvmIRCallback,
2677+
function_ref<LogicalResult(Operation *op, llvm::Module &)>
2678+
optimizedLlvmIRCallback,
2679+
function_ref<LogicalResult(Operation *op, StringRef)> isaCallback)
26742680
: toolkitPath(toolkitPath.str()), librariesToLink(librariesToLink),
26752681
cmdOptions(cmdOptions.str()), elfSection(elfSection.str()),
26762682
compilationTarget(compilationTarget),
@@ -2696,22 +2702,23 @@ SymbolTable *TargetOptions::getSymbolTable() const {
26962702
return getSymbolTableCallback ? getSymbolTableCallback() : nullptr;
26972703
}
26982704

2699-
function_ref<LogicalResult(llvm::Module &)>
2705+
function_ref<LogicalResult(Operation *op, llvm::Module &)>
27002706
TargetOptions::getInitialLlvmIRCallback() const {
27012707
return initialLlvmIRCallback;
27022708
}
27032709

2704-
function_ref<LogicalResult(llvm::Module &)>
2710+
function_ref<LogicalResult(Operation *op, llvm::Module &)>
27052711
TargetOptions::getLinkedLlvmIRCallback() const {
27062712
return linkedLlvmIRCallback;
27072713
}
27082714

2709-
function_ref<LogicalResult(llvm::Module &)>
2715+
function_ref<LogicalResult(Operation *op, llvm::Module &)>
27102716
TargetOptions::getOptimizedLlvmIRCallback() const {
27112717
return optimizedLlvmIRCallback;
27122718
}
27132719

2714-
function_ref<LogicalResult(StringRef)> TargetOptions::getISACallback() const {
2720+
function_ref<LogicalResult(Operation *op, StringRef)>
2721+
TargetOptions::getISACallback() const {
27152722
return isaCallback;
27162723
}
27172724

mlir/lib/Target/LLVM/ModuleToObject.cpp

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,13 @@ using namespace mlir::LLVM;
3737
ModuleToObject::ModuleToObject(
3838
Operation &module, StringRef triple, StringRef chip, StringRef features,
3939
int optLevel,
40-
function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback,
41-
function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback,
42-
function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback,
43-
function_ref<LogicalResult(StringRef)> isaCallback)
40+
function_ref<LogicalResult(Operation *op, llvm::Module &)>
41+
initialLlvmIRCallback,
42+
function_ref<LogicalResult(Operation *op, llvm::Module &)>
43+
linkedLlvmIRCallback,
44+
function_ref<LogicalResult(Operation *op, llvm::Module &)>
45+
optimizedLlvmIRCallback,
46+
function_ref<LogicalResult(Operation *op, StringRef)> isaCallback)
4447
: module(module), triple(triple), chip(chip), features(features),
4548
optLevel(optLevel), initialLlvmIRCallback(initialLlvmIRCallback),
4649
linkedLlvmIRCallback(linkedLlvmIRCallback),
@@ -255,12 +258,9 @@ std::optional<SmallVector<char, 0>> ModuleToObject::run() {
255258
}
256259
setDataLayoutAndTriple(*llvmModule);
257260

258-
if (initialLlvmIRCallback) {
259-
if (failed(initialLlvmIRCallback(*llvmModule))) {
260-
getOperation().emitError() << "InitialLLVMIRCallback failed.";
261+
if (initialLlvmIRCallback)
262+
if (failed(initialLlvmIRCallback(&getOperation(), *llvmModule)))
261263
return std::nullopt;
262-
}
263-
}
264264

265265
// Link bitcode files.
266266
handleModulePreLink(*llvmModule);
@@ -274,23 +274,17 @@ std::optional<SmallVector<char, 0>> ModuleToObject::run() {
274274
handleModulePostLink(*llvmModule);
275275
}
276276

277-
if (linkedLlvmIRCallback) {
278-
if (failed(linkedLlvmIRCallback(*llvmModule))) {
279-
getOperation().emitError() << "LinkedLLVMIRCallback failed.";
277+
if (linkedLlvmIRCallback)
278+
if (failed(linkedLlvmIRCallback(&getOperation(), *llvmModule)))
280279
return std::nullopt;
281-
}
282-
}
283280

284281
// Optimize the module.
285282
if (failed(optimizeModule(*llvmModule, optLevel)))
286283
return std::nullopt;
287284

288-
if (optimizedLlvmIRCallback) {
289-
if (failed(optimizedLlvmIRCallback(*llvmModule))) {
290-
getOperation().emitError() << "OptimizedLLVMIRCallback failed.";
285+
if (optimizedLlvmIRCallback)
286+
if (failed(optimizedLlvmIRCallback(&getOperation(), *llvmModule)))
291287
return std::nullopt;
292-
}
293-
}
294288

295289
// Return the serialized object.
296290
return moduleToObject(*llvmModule);

mlir/lib/Target/LLVM/NVVM/Target.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -707,12 +707,9 @@ NVPTXSerializer::moduleToObject(llvm::Module &llvmModule) {
707707
return std::nullopt;
708708
}
709709

710-
if (isaCallback) {
711-
if (failed(isaCallback(serializedISA.value()))) {
712-
getOperation().emitError() << "ISACallback failed.";
710+
if (isaCallback)
711+
if (failed(isaCallback(getOperation(), serializedISA.value())))
713712
return std::nullopt;
714-
}
715-
}
716713

717714
#define DEBUG_TYPE "serialize-to-isa"
718715
LDBG() << "PTX for module: " << getOperation().getNameAttr() << "\n"

mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -177,29 +177,33 @@ TEST_F(MLIRTargetLLVMNVVM,
177177

178178
std::string initialLLVMIR;
179179
auto initialCallback =
180-
[&initialLLVMIR](llvm::Module &module) -> LogicalResult {
180+
[&initialLLVMIR](Operation * /*op*/,
181+
llvm::Module &module) -> LogicalResult {
181182
llvm::raw_string_ostream ros(initialLLVMIR);
182183
module.print(ros, nullptr);
183184
return success();
184185
};
185186

186187
std::string linkedLLVMIR;
187-
auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) -> LogicalResult {
188+
auto linkedCallback = [&linkedLLVMIR](Operation * /*op*/,
189+
llvm::Module &module) -> LogicalResult {
188190
llvm::raw_string_ostream ros(linkedLLVMIR);
189191
module.print(ros, nullptr);
190192
return success();
191193
};
192194

193195
std::string optimizedLLVMIR;
194196
auto optimizedCallback =
195-
[&optimizedLLVMIR](llvm::Module &module) -> LogicalResult {
197+
[&optimizedLLVMIR](Operation * /*op*/,
198+
llvm::Module &module) -> LogicalResult {
196199
llvm::raw_string_ostream ros(optimizedLLVMIR);
197200
module.print(ros, nullptr);
198201
return success();
199202
};
200203

201204
std::string isaResult;
202-
auto isaCallback = [&isaResult](llvm::StringRef isa) -> LogicalResult {
205+
auto isaCallback = [&isaResult](Operation * /*op*/,
206+
llvm::StringRef isa) -> LogicalResult {
203207
isaResult = isa.str();
204208
return success();
205209
};
@@ -239,7 +243,8 @@ TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(CallbackFailedWithISA)) {
239243
auto serializer = dyn_cast<gpu::TargetAttrInterface>(target);
240244
ASSERT_TRUE(!!serializer);
241245

242-
auto isaCallback = [](llvm::StringRef /*isa*/) -> LogicalResult {
246+
auto isaCallback = [](Operation * /*op*/,
247+
llvm::StringRef /*isa*/) -> LogicalResult {
243248
return failure();
244249
};
245250

@@ -295,7 +300,8 @@ TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(LinkedLLVMIRResource)) {
295300

296301
// Hook to intercept the LLVM IR after linking external libs.
297302
std::string linkedLLVMIR;
298-
auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) -> LogicalResult {
303+
auto linkedCallback = [&linkedLLVMIR](Operation * /*op*/,
304+
llvm::Module &module) -> LogicalResult {
299305
llvm::raw_string_ostream ros(linkedLLVMIR);
300306
module.print(ros, nullptr);
301307
return success();

mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,8 @@ TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackInvokedWithInitialLLVMIR)) {
169169

170170
std::string initialLLVMIR;
171171
auto initialCallback =
172-
[&initialLLVMIR](llvm::Module &module) -> LogicalResult {
172+
[&initialLLVMIR](Operation * /*op*/,
173+
llvm::Module &module) -> LogicalResult {
173174
llvm::raw_string_ostream ros(initialLLVMIR);
174175
module.print(ros, nullptr);
175176
return success();
@@ -198,7 +199,8 @@ TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackInvokedWithLinkedLLVMIR)) {
198199
auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target);
199200

200201
std::string linkedLLVMIR;
201-
auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) -> LogicalResult {
202+
auto linkedCallback = [&linkedLLVMIR](Operation * /*op*/,
203+
llvm::Module &module) -> LogicalResult {
202204
llvm::raw_string_ostream ros(linkedLLVMIR);
203205
module.print(ros, nullptr);
204206
return success();
@@ -229,7 +231,8 @@ TEST_F(MLIRTargetLLVM,
229231

230232
std::string optimizedLLVMIR;
231233
auto optimizedCallback =
232-
[&optimizedLLVMIR](llvm::Module &module) -> LogicalResult {
234+
[&optimizedLLVMIR](Operation * /*op*/,
235+
llvm::Module &module) -> LogicalResult {
233236
llvm::raw_string_ostream ros(optimizedLLVMIR);
234237
module.print(ros, nullptr);
235238
return success();
@@ -257,7 +260,8 @@ TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackFailedWithInitialLLVMIR)) {
257260
IntegerAttr target = builder.getI32IntegerAttr(0);
258261
auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target);
259262

260-
auto initialCallback = [](llvm::Module & /*module*/) -> LogicalResult {
263+
auto initialCallback = [](Operation * /*op*/,
264+
llvm::Module & /*module*/) -> LogicalResult {
261265
return failure();
262266
};
263267

@@ -281,7 +285,8 @@ TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackFailedWithLinkedLLVMIR)) {
281285
IntegerAttr target = builder.getI32IntegerAttr(0);
282286
auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target);
283287

284-
auto linkedCallback = [](llvm::Module & /*module*/) -> LogicalResult {
288+
auto linkedCallback = [](Operation * /*op*/,
289+
llvm::Module & /*module*/) -> LogicalResult {
285290
return failure();
286291
};
287292

@@ -305,7 +310,8 @@ TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackFailedWithOptimizedLLVMIR)) {
305310
IntegerAttr target = builder.getI32IntegerAttr(0);
306311
auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target);
307312

308-
auto optimizedCallback = [](llvm::Module & /*module*/) -> LogicalResult {
313+
auto optimizedCallback = [](Operation * /*op*/,
314+
llvm::Module & /*module*/) -> LogicalResult {
309315
return failure();
310316
};
311317

0 commit comments

Comments
 (0)