Skip to content

Commit 3128e6e

Browse files
committed
Add optional attribute outline_module to gpu.launch
1 parent 4639a9a commit 3128e6e

File tree

2 files changed

+21
-4
lines changed

2 files changed

+21
-4
lines changed

mlir/include/mlir/Dialect/GPU/IR/GPUOps.td

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -803,7 +803,8 @@ def GPU_LaunchOp : GPU_Op<"launch", [
803803
Optional<Index>:$clusterSizeX,
804804
Optional<Index>:$clusterSizeY,
805805
Optional<Index>:$clusterSizeZ,
806-
Optional<I32>:$dynamicSharedMemorySize)>,
806+
Optional<I32>:$dynamicSharedMemorySize,
807+
OptionalAttr<SymbolRefAttr>:$outlineModule)>,
807808
Results<(outs Optional<GPU_AsyncToken>:$asyncToken)> {
808809
let summary = "GPU kernel launch operation";
809810

@@ -837,6 +838,10 @@ def GPU_LaunchOp : GPU_Op<"launch", [
837838
- a variadic number of Workgroup memory attributions.
838839
- a variadic number of Private memory attributions.
839840

841+
The `outline_module` attribute is optional and specifies a module in which
842+
the kernel should be outlined. When this attribute is present, the kernel is
843+
outlined into the specified module instead of the default behavior.
844+
840845
Syntax:
841846

842847
```
@@ -1030,6 +1035,11 @@ def GPU_LaunchOp : GPU_Op<"launch", [
10301035
static StringRef getNumWorkgroupAttributionsAttrName() {
10311036
return "workgroup_attributions";
10321037
}
1038+
1039+
/// Checks if the outline_module attribute is present.
1040+
bool hasOutlineModule() {
1041+
return getOutlineModule().has_value();
1042+
}
10331043
}];
10341044

10351045
let hasCanonicalizer = 1;

mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -364,9 +364,16 @@ class GpuKernelOutliningPass
364364
Block::iterator insertPt(func->getNextNode());
365365
auto funcWalkResult = func.walk([&](gpu::LaunchOp op) {
366366
SetVector<Value> operands;
367-
std::string kernelFnName =
368-
Twine(op->getParentOfType<SymbolOpInterface>().getName(), "_kernel")
369-
.str();
367+
std::string kernelFnName;
368+
if (auto outlineModuleAttr = op->getAttrOfType<SymbolRefAttr>("outline_module")) {
369+
kernelFnName = outlineModuleAttr.getRootReference().str();
370+
llvm::errs() << "outlined module name = " << kernelFnName << "\n";
371+
} else {
372+
kernelFnName =
373+
Twine(op->getParentOfType<SymbolOpInterface>().getName(), "_kernel")
374+
.str();
375+
llvm::errs() << "original module name = " << kernelFnName << "\n";
376+
}
370377

371378
gpu::GPUFuncOp outlinedFunc =
372379
outlineKernelFuncImpl(op, kernelFnName, operands);

0 commit comments

Comments
 (0)