Skip to content

Commit 2edee0b

Browse files
authored
[mlir][gpu] Support outlining nested gpu.launch (#152696)
This PR fixes a crash in `GpuKernelOutliningPass` that occurred when encountering a symbol that was not a `FlatSymbolRefAttr`, enabling outlining of nested `gpu.launch` operations. Fixes #149318.
1 parent 04081ca commit 2edee0b

File tree

2 files changed

+27
-2
lines changed

2 files changed

+27
-2
lines changed

mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -431,8 +431,7 @@ class GpuKernelOutliningPass
431431
if (std::optional<SymbolTable::UseRange> symbolUses =
432432
SymbolTable::getSymbolUses(symbolDefWorklist.pop_back_val())) {
433433
for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
434-
StringRef symbolName =
435-
cast<FlatSymbolRefAttr>(symbolUse.getSymbolRef()).getValue();
434+
StringAttr symbolName = symbolUse.getSymbolRef().getLeafReference();
436435
if (symbolTable.lookup(symbolName))
437436
continue;
438437

mlir/test/Dialect/GPU/outlining.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -634,3 +634,29 @@ func.func @testNoAttributes() {
634634
}
635635
return
636636
}
637+
638+
// -----
639+
640+
// This test tests nested `gpu.launch`.
641+
642+
// CHECK-LABEL: func.func @nested_launch(
643+
// CHECK-SAME: %[[ARG0:.*]]: index) {
644+
// CHECK: gpu.launch_func @nested_launch_kernel_0::@nested_launch_kernel blocks in (%[[ARG0]], %[[ARG0]], %[[ARG0]]) threads in (%[[ARG0]], %[[ARG0]], %[[ARG0]]) args(%[[ARG0]] : index)
645+
// CHECK: gpu.module @nested_launch_kernel
646+
// CHECK: gpu.func @nested_launch_kernel() kernel
647+
// CHECK: "some_op"
648+
// CHECK: gpu.module @nested_launch_kernel_0
649+
// CHECK: gpu.func @nested_launch_kernel(%[[VAL_0:.*]]: index) kernel
650+
// CHECK: gpu.launch_func @nested_launch_kernel::@nested_launch_kernel blocks in (%[[VAL_0]], %[[VAL_0]], %[[VAL_0]]) threads in (%[[VAL_0]], %[[VAL_0]], %[[VAL_0]])
651+
func.func @nested_launch(%sz : index) {
652+
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %sz, %grid_y = %sz, %grid_z = %sz)
653+
threads(%tx, %ty, %tz) in (%block_x = %sz, %block_y = %sz, %block_z = %sz) {
654+
gpu.launch blocks(%bx1, %by1, %bz1) in (%grid_x1 = %sz, %grid_y1 = %sz, %grid_z1 = %sz)
655+
threads(%tx1, %ty1, %tz1) in (%block_x1 = %sz, %block_y1 = %sz, %block_z1 = %sz) {
656+
"some_op"(%bx1, %tx1) : (index, index) -> ()
657+
gpu.terminator
658+
}
659+
gpu.terminator
660+
}
661+
return
662+
}

0 commit comments

Comments
 (0)