From b91b29a4b90ba2965ba7e8ed609c12f23a1e42e8 Mon Sep 17 00:00:00 2001 From: "Misha (M3 MBP)" Date: Fri, 17 Jan 2025 17:51:40 +0100 Subject: [PATCH 1/3] tests --- mlir/test/Target/SPIRV/branch-load.mlir | 35 ++++++++++++++++++++++ mlir/test/Target/SPIRV/ssa.mlir | 40 +++++++++++++++++++++++++ mlir/test/Target/SPIRV/ssa2.mlir | 36 ++++++++++++++++++++++ 3 files changed, 111 insertions(+) create mode 100644 mlir/test/Target/SPIRV/branch-load.mlir create mode 100644 mlir/test/Target/SPIRV/ssa.mlir create mode 100644 mlir/test/Target/SPIRV/ssa2.mlir diff --git a/mlir/test/Target/SPIRV/branch-load.mlir b/mlir/test/Target/SPIRV/branch-load.mlir new file mode 100644 index 0000000000000..2a0c376e9302c --- /dev/null +++ b/mlir/test/Target/SPIRV/branch-load.mlir @@ -0,0 +1,35 @@ +// RUN: mlir-translate -no-implicit-module -test-spirv-roundtrip %s | FileCheck %s + +// CHECK: spirv.module Logical GLSL450 requires #spirv.vce +spirv.module Logical GLSL450 requires #spirv.vce { + // CHECK: spirv.func @main() "None" + spirv.func @main() "None" { + // CHECK: %[[VAR:.*]] = spirv.Variable : !spirv.ptr, Function> + %0 = spirv.Variable : !spirv.ptr, Function> + spirv.Branch ^bb1 + ^bb1: // pred: ^bb0 + // CHECK: spirv.mlir.selection + spirv.mlir.selection { + // CHECK: %[[COND:.*]] = spirv.Constant true + // CHECK: spirv.BranchConditional %[[COND]] + %true = spirv.Constant true + spirv.BranchConditional %true, ^bb1, ^bb2 + ^bb1: // pred: ^bb0 + // CHECK: %[[CONST:.*]] = spirv.Constant dense<0.000000e+00> : vector<3xf32> + // CHECK: spirv.Store "Function" %[[VAR]], %[[CONST]] : vector<3xf32> + %cst_vec_3xf32 = spirv.Constant dense<0.000000e+00> : vector<3xf32> + spirv.Store "Function" %0, %cst_vec_3xf32 : vector<3xf32> + spirv.Branch ^bb2 + ^bb2: // 2 preds: ^bb0, ^bb1 + spirv.mlir.merge + } + // CHECK: %[[RESULT:.*]] = spirv.Load "Function" %[[VAR]] : vector<3xf32> + // CHECK: spirv.Return + %1 = spirv.Load "Function" %0 : vector<3xf32> + spirv.Return + } + // CHECK: spirv.EntryPoint "Fragment" @main + // CHECK: spirv.ExecutionMode @main "OriginUpperLeft" + spirv.EntryPoint "Fragment" @main + spirv.ExecutionMode @main "OriginUpperLeft" +} diff --git a/mlir/test/Target/SPIRV/ssa.mlir b/mlir/test/Target/SPIRV/ssa.mlir new file mode 100644 index 0000000000000..d3556991ba463 --- /dev/null +++ b/mlir/test/Target/SPIRV/ssa.mlir @@ -0,0 +1,40 @@ +# RUN: split-file %s %t +# RUN: spirv-as --target-env spv1.0 %t/spv.spvasm -o %t.spv +# RUN: mlir-translate --deserialize-spirv %t.spv -o - | FileCheck %s + +// CHECK: module +// CHECK: spirv.func @main +// CHECK: spirv.Variable +// CHECK: spirv.Return +//--- spv.spvasm +; SPIR-V +; Version: 1.0 +; Generator: Khronos SPIR-V Tools Assembler; 0 +; Bound: 20 +; Schema: 0 + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" + OpExecutionMode %main OriginUpperLeft + %void = OpTypeVoid + %float = OpTypeFloat 32 + %v3float = OpTypeVector %float 3 + %ptr_v3f = OpTypePointer Function %v3float + %fn = OpTypeFunction %void + %float_0 = OpConstant %float 0 + %float_1 = OpConstant %float 1 + %bool = OpTypeBool + %true = OpConstantTrue %bool + %v3_zero = OpConstantComposite %v3float %float_0 %float_0 %float_0 + %main = OpFunction %void None %fn + %entry = OpLabel + %var = OpVariable %ptr_v3f Function + OpSelectionMerge %merge None + OpBranchConditional %true %then %merge + %then = OpLabel + OpStore %var %v3_zero + OpBranch %merge + %merge = OpLabel + %load = OpLoad %v3float %var + OpReturn + OpFunctionEnd \ No newline at end of file diff --git a/mlir/test/Target/SPIRV/ssa2.mlir b/mlir/test/Target/SPIRV/ssa2.mlir new file mode 100644 index 0000000000000..e50d5cb571bbc --- /dev/null +++ b/mlir/test/Target/SPIRV/ssa2.mlir @@ -0,0 +1,36 @@ +# RUN: split-file %s %t +# RUN: spirv-as --target-env spv1.0 %t/spv.spvasm -o %t.spv +# RUN: mlir-translate --deserialize-spirv %t.spv -o - | FileCheck %s +// CHECK: module +// CHECK: spirv.func @main +// CHECK: spirv.Variable +// CHECK: spirv.Return +//--- spv.spvasm +; SPIR-V +; Version: 1.0 +; Generator: Khronos SPIR-V Tools Assembler; 0 +; Bound: 20 +; Schema: 0 + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" + OpExecutionMode %main OriginUpperLeft + %void = OpTypeVoid + %float = OpTypeFloat 32 + %ptr_f = OpTypePointer Function %float + %fn = OpTypeFunction %void + %float_1 = OpConstant %float 1.0 + %bool = OpTypeBool + %true = OpConstantTrue %bool + %main = OpFunction %void None %fn + %entry = OpLabel + %var = OpVariable %ptr_f Function + OpSelectionMerge %merge None + OpBranchConditional %true %then %merge + %then = OpLabel + OpStore %var %float_1 + OpBranch %merge + %merge = OpLabel + %load = OpLoad %float %var + OpReturn + OpFunctionEnd \ No newline at end of file From f61b88e59a45c2a9b74693bc7a76c894fc533e46 Mon Sep 17 00:00:00 2001 From: "Misha (M3 MBP)" Date: Fri, 17 Jan 2025 17:52:16 +0100 Subject: [PATCH 2/3] fix spirv -> mlir translation ssa handling --- .../SPIRV/Deserialization/Deserializer.cpp | 22 ++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp index 04469f1933819..ee62b7da66fc2 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -1937,8 +1937,16 @@ LogicalResult ControlFlowStructurizer::structurize() { << "[cf] block " << block << " is a function entry block\n"); } - for (auto &op : *block) + for (auto &op : *block) { + if (auto varOp = dyn_cast(op)) { + if (varOp.getStorageClass() == spirv::StorageClass::Function) { // This prevents %1 variable duplication in composite4anti + // For function-scoped variables, ensure proper mapping but maintain their original location + mapper.map(&op, &op); + continue; + } + } newBlock->push_back(op.clone(mapper)); + } } // Go through all ops and remap the operands. @@ -2006,6 +2014,12 @@ LogicalResult ControlFlowStructurizer::structurize() { // the SelectionOp/LoopOp's region, there is no escape for it: // SelectionOp/LooOp does not support yield values right now. for (auto *block : constructBlocks) { + block->walk([&](spirv::VariableOp varOp) { + if (varOp.getStorageClass() == spirv::StorageClass::Function) { + // Move function variables to the entry block to preserve their lifetime + varOp->moveBefore(&body.front().front()); + } + }); for (Operation &op : *block) if (!op.use_empty()) return op.emitOpError( @@ -2070,6 +2084,12 @@ LogicalResult ControlFlowStructurizer::structurize() { } } + if (auto selectionOp = llvm::dyn_cast(op)) { + selectionOp.walk([&](spirv::VariableOp varOp) { + varOp->moveBefore(&op->getParentRegion()->front().front()); + }); + } + LLVM_DEBUG(logger.startLine() << "[cf] after structurizing construct with header block " << headerBlock << ":\n" From 0dd1f10e78d7729f5ea62073d17ca1eda15ba443 Mon Sep 17 00:00:00 2001 From: "Misha (M3 MBP)" Date: Fri, 17 Jan 2025 17:57:41 +0100 Subject: [PATCH 3/3] cleanup comments --- mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp index ee62b7da66fc2..41008a1c73bfe 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -1939,8 +1939,7 @@ LogicalResult ControlFlowStructurizer::structurize() { for (auto &op : *block) { if (auto varOp = dyn_cast(op)) { - if (varOp.getStorageClass() == spirv::StorageClass::Function) { // This prevents %1 variable duplication in composite4anti - // For function-scoped variables, ensure proper mapping but maintain their original location + if (varOp.getStorageClass() == spirv::StorageClass::Function) { mapper.map(&op, &op); continue; } @@ -2016,7 +2015,6 @@ LogicalResult ControlFlowStructurizer::structurize() { for (auto *block : constructBlocks) { block->walk([&](spirv::VariableOp varOp) { if (varOp.getStorageClass() == spirv::StorageClass::Function) { - // Move function variables to the entry block to preserve their lifetime varOp->moveBefore(&body.front().front()); } });