Skip to content

Commit 26c711e

Browse files
authored
set-spirv-api-attrs: Walk and find gpu module. (#1100)
1 parent 6ec8419 commit 26c711e

File tree

3 files changed

+52
-26
lines changed

3 files changed

+52
-26
lines changed

include/imex/Transforms/Passes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def SetSPIRVCapabilities : Pass<"set-spirv-capabilities"> {
6565
];
6666
}
6767

68-
def SetSPIRVAbiAttribute : Pass<"set-spirv-abi-attrs", "::mlir::gpu::GPUModuleOp"> {
68+
def SetSPIRVAbiAttribute : Pass<"set-spirv-abi-attrs"> {
6969
let summary = "Sets Spirv Abi attribute";
7070
let constructor = "imex::createSetSPIRVAbiAttributePass()";
7171
let dependentDialects = ["::mlir::gpu::GPUDialect",

lib/Transforms/SetSPIRVAbiAttribute.cpp

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -53,31 +53,34 @@ class SetSPIRVAbiAttributePass
5353
}
5454

5555
void runOnOperation() override {
56-
auto gpuModule = getOperation();
5756
auto *context = &getContext();
5857
auto attrName =
5958
mlir::StringAttr::get(context, mlir::spirv::getEntryPointABIAttrName());
60-
if (m_clientAPI == "opencl") {
61-
auto abi = mlir::spirv::getEntryPointABIAttr(context);
62-
for (const auto &gpuFunc : gpuModule.getOps<mlir::gpu::GPUFuncOp>()) {
63-
if (!mlir::gpu::GPUDialect::isKernel(gpuFunc) ||
64-
gpuFunc->getAttr(attrName))
65-
continue;
59+
auto op = getOperation();
60+
mlir::StringRef clientAPI = m_clientAPI;
61+
op->walk([&](mlir::gpu::GPUModuleOp gpuModule) {
62+
if (clientAPI == "opencl") {
63+
auto abi = mlir::spirv::getEntryPointABIAttr(context);
64+
for (const auto &gpuFunc : gpuModule.getOps<mlir::gpu::GPUFuncOp>()) {
65+
if (!mlir::gpu::GPUDialect::isKernel(gpuFunc) ||
66+
gpuFunc->getAttr(attrName))
67+
continue;
6668

67-
gpuFunc->setAttr("VectorComputeFunctionINTEL",
68-
mlir::UnitAttr::get(context));
69-
gpuFunc->setAttr(attrName, abi);
70-
}
71-
} else if (m_clientAPI == "vulkan") {
72-
auto abi = mlir::spirv::getEntryPointABIAttr(context, {1, 1, 1});
73-
for (const auto &gpuFunc : gpuModule.getOps<mlir::gpu::GPUFuncOp>()) {
74-
if (!mlir::gpu::GPUDialect::isKernel(gpuFunc) ||
75-
gpuFunc->getAttr(attrName))
76-
continue;
69+
gpuFunc->setAttr("VectorComputeFunctionINTEL",
70+
mlir::UnitAttr::get(context));
71+
gpuFunc->setAttr(attrName, abi);
72+
}
73+
} else if (clientAPI == "vulkan") {
74+
auto abi = mlir::spirv::getEntryPointABIAttr(context, {1, 1, 1});
75+
for (const auto &gpuFunc : gpuModule.getOps<mlir::gpu::GPUFuncOp>()) {
76+
if (!mlir::gpu::GPUDialect::isKernel(gpuFunc) ||
77+
gpuFunc->getAttr(attrName))
78+
continue;
7779

78-
gpuFunc->setAttr(attrName, abi);
80+
gpuFunc->setAttr(attrName, abi);
81+
}
7982
}
80-
}
83+
});
8184
}
8285

8386
private:
Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
// RUN: imex-opt --set-spirv-abi-attrs='client-api=opencl' %s | FileCheck %s --check-prefix=OPENCL
2-
// RUN: imex-opt --set-spirv-abi-attrs='client-api=vulkan' %s | FileCheck %s --check-prefix=VULKAN
1+
// RUN: imex-opt --split-input-file --set-spirv-abi-attrs='client-api=opencl' %s | FileCheck %s --check-prefix=OPENCL
2+
// RUN: imex-opt --split-input-file --set-spirv-abi-attrs='client-api=vulkan' %s | FileCheck %s --check-prefix=VULKAN
33

44
gpu.module @main_kernel {
55
gpu.func @main_kernel(%arg0: memref<8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>) kernel {
6-
7-
// OPENCL: gpu.func @main_kernel(%arg0: memref<8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
8-
// VULKAN: gpu.func @main_kernel(%arg0: memref<8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>) kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [1, 1, 1]>} {
9-
6+
// OPENCL-LABEL: gpu.func @main_kernel(%arg0: memref<8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>) kernel attributes
7+
// OPENCL-SAME: {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
8+
// VULKAN-LABEL: gpu.func @main_kernel(%arg0: memref<8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>) kernel attributes
9+
// VULKAN-SAME: {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [1, 1, 1]>} {
1010
cf.br ^bb1
1111
^bb1: // pred: ^bb0
1212
%0 = gpu.block_id x
@@ -17,3 +17,26 @@ gpu.module @main_kernel {
1717
gpu.return
1818
}
1919
}
20+
21+
// -----
22+
23+
module {
24+
module attributes {gpu.container_module} {
25+
func.func @run(%arg0: memref<4096x4096xf16>) -> memref<4096x4096xf16> attributes {llvm.emit_c_interface} {
26+
%c1 = arith.constant 1 : index
27+
gpu.launch_func @run_kernel::@run_kernel blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%arg0 : memref<4096x4096xf16>)
28+
return %arg0 : memref<4096x4096xf16>
29+
}
30+
gpu.module @run_kernel {
31+
// OPENCL-LABEL: gpu.func @run_kernel(%arg0: memref<4096x4096xf16>) kernel attributes
32+
// OPENCL-SAME: {VectorComputeFunctionINTEL, known_block_size = array<i32: 1, 1, 1>, known_grid_size = array<i32: 1, 1, 1>,
33+
// OPENCL-SAME: spirv.entry_point_abi = #spirv.entry_point_abi<>} {
34+
// VULKAN-LABEL: gpu.func @run_kernel(%arg0: memref<4096x4096xf16>) kernel attributes
35+
// VULKAN-SAME: {known_block_size = array<i32: 1, 1, 1>, known_grid_size = array<i32: 1, 1, 1>,
36+
// VULKAN-SAME: spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [1, 1, 1]>} {
37+
gpu.func @run_kernel(%arg0: memref<4096x4096xf16>) kernel attributes {known_block_size = array<i32: 1, 1, 1>, known_grid_size = array<i32: 1, 1, 1>} {
38+
gpu.return
39+
}
40+
}
41+
}
42+
}

0 commit comments

Comments
 (0)