From 45621c70d913766691f87878d2ccb02e3ccf27b0 Mon Sep 17 00:00:00 2001 From: Igor Wodiany Date: Mon, 25 Nov 2024 16:41:03 +0000 Subject: [PATCH] [mlir][spirv] Split conditional basic blocks during deserialization With the current design some of the values are sank into a selection region, despite them being also used outside that region. This is because the current deserializer logic sinks the entire basic block containing a conditional branch forming a header of a selection construct, without accounting for some values being used outside. This manifests as (for example): ``` :0: error: 'spirv.Variable' op failed control flow structurization: it has uses outside of the enclosing selection/loop construct :0: note: see current operation: %0 = "spirv.Variable"()<{storage_class = #spirv.storage_class}> : () -> !spirv.ptr, Function> ``` The proposed solution to this problem is to split the conditional basic block into two, one block containing just the conditional branch, and other the rest of instructions. By doing this, the logic that structures selection regions, only sinks the comparison, keeping the rest of instructions outside the selection region. A SPIR-V test is required, as the problem can happen only during deserialization and cannot be tested with `--test-spirv-roundtrip`. An MLIR test exhibiting the problematic behaviour would be an incorrect MLIR in the first place. This solution is proposed as an alternative to an unfinished PR #123371, that is unlikely to be merged in the foreseeable future, as the author "stepped away from this for a time being". There is also a Discourse thread: https://discourse.llvm.org/t/spir-v-uses-outside-the-selection-region/84494 that tried to solicit some feedback on the topic. --- .../SPIRV/Deserialization/Deserializer.cpp | 59 ++++++++++++++++++ .../SPIRV/Deserialization/Deserializer.h | 4 ++ mlir/test/Target/SPIRV/selection.mlir | 2 + mlir/test/Target/SPIRV/selection.spv | 60 +++++++++++++++++++ mlir/test/lit.cfg.py | 1 + 5 files changed, 126 insertions(+) create mode 100644 mlir/test/Target/SPIRV/selection.spv diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp index 04469f1933819..8ebe8d54b041c 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -2158,6 +2158,53 @@ LogicalResult spirv::Deserializer::wireUpBlockArgument() { return success(); } +LogicalResult spirv::Deserializer::splitConditionalBlocks() { + // Create a copy, so we can modify keys in the original. + BlockMergeInfoMap blockMergeInfoCopy = blockMergeInfo; + for (auto it = blockMergeInfoCopy.begin(), e = blockMergeInfoCopy.end(); + it != e; ++it) { + auto &[block, mergeInfo] = *it; + + // Skip processing loop regions. For loop regions continueBlock is non-null. + if (mergeInfo.continueBlock) + continue; + + if (!block->mightHaveTerminator()) + continue; + + Operation *terminator = block->getTerminator(); + assert(terminator); + + if (!isa(terminator)) + continue; + + // Do not split blocks that only contain a conditional branch, i.e., block + // size is <= 1. + if (block->begin() != block->end() && + std::next(block->begin()) != block->end()) { + Block *newBlock = block->splitBlock(terminator); + OpBuilder builder(block, block->end()); + builder.create(block->getParent()->getLoc(), newBlock); + + // If the split block was a merge block of another region we need to + // update the map. + for (auto it = blockMergeInfo.begin(); it != blockMergeInfo.end(); ++it) { + auto &[ignore, mergeInfo] = *it; + if (mergeInfo.mergeBlock == block) { + mergeInfo.mergeBlock = newBlock; + } + } + + // After splitting we need to update the map to use the new block as a + // header. + blockMergeInfo.erase(block); + blockMergeInfo.try_emplace(newBlock, mergeInfo); + } + } + + return success(); +} + LogicalResult spirv::Deserializer::structurizeControlFlow() { LLVM_DEBUG({ logger.startLine() @@ -2165,6 +2212,18 @@ LogicalResult spirv::Deserializer::structurizeControlFlow() { logger.indent(); }); + LLVM_DEBUG({ + logger.startLine() << "[cf] split conditional blocks\n"; + logger.startLine() << "\n"; + }); + + if (failed(splitConditionalBlocks())) { + return failure(); + } + + // TODO: This loop is non-deterministic. Iteration order may vary between runs + // for the same shader as the key to the map is a pointer. See: + // https://github.com/llvm/llvm-project/issues/128547 while (!blockMergeInfo.empty()) { Block *headerBlock = blockMergeInfo.begin()->first; BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second; diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h index 264d580c40f09..8dd35aa876726 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h @@ -246,6 +246,10 @@ class Deserializer { return opBuilder.getStringAttr(attrName); } + // Move a conditional branch into a separate basic block to avoid sinking + // defs that are required outside a selection region. + LogicalResult splitConditionalBlocks(); + //===--------------------------------------------------------------------===// // Type //===--------------------------------------------------------------------===// diff --git a/mlir/test/Target/SPIRV/selection.mlir b/mlir/test/Target/SPIRV/selection.mlir index f1d35d74dba15..24abb12998d06 100644 --- a/mlir/test/Target/SPIRV/selection.mlir +++ b/mlir/test/Target/SPIRV/selection.mlir @@ -105,6 +105,8 @@ spirv.module Logical GLSL450 requires #spirv.vce { %var = spirv.Variable : !spirv.ptr // CHECK-NEXT: spirv.Branch ^[[BB:.+]] // CHECK-NEXT: ^[[BB]]: +// CHECK: spirv.Branch ^[[BB:.+]] +// CHECK-NEXT: ^[[BB]]: // CHECK-NEXT: spirv.mlir.selection { spirv.mlir.selection { diff --git a/mlir/test/Target/SPIRV/selection.spv b/mlir/test/Target/SPIRV/selection.spv new file mode 100644 index 0000000000000..9642d0a44fb59 --- /dev/null +++ b/mlir/test/Target/SPIRV/selection.spv @@ -0,0 +1,60 @@ +; RUN: %if spirv-tools %{ spirv-as --target-env spv1.0 %s -o - | mlir-translate --deserialize-spirv - -o - | FileCheck %s %} + +; COM: The purpose of this test is to check that a variable (in this case %color) that +; COM: is defined before a selection region and used both in the selection region and +; COM: after it, is not sunk into that selection region by the deserializer. If the +; COM: variable is sunk, then it cannot be accessed outside the region and causes +; COM: control-flow structurization to fail. + +; CHECK: spirv.module Logical GLSL450 requires #spirv.vce { +; CHECK: spirv.func @main() "None" { +; CHECK: spirv.Variable : !spirv.ptr, Function> +; CHECK: spirv.mlir.selection { +; CHECK-NEXT: spirv.BranchConditional {{.*}}, ^[[bb:.+]], ^[[bb:.+]] +; CHECK-NEXT: ^[[bb:.+]] +; CHECK: spirv.Branch ^[[bb:.+]] +; CHECK-NEXT: ^[[bb:.+]]: +; CHECK-NEXT: spirv.mlir.merge +; CHECK-NEXT: } +; CHECK: spirv.Return +; CHECK-NEXT: } +; CHECK: } + + OpCapability Shader + %2 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %colorOut + OpExecutionMode %main OriginUpperLeft + OpDecorate %colorOut Location 0 + %void = OpTypeVoid + %4 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 +%fun_v4float = OpTypePointer Function %v4float + %float_1 = OpConstant %float 1 + %float_0 = OpConstant %float 0 + %13 = OpConstantComposite %v4float %float_1 %float_0 %float_0 %float_1 +%out_v4float = OpTypePointer Output %v4float + %colorOut = OpVariable %out_v4float Output + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 + %out_float = OpTypePointer Output %float + %bool = OpTypeBool + %25 = OpConstantComposite %v4float %float_1 %float_1 %float_0 %float_1 + %main = OpFunction %void None %4 + %6 = OpLabel + %color = OpVariable %fun_v4float Function + OpStore %color %13 + %19 = OpAccessChain %out_float %colorOut %uint_0 + %20 = OpLoad %float %19 + %22 = OpFOrdEqual %bool %20 %float_1 + OpSelectionMerge %24 None + OpBranchConditional %22 %23 %24 + %23 = OpLabel + OpStore %color %25 + OpBranch %24 + %24 = OpLabel + %26 = OpLoad %v4float %color + OpStore %colorOut %26 + OpReturn + OpFunctionEnd diff --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py index 32b2f8b53d5fa..8578c76969e74 100644 --- a/mlir/test/lit.cfg.py +++ b/mlir/test/lit.cfg.py @@ -43,6 +43,7 @@ ".test", ".pdll", ".c", + ".spv", ] # test_source_root: The root path where tests are located.