Skip to content

Commit 52ad728

Browse files
committed
Add optional attributes kernelFunc and kernelModule to specify the kernel function name or kernel module name.
1 parent 3128e6e commit 52ad728

File tree

2 files changed

+40
-15
lines changed

2 files changed

+40
-15
lines changed

mlir/include/mlir/Dialect/GPU/IR/GPUOps.td

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -804,7 +804,8 @@ def GPU_LaunchOp : GPU_Op<"launch", [
804804
Optional<Index>:$clusterSizeY,
805805
Optional<Index>:$clusterSizeZ,
806806
Optional<I32>:$dynamicSharedMemorySize,
807-
OptionalAttr<SymbolRefAttr>:$outlineModule)>,
807+
OptionalAttr<SymbolRefAttr>:$kernelFunc,
808+
OptionalAttr<SymbolRefAttr>:$kernelModule)>,
808809
Results<(outs Optional<GPU_AsyncToken>:$asyncToken)> {
809810
let summary = "GPU kernel launch operation";
810811

@@ -838,9 +839,8 @@ def GPU_LaunchOp : GPU_Op<"launch", [
838839
- a variadic number of Workgroup memory attributions.
839840
- a variadic number of Private memory attributions.
840841

841-
The `outline_module` attribute is optional and specifies a module in which
842-
the kernel should be outlined. When this attribute is present, the kernel is
843-
outlined into the specified module instead of the default behavior.
842+
The `kernelFunc` and `kernelModule` attributes are optional and specifies the kernel name and a module in whichthe kernel should be outlined.
843+
844844

845845
Syntax:
846846

@@ -1036,9 +1036,14 @@ def GPU_LaunchOp : GPU_Op<"launch", [
10361036
return "workgroup_attributions";
10371037
}
10381038

1039-
/// Checks if the outline_module attribute is present.
1040-
bool hasOutlineModule() {
1041-
return getOutlineModule().has_value();
1039+
/// Checks if the kernel func name attribute is present.
1040+
bool hasKernelFuncName() {
1041+
return getKernelFunc().has_value();
1042+
}
1043+
1044+
/// Checks if the kernel module name attribute is present.
1045+
bool hasKernelModuleName() {
1046+
return getKernelModule().has_value();
10421047
}
10431048
}];
10441049

mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -365,14 +365,14 @@ class GpuKernelOutliningPass
365365
auto funcWalkResult = func.walk([&](gpu::LaunchOp op) {
366366
SetVector<Value> operands;
367367
std::string kernelFnName;
368-
if (auto outlineModuleAttr = op->getAttrOfType<SymbolRefAttr>("outline_module")) {
369-
kernelFnName = outlineModuleAttr.getRootReference().str();
370-
llvm::errs() << "outlined module name = " << kernelFnName << "\n";
368+
if (op.hasKernelFuncName()) {
369+
kernelFnName = op->getAttrOfType<mlir::SymbolRefAttr>("kernelFunc").getRootReference().str();
370+
llvm::errs() << "use provided kernel func name = " << kernelFnName << "\n";
371371
} else {
372372
kernelFnName =
373373
Twine(op->getParentOfType<SymbolOpInterface>().getName(), "_kernel")
374374
.str();
375-
llvm::errs() << "original module name = " << kernelFnName << "\n";
375+
llvm::errs() << "use default kernel func name = " << kernelFnName << "\n";
376376
}
377377

378378
gpu::GPUFuncOp outlinedFunc =
@@ -381,7 +381,7 @@ class GpuKernelOutliningPass
381381
// Create nested module and insert outlinedFunc. The module will
382382
// originally get the same name as the function, but may be renamed on
383383
// insertion into the parent module.
384-
auto kernelModule = createKernelModule(outlinedFunc, symbolTable);
384+
auto kernelModule = createKernelModule(op, outlinedFunc, symbolTable);
385385
symbolTable.insert(kernelModule, insertPt);
386386

387387
// Potentially changes signature, pulling in constants.
@@ -402,16 +402,34 @@ class GpuKernelOutliningPass
402402

403403
private:
404404
/// Returns a gpu.module containing kernelFunc and all callees (recursive).
405-
gpu::GPUModuleOp createKernelModule(gpu::GPUFuncOp kernelFunc,
405+
gpu::GPUModuleOp createKernelModule(gpu::LaunchOp gpuLaunchOp, gpu::GPUFuncOp kernelFunc,
406406
const SymbolTable &parentSymbolTable) {
407407
// TODO: This code cannot use an OpBuilder because it must be inserted into
408408
// a SymbolTable by the caller. SymbolTable needs to be refactored to
409409
// prevent manual building of Ops with symbols in code using SymbolTables
410410
// and then this needs to use the OpBuilder.
411411
auto *context = getOperation().getContext();
412412
OpBuilder builder(context);
413-
auto kernelModule = builder.create<gpu::GPUModuleOp>(kernelFunc.getLoc(),
414-
kernelFunc.getName());
413+
std::string kernelModuleName;
414+
if (gpuLaunchOp.hasKernelModuleName()) {
415+
kernelModuleName = gpuLaunchOp->getAttrOfType<mlir::SymbolRefAttr>("kernelModule").getRootReference().str();
416+
llvm::errs() << "use provided kernel module name = " << kernelModuleName << "\n";
417+
} else {
418+
kernelModuleName = kernelFunc.getName();
419+
llvm::errs() << "use default kernel module name = " << kernelModuleName << "\n";
420+
}
421+
422+
gpu::GPUModuleOp kernelModule;
423+
// Check if the module already exists in the symbol table
424+
if (auto existingModule = parentSymbolTable.lookup<gpu::GPUModuleOp>(kernelModuleName)) {
425+
llvm::errs() << "Reusing existing kernel module: " << kernelModuleName << "\n";
426+
kernelModule = existingModule;
427+
} else {
428+
// If not found, create a new GPU module
429+
llvm::errs() << "Creating new kernel module: " << kernelModuleName << "\n";
430+
kernelModule = builder.create<gpu::GPUModuleOp>(kernelFunc.getLoc(),
431+
kernelModuleName);
432+
}
415433

416434
// If a valid data layout spec was provided, attach it to the kernel module.
417435
// Otherwise, the default data layout will be used.
@@ -439,6 +457,8 @@ class GpuKernelOutliningPass
439457
}
440458
}
441459

460+
//llvm::errs() << "kernelModule:\n" << kernelModule << "\n";
461+
442462
return kernelModule;
443463
}
444464

0 commit comments

Comments
 (0)