Skip to content

Commit 4ff9063

Browse files
committed
More changes based on review feedback.
More tests.
1 parent 9110630 commit 4ff9063

File tree

3 files changed

+106
-6
lines changed

3 files changed

+106
-6
lines changed

mlir/include/mlir/Dialect/GPU/IR/GPUOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -840,7 +840,7 @@ def GPU_LaunchOp : GPU_Op<"launch", [
840840
- a variadic number of Private memory attributions.
841841

842842
The `kernelFunc` and `kernelModule` attributes are optional and specifies
843-
the kernel name and a module in whichthe kernel should be outlined.
843+
the kernel name and a module in which the kernel should be outlined.
844844

845845
Syntax:
846846

mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -411,14 +411,12 @@ class GpuKernelOutliningPass
411411
auto *context = getOperation().getContext();
412412
OpBuilder builder(context);
413413
std::string kernelModuleName;
414-
gpu::GPUModuleOp kernelModule = nullptr;
414+
gpu::GPUModuleOp kernelModule;
415415
if (gpuLaunchOp.getKernelModule()) {
416416
kernelModuleName =
417417
gpuLaunchOp.getKernelModule()->getRootReference().str();
418-
if (auto existingModule =
419-
parentSymbolTable.lookup<gpu::GPUModuleOp>(kernelModuleName)) {
420-
kernelModule = existingModule;
421-
}
418+
kernelModule =
419+
parentSymbolTable.lookup<gpu::GPUModuleOp>(kernelModuleName);
422420
} else {
423421
kernelModuleName = kernelFunc.getName();
424422
}

mlir/test/Dialect/GPU/outlining.mlir

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -529,3 +529,105 @@ func.func @testKernelAttributes() {
529529
} {kernelModule = @test_module, kernelFunc = @test_kernel_func}
530530
return
531531
}
532+
533+
// -----
534+
// This test tests the two optional attributes kernelModule and kernelFunc for gpu.launch, when kernelModule already exists.
535+
536+
// CHECK-LABEL: gpu.module @existing_module
537+
// CHECK: gpu.func @test_kernel_func()
538+
// CHECK: gpu.func @test_kernel_func_0()
539+
// CHECK-NOT: gpu.module @testExistingModule_kernel
540+
// CHECK-NOT: gpu.func @testExistingModule_kernel()
541+
// CHECK: func.func @testExistingModule()
542+
// CHECK: gpu.launch_func @existing_module::@test_kernel_func_0 blocks in (%[[GRID_X:.*]], %[[GRID_Y:.*]], %[[GRID_Z:.*]]) threads in (%[[BLOCK_X:.*]], %[[BLOCK_Y:.*]], %[[BLOCK_Z:.*]])
543+
544+
gpu.module @existing_module {
545+
gpu.func @test_kernel_func() {
546+
gpu.return
547+
}
548+
}
549+
550+
func.func @testExistingModule() {
551+
%gDimX = arith.constant 8 : index
552+
%gDimY = arith.constant 12 : index
553+
%gDimZ = arith.constant 16 : index
554+
%bDimX = arith.constant 32 : index
555+
%bDimY = arith.constant 16 : index
556+
%bDimZ = arith.constant 8 : index
557+
558+
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %gDimX, %grid_y = %gDimY, %grid_z = %gDimZ)
559+
threads(%tx, %ty, %tz) in (%block_x = %bDimX, %block_y = %bDimY, %block_z = %bDimZ) {
560+
"some_op"(%bx, %tx) : (index, index) -> ()
561+
gpu.terminator
562+
} {kernelModule = @existing_module, kernelFunc = @test_kernel_func}
563+
return
564+
}
565+
566+
// -----
567+
// This test tests the optional attribute kernelModule for gpu.launch.
568+
// CHECK-LABEL: func.func @testKernelModuleOnly()
569+
// CHECK: gpu.launch_func @test_module::@testKernelModuleOnly_kernel blocks in (%[[GRID_X:.*]], %[[GRID_Y:.*]], %[[GRID_Z:.*]]) threads in (%[[BLOCK_X:.*]], %[[BLOCK_Y:.*]], %[[BLOCK_Z:.*]])
570+
// CHECK: gpu.module @test_module
571+
// CHECK: gpu.func @testKernelModuleOnly_kernel()
572+
func.func @testKernelModuleOnly() {
573+
%gDimX = arith.constant 8 : index
574+
%gDimY = arith.constant 12 : index
575+
%gDimZ = arith.constant 16 : index
576+
%bDimX = arith.constant 32 : index
577+
%bDimY = arith.constant 16 : index
578+
%bDimZ = arith.constant 8 : index
579+
580+
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %gDimX, %grid_y = %gDimY, %grid_z = %gDimZ)
581+
threads(%tx, %ty, %tz) in (%block_x = %bDimX, %block_y = %bDimY, %block_z = %bDimZ) {
582+
"some_op"(%bx, %tx) : (index, index) -> ()
583+
gpu.terminator
584+
} {kernelModule = @test_module}
585+
return
586+
}
587+
588+
// -----
589+
// This test tests the optional attribute kernelFunc for gpu.launch.
590+
// CHECK-LABEL: func.func @testKernelFuncOnly()
591+
// CHECK: gpu.launch_func @test_kernel_func::@test_kernel_func blocks in (%[[GRID_X:.*]], %[[GRID_Y:.*]], %[[GRID_Z:.*]]) threads in (%[[BLOCK_X:.*]], %[[BLOCK_Y:.*]], %[[BLOCK_Z:.*]])
592+
593+
// CHECK: gpu.module @test_kernel_func
594+
// CHECK: gpu.func @test_kernel_func()
595+
func.func @testKernelFuncOnly() {
596+
%gDimX = arith.constant 8 : index
597+
%gDimY = arith.constant 12 : index
598+
%gDimZ = arith.constant 16 : index
599+
%bDimX = arith.constant 32 : index
600+
%bDimY = arith.constant 16 : index
601+
%bDimZ = arith.constant 8 : index
602+
603+
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %gDimX, %grid_y = %gDimY, %grid_z = %gDimZ)
604+
threads(%tx, %ty, %tz) in (%block_x = %bDimX, %block_y = %bDimY, %block_z = %bDimZ) {
605+
"some_op"(%bx, %tx) : (index, index) -> ()
606+
gpu.terminator
607+
} {kernelFunc = @test_kernel_func}
608+
return
609+
}
610+
611+
612+
// -----
613+
// This test tests gpu.launch when optional attributes kernelModule and kernelFunc are not specified.
614+
// CHECK-LABEL: func.func @testNoAttributes()
615+
// CHECK: gpu.launch_func @testNoAttributes_kernel::@testNoAttributes_kernel blocks in (%[[GRID_X:.*]], %[[GRID_Y:.*]], %[[GRID_Z:.*]]) threads in (%[[BLOCK_X:.*]], %[[BLOCK_Y:.*]], %[[BLOCK_Z:.*]])
616+
617+
// CHECK: gpu.module @testNoAttributes_kernel
618+
// CHECK: gpu.func @testNoAttributes_kernel()
619+
func.func @testNoAttributes() {
620+
%gDimX = arith.constant 8 : index
621+
%gDimY = arith.constant 12 : index
622+
%gDimZ = arith.constant 16 : index
623+
%bDimX = arith.constant 32 : index
624+
%bDimY = arith.constant 16 : index
625+
%bDimZ = arith.constant 8 : index
626+
627+
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %gDimX, %grid_y = %gDimY, %grid_z = %gDimZ)
628+
threads(%tx, %ty, %tz) in (%block_x = %bDimX, %block_y = %bDimY, %block_z = %bDimZ) {
629+
"some_op"(%bx, %tx) : (index, index) -> ()
630+
gpu.terminator
631+
}
632+
return
633+
}

0 commit comments

Comments
 (0)