Skip to content

Commit 7d886fa

Browse files
authored
[mlir][gpu] Update attribute definitions in gpu::LaunchOp (#152106)
`gpu::LaunchOp` is updated the following way: - Change the attribute type of kernel function and module from `SymbolRefAttr` to `FlatSymbolRefAttr` to avoid nested symbol references. - Rename variables from camel case (kernelFunc, kernelModule) to lower case (function, module) and update the syntax. - `LaunchOp::build` support passing `module` and `function` attributes.
1 parent ffdaf85 commit 7d886fa

File tree

5 files changed

+98
-24
lines changed

5 files changed

+98
-24
lines changed

mlir/include/mlir/Dialect/GPU/IR/GPUOps.td

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -804,8 +804,8 @@ def GPU_LaunchOp : GPU_Op<"launch", [
804804
Optional<Index>:$clusterSizeY,
805805
Optional<Index>:$clusterSizeZ,
806806
Optional<I32>:$dynamicSharedMemorySize,
807-
OptionalAttr<SymbolRefAttr>:$kernelFunc,
808-
OptionalAttr<SymbolRefAttr>:$kernelModule)>,
807+
OptionalAttr<FlatSymbolRefAttr>:$module,
808+
OptionalAttr<FlatSymbolRefAttr>:$function)>,
809809
Results<(outs Optional<GPU_AsyncToken>:$asyncToken)> {
810810
let summary = "GPU kernel launch operation";
811811

@@ -839,7 +839,7 @@ def GPU_LaunchOp : GPU_Op<"launch", [
839839
- a variadic number of Workgroup memory attributions.
840840
- a variadic number of Private memory attributions.
841841

842-
The `kernelFunc` and `kernelModule` attributes are optional and specifies
842+
The `function` and `module` attributes are optional and specifies
843843
the kernel name and a module in which the kernel should be outlined.
844844

845845
Syntax:
@@ -850,6 +850,8 @@ def GPU_LaunchOp : GPU_Op<"launch", [
850850
`blocks` `(` ssa-id-list `)` `in` ssa-reassignment
851851
`threads` `(` ssa-id-list `)` `in` ssa-reassignment
852852
(dynamic_shared_memory_size ssa-use)?
853+
(`module(` symbol-ref-id `)`)?
854+
(`function(` symbol-ref-id `)`)?
853855
memory-attribution
854856
region attr-dict?
855857
ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)`
@@ -907,6 +909,14 @@ def GPU_LaunchOp : GPU_Op<"launch", [
907909
// sizes are immediately usable inside body region.
908910
"some_op"(%cx, %bx, %tx) : (index, index, index) -> ()
909911
}
912+
913+
// Launch with module and function attributes.
914+
gpu.launch blocks(%bx, %by, %bz) in (%sz_bx = %0, %sz_by = %1, %sz_bz = %2)
915+
threads(%tx, %ty, %tz) in (%sz_tx = %3, %sz_ty = %4, %sz_tz = %5)
916+
module(@kernel_module) function(@kernel_func) {
917+
"some_op"(%bx, %tx) : (index, index) -> ()
918+
%42 = load %val1[%bx] : memref<?xf32, 1>
919+
}
910920
```
911921

912922
Rationale: using operation/block arguments gives analyses a clear way of
@@ -931,7 +941,9 @@ def GPU_LaunchOp : GPU_Op<"launch", [
931941
CArg<"TypeRange", "{}">:$privateAttributions,
932942
CArg<"Value", "nullptr">:$clusterSizeX,
933943
CArg<"Value", "nullptr">:$clusterSizeY,
934-
CArg<"Value", "nullptr">:$clusterSizeZ)>
944+
CArg<"Value", "nullptr">:$clusterSizeZ,
945+
CArg<"FlatSymbolRefAttr", "nullptr">:$module,
946+
CArg<"FlatSymbolRefAttr", "nullptr">:$function)>,
935947
];
936948

937949
let extraClassDeclaration = [{

mlir/lib/Dialect/GPU/IR/GPUDialect.cpp

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -756,7 +756,8 @@ void LaunchOp::build(OpBuilder &builder, OperationState &result,
756756
Type asyncTokenType, ValueRange asyncDependencies,
757757
TypeRange workgroupAttributions,
758758
TypeRange privateAttributions, Value clusterSizeX,
759-
Value clusterSizeY, Value clusterSizeZ) {
759+
Value clusterSizeY, Value clusterSizeZ,
760+
FlatSymbolRefAttr module, FlatSymbolRefAttr function) {
760761
OpBuilder::InsertionGuard g(builder);
761762

762763
// Add a WorkGroup attribution attribute. This attribute is required to
@@ -781,6 +782,12 @@ void LaunchOp::build(OpBuilder &builder, OperationState &result,
781782
if (dynamicSharedMemorySize)
782783
result.addOperands(dynamicSharedMemorySize);
783784

785+
// Add optional module and function attributes.
786+
if (module)
787+
result.addAttribute(getModuleAttrName(result.name), module);
788+
if (function)
789+
result.addAttribute(getFunctionAttrName(result.name), function);
790+
784791
// Create a kernel body region with kNumConfigRegionAttributes + N memory
785792
// attributions, where the first kNumConfigRegionAttributes arguments have
786793
// `index` type and the rest have the same types as the data operands.
@@ -944,6 +951,21 @@ void LaunchOp::print(OpAsmPrinter &p) {
944951
p << ' ' << getDynamicSharedMemorySizeKeyword() << ' '
945952
<< getDynamicSharedMemorySize();
946953

954+
// Print optional module attribute.
955+
StringRef moduleAttrName = getModuleAttrName();
956+
if (auto module = getModule()) {
957+
p << ' ' << moduleAttrName << '(';
958+
p.printSymbolName(*module);
959+
p << ')';
960+
}
961+
// Print optional function attribute.
962+
StringRef functionAttrName = getFunctionAttrName();
963+
if (auto function = getFunction()) {
964+
p << ' ' << functionAttrName << '(';
965+
p.printSymbolName(*function);
966+
p << ')';
967+
}
968+
947969
printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions());
948970
printAttributions(p, getPrivateKeyword(), getPrivateAttributions());
949971

@@ -952,7 +974,8 @@ void LaunchOp::print(OpAsmPrinter &p) {
952974
p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
953975
p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
954976
LaunchOp::getOperandSegmentSizeAttr(),
955-
getNumWorkgroupAttributionsAttrName()});
977+
getNumWorkgroupAttributionsAttrName(),
978+
moduleAttrName, functionAttrName});
956979
}
957980

958981
// Parse the size assignment blocks for blocks and threads. These have the form
@@ -990,6 +1013,9 @@ parseSizeAssignment(OpAsmParser &parser,
9901013
/// `clusters` `(` ssa-id-list `)` `in` ssa-reassignment (Optional)
9911014
/// `blocks` `(` ssa-id-list `)` `in` ssa-reassignment
9921015
/// `threads` `(` ssa-id-list `)` `in` ssa-reassignment
1016+
/// (`dynamic_shared_memory_size` ssa-use)?
1017+
/// (`module(` symbol-ref-id `)`)?
1018+
/// (`function(` symbol-ref-id `)`)?
9931019
/// memory-attribution
9941020
/// region attr-dict?
9951021
/// ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)`
@@ -1060,6 +1086,27 @@ ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
10601086
return failure();
10611087
}
10621088

1089+
// Parse optional module attribute.
1090+
StringRef moduleAttrName = getModuleAttrName(result.name);
1091+
if (succeeded(parser.parseOptionalKeyword(moduleAttrName))) {
1092+
FlatSymbolRefAttr moduleSymbol;
1093+
if (parser.parseLParen() ||
1094+
parser.parseAttribute(moduleSymbol, Type(), moduleAttrName,
1095+
result.attributes) ||
1096+
parser.parseRParen())
1097+
return failure();
1098+
}
1099+
// Parse optional function attribute.
1100+
StringRef functionAttrName = getFunctionAttrName(result.name);
1101+
if (succeeded(parser.parseOptionalKeyword(functionAttrName))) {
1102+
FlatSymbolRefAttr funcSymbol;
1103+
if (parser.parseLParen() ||
1104+
parser.parseAttribute(funcSymbol, Type(), functionAttrName,
1105+
result.attributes) ||
1106+
parser.parseRParen())
1107+
return failure();
1108+
}
1109+
10631110
// Create the region arguments, it has kNumConfigRegionAttributes arguments
10641111
// that correspond to block/thread identifiers and grid/block sizes, all
10651112
// having `index` type, a variadic number of WorkGroup Attributions and

mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -356,8 +356,8 @@ class GpuKernelOutliningPass
356356
auto funcWalkResult = func.walk([&](gpu::LaunchOp op) {
357357
SetVector<Value> operands;
358358
std::string kernelFnName;
359-
if (op.getKernelFunc()) {
360-
kernelFnName = op.getKernelFunc()->getRootReference().str();
359+
if (op.getFunction()) {
360+
kernelFnName = op.getFunction()->str();
361361
} else {
362362
kernelFnName =
363363
Twine(op->getParentOfType<SymbolOpInterface>().getName(),
@@ -403,9 +403,8 @@ class GpuKernelOutliningPass
403403
OpBuilder builder(context);
404404
std::string kernelModuleName;
405405
gpu::GPUModuleOp kernelModule;
406-
if (gpuLaunchOp.getKernelModule()) {
407-
kernelModuleName =
408-
gpuLaunchOp.getKernelModule()->getRootReference().str();
406+
if (gpuLaunchOp.getModule()) {
407+
kernelModuleName = gpuLaunchOp.getModule()->str();
409408
kernelModule =
410409
parentSymbolTable.lookup<gpu::GPUModuleOp>(kernelModuleName);
411410
} else {

mlir/test/Dialect/GPU/ops.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,18 @@ module attributes {gpu.container_module} {
1717
return
1818
}
1919

20+
// CHECK-LABEL:func @launch_with_module_func_attr(%{{.*}}: index)
21+
func.func @launch_with_module_func_attr(%sz : index) {
22+
// CHECK: gpu.launch blocks(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}) threads(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}) module(@test_module) function(@test_kernel_func)
23+
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %sz, %grid_y = %sz, %grid_z = %sz)
24+
threads(%tx, %ty, %tz) in (%block_x = %sz, %block_y = %sz, %block_z = %sz)
25+
module(@test_module) function(@test_kernel_func) {
26+
// CHECK: gpu.terminator
27+
gpu.terminator
28+
}
29+
return
30+
}
31+
2032
// CHECK-LABEL:func @args(%{{.*}}: index, %{{.*}}: index, %{{.*}}: f32, %{{.*}}: memref<?xf32, 1>) {
2133
func.func @args(%blk : index, %thrd : index, %float : f32, %data : memref<?xf32,1>) {
2234
// CHECK: gpu.launch blocks(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}) threads(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}, %{{.*}} = %{{.*}})

mlir/test/Dialect/GPU/outlining.mlir

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,7 @@ func.func @launch_cluster() {
509509
// CHECK-NEXT: = memref.load %[[KERNEL_ARG1]][%[[TID]]] : memref<?xf32, 1>
510510

511511
// -----
512-
// This test tests the two optional attributes kernelModule and kernelFunc for gpu.launch
512+
// This test tests the two optional attributes `module` and `function` for gpu.launch
513513
// CHECK-LABEL: func.func @testKernelAttributes()
514514
// CHECK: gpu.launch_func @test_module::@test_kernel_func blocks in (%[[GRID_X:.*]], %[[GRID_Y:.*]], %[[GRID_Z:.*]]) threads in (%[[BLOCK_X:.*]], %[[BLOCK_Y:.*]], %[[BLOCK_Z:.*]])
515515
// CHECK: gpu.module @test_module
@@ -523,15 +523,16 @@ func.func @testKernelAttributes() {
523523
%bDimZ = arith.constant 8 : index
524524

525525
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %gDimX, %grid_y = %gDimY, %grid_z = %gDimZ)
526-
threads(%tx, %ty, %tz) in (%block_x = %bDimX, %block_y = %bDimY, %block_z = %bDimZ) {
526+
threads(%tx, %ty, %tz) in (%block_x = %bDimX, %block_y = %bDimY, %block_z = %bDimZ)
527+
module(@test_module) function(@test_kernel_func) {
527528
"some_op"(%bx, %tx) : (index, index) -> ()
528529
gpu.terminator
529-
} {kernelModule = @test_module, kernelFunc = @test_kernel_func}
530+
}
530531
return
531532
}
532533

533534
// -----
534-
// This test tests the two optional attributes kernelModule and kernelFunc for gpu.launch, when kernelModule already exists.
535+
// This test tests the two optional attributes `module` and `function` for gpu.launch, when kernelModule already exists.
535536

536537
// CHECK-LABEL: gpu.module @existing_module
537538
// CHECK: gpu.func @test_kernel_func()
@@ -556,15 +557,16 @@ func.func @testExistingModule() {
556557
%bDimZ = arith.constant 8 : index
557558

558559
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %gDimX, %grid_y = %gDimY, %grid_z = %gDimZ)
559-
threads(%tx, %ty, %tz) in (%block_x = %bDimX, %block_y = %bDimY, %block_z = %bDimZ) {
560+
threads(%tx, %ty, %tz) in (%block_x = %bDimX, %block_y = %bDimY, %block_z = %bDimZ)
561+
module(@existing_module) function(@test_kernel_func) {
560562
"some_op"(%bx, %tx) : (index, index) -> ()
561563
gpu.terminator
562-
} {kernelModule = @existing_module, kernelFunc = @test_kernel_func}
564+
}
563565
return
564566
}
565567

566568
// -----
567-
// This test tests the optional attribute kernelModule for gpu.launch.
569+
// This test tests the optional attribute `module` for gpu.launch.
568570
// CHECK-LABEL: func.func @testKernelModuleOnly()
569571
// CHECK: gpu.launch_func @test_module::@testKernelModuleOnly_kernel blocks in (%[[GRID_X:.*]], %[[GRID_Y:.*]], %[[GRID_Z:.*]]) threads in (%[[BLOCK_X:.*]], %[[BLOCK_Y:.*]], %[[BLOCK_Z:.*]])
570572
// CHECK: gpu.module @test_module
@@ -578,15 +580,16 @@ func.func @testKernelModuleOnly() {
578580
%bDimZ = arith.constant 8 : index
579581

580582
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %gDimX, %grid_y = %gDimY, %grid_z = %gDimZ)
581-
threads(%tx, %ty, %tz) in (%block_x = %bDimX, %block_y = %bDimY, %block_z = %bDimZ) {
583+
threads(%tx, %ty, %tz) in (%block_x = %bDimX, %block_y = %bDimY, %block_z = %bDimZ)
584+
module(@test_module) {
582585
"some_op"(%bx, %tx) : (index, index) -> ()
583586
gpu.terminator
584-
} {kernelModule = @test_module}
587+
}
585588
return
586589
}
587590

588591
// -----
589-
// This test tests the optional attribute kernelFunc for gpu.launch.
592+
// This test tests the optional attribute `function` for gpu.launch.
590593
// CHECK-LABEL: func.func @testKernelFuncOnly()
591594
// CHECK: gpu.launch_func @test_kernel_func::@test_kernel_func blocks in (%[[GRID_X:.*]], %[[GRID_Y:.*]], %[[GRID_Z:.*]]) threads in (%[[BLOCK_X:.*]], %[[BLOCK_Y:.*]], %[[BLOCK_Z:.*]])
592595

@@ -601,15 +604,16 @@ func.func @testKernelFuncOnly() {
601604
%bDimZ = arith.constant 8 : index
602605

603606
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %gDimX, %grid_y = %gDimY, %grid_z = %gDimZ)
604-
threads(%tx, %ty, %tz) in (%block_x = %bDimX, %block_y = %bDimY, %block_z = %bDimZ) {
607+
threads(%tx, %ty, %tz) in (%block_x = %bDimX, %block_y = %bDimY, %block_z = %bDimZ)
608+
function(@test_kernel_func) {
605609
"some_op"(%bx, %tx) : (index, index) -> ()
606610
gpu.terminator
607-
} {kernelFunc = @test_kernel_func}
611+
}
608612
return
609613
}
610614

611615
// -----
612-
// This test tests gpu.launch when optional attributes kernelModule and kernelFunc are not specified.
616+
// This test tests gpu.launch when optional attributes `module` and `function` are not specified.
613617
// CHECK-LABEL: func.func @testNoAttributes()
614618
// CHECK: gpu.launch_func @testNoAttributes_kernel::@testNoAttributes_kernel blocks in (%[[GRID_X:.*]], %[[GRID_Y:.*]], %[[GRID_Z:.*]]) threads in (%[[BLOCK_X:.*]], %[[BLOCK_Y:.*]], %[[BLOCK_Z:.*]])
615619

0 commit comments

Comments
 (0)