Skip to content

Commit 1150e8e

Browse files
authored
[mlir::spirv] Support scf.if in mlir-vulkan-runner (#75367)
1. Register SCFDialect in mlir-vulkan-runner 2. Add SCFToSPIRV in GPUToSPIRVPass to lower scf. Fixes llvm/llvm-project#74939
1 parent 38c9390 commit 1150e8e

File tree

3 files changed

+60
-2
lines changed

3 files changed

+60
-2
lines changed

mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h"
1818
#include "mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h"
1919
#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h"
20+
#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
2021
#include "mlir/Dialect/Func/IR/FuncOps.h"
2122
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
2223
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
@@ -126,6 +127,8 @@ void GPUToSPIRVPass::runOnOperation() {
126127

127128
// TODO: Change SPIR-V conversion to be progressive and remove the following
128129
// patterns.
130+
ScfToSPIRVContext scfContext;
131+
populateSCFToSPIRVPatterns(typeConverter, scfContext, patterns);
129132
mlir::arith::populateArithToSPIRVPatterns(typeConverter, patterns);
130133
populateMemRefToSPIRVPatterns(typeConverter, patterns);
131134
populateFuncToSPIRVPatterns(typeConverter, patterns);
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
// RUN: mlir-vulkan-runner %s --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils --entry-point-result=void | FileCheck %s
2+
3+
// CHECK: [3.3, 3.3, 3.3, 3.3, 0, 0, 0, 0]
4+
module attributes {
5+
gpu.container_module,
6+
spirv.target_env = #spirv.target_env<
7+
#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
8+
} {
9+
gpu.module @kernels {
10+
gpu.func @kernel_add(%arg0 : memref<8xf32>, %arg1 : memref<8xf32>, %arg2 : memref<8xf32>)
11+
kernel attributes { spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [1, 1, 1]>} {
12+
%0 = gpu.block_id x
13+
%limit = arith.constant 4 : index
14+
%cond = arith.cmpi slt, %0, %limit : index
15+
scf.if %cond {
16+
%1 = memref.load %arg0[%0] : memref<8xf32>
17+
%2 = memref.load %arg1[%0] : memref<8xf32>
18+
%3 = arith.addf %1, %2 : f32
19+
memref.store %3, %arg2[%0] : memref<8xf32>
20+
}
21+
gpu.return
22+
}
23+
}
24+
25+
func.func @main() {
26+
%arg0 = memref.alloc() : memref<8xf32>
27+
%arg1 = memref.alloc() : memref<8xf32>
28+
%arg2 = memref.alloc() : memref<8xf32>
29+
%0 = arith.constant 0 : i32
30+
%1 = arith.constant 1 : i32
31+
%2 = arith.constant 2 : i32
32+
%value0 = arith.constant 0.0 : f32
33+
%value1 = arith.constant 1.1 : f32
34+
%value2 = arith.constant 2.2 : f32
35+
%arg3 = memref.cast %arg0 : memref<8xf32> to memref<?xf32>
36+
%arg4 = memref.cast %arg1 : memref<8xf32> to memref<?xf32>
37+
%arg5 = memref.cast %arg2 : memref<8xf32> to memref<?xf32>
38+
call @fillResource1DFloat(%arg3, %value1) : (memref<?xf32>, f32) -> ()
39+
call @fillResource1DFloat(%arg4, %value2) : (memref<?xf32>, f32) -> ()
40+
call @fillResource1DFloat(%arg5, %value0) : (memref<?xf32>, f32) -> ()
41+
42+
%cst1 = arith.constant 1 : index
43+
%cst8 = arith.constant 8 : index
44+
gpu.launch_func @kernels::@kernel_add
45+
blocks in (%cst8, %cst1, %cst1) threads in (%cst1, %cst1, %cst1)
46+
args(%arg0 : memref<8xf32>, %arg1 : memref<8xf32>, %arg2 : memref<8xf32>)
47+
%arg6 = memref.cast %arg5 : memref<?xf32> to memref<*xf32>
48+
call @printMemrefF32(%arg6) : (memref<*xf32>) -> ()
49+
return
50+
}
51+
func.func private @fillResource1DFloat(%0 : memref<?xf32>, %1 : f32)
52+
func.func private @printMemrefF32(%ptr : memref<*xf32>)
53+
}
54+

mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "mlir/Dialect/LLVMIR/Transforms/RequestCWrappers.h"
2828
#include "mlir/Dialect/MemRef/IR/MemRef.h"
2929
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
30+
#include "mlir/Dialect/SCF/IR/SCF.h"
3031
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
3132
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
3233
#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
@@ -105,8 +106,8 @@ int main(int argc, char **argv) {
105106
mlir::DialectRegistry registry;
106107
registry.insert<mlir::arith::ArithDialect, mlir::LLVM::LLVMDialect,
107108
mlir::gpu::GPUDialect, mlir::spirv::SPIRVDialect,
108-
mlir::func::FuncDialect, mlir::memref::MemRefDialect,
109-
mlir::vector::VectorDialect>();
109+
mlir::scf::SCFDialect, mlir::func::FuncDialect,
110+
mlir::memref::MemRefDialect, mlir::vector::VectorDialect>();
110111
mlir::registerBuiltinDialectTranslation(registry);
111112
mlir::registerLLVMDialectTranslation(registry);
112113

0 commit comments

Comments
 (0)