diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp index 04469f1933819..8ebe8d54b041c 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -2158,6 +2158,53 @@ LogicalResult spirv::Deserializer::wireUpBlockArgument() { return success(); } +LogicalResult spirv::Deserializer::splitConditionalBlocks() { + // Create a copy, so we can modify keys in the original. + BlockMergeInfoMap blockMergeInfoCopy = blockMergeInfo; + for (auto it = blockMergeInfoCopy.begin(), e = blockMergeInfoCopy.end(); + it != e; ++it) { + auto &[block, mergeInfo] = *it; + + // Skip processing loop regions. For loop regions continueBlock is non-null. + if (mergeInfo.continueBlock) + continue; + + if (!block->mightHaveTerminator()) + continue; + + Operation *terminator = block->getTerminator(); + assert(terminator); + + if (!isa(terminator)) + continue; + + // Do not split blocks that only contain a conditional branch, i.e., block + // size is <= 1. + if (block->begin() != block->end() && + std::next(block->begin()) != block->end()) { + Block *newBlock = block->splitBlock(terminator); + OpBuilder builder(block, block->end()); + builder.create(block->getParent()->getLoc(), newBlock); + + // If the split block was a merge block of another region we need to + // update the map. + for (auto it = blockMergeInfo.begin(); it != blockMergeInfo.end(); ++it) { + auto &[ignore, mergeInfo] = *it; + if (mergeInfo.mergeBlock == block) { + mergeInfo.mergeBlock = newBlock; + } + } + + // After splitting we need to update the map to use the new block as a + // header. + blockMergeInfo.erase(block); + blockMergeInfo.try_emplace(newBlock, mergeInfo); + } + } + + return success(); +} + LogicalResult spirv::Deserializer::structurizeControlFlow() { LLVM_DEBUG({ logger.startLine() @@ -2165,6 +2212,18 @@ LogicalResult spirv::Deserializer::structurizeControlFlow() { logger.indent(); }); + LLVM_DEBUG({ + logger.startLine() << "[cf] split conditional blocks\n"; + logger.startLine() << "\n"; + }); + + if (failed(splitConditionalBlocks())) { + return failure(); + } + + // TODO: This loop is non-deterministic. Iteration order may vary between runs + // for the same shader as the key to the map is a pointer. See: + // https://github.com/llvm/llvm-project/issues/128547 while (!blockMergeInfo.empty()) { Block *headerBlock = blockMergeInfo.begin()->first; BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second; diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h index 264d580c40f09..8dd35aa876726 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h @@ -246,6 +246,10 @@ class Deserializer { return opBuilder.getStringAttr(attrName); } + // Move a conditional branch into a separate basic block to avoid sinking + // defs that are required outside a selection region. + LogicalResult splitConditionalBlocks(); + //===--------------------------------------------------------------------===// // Type //===--------------------------------------------------------------------===// diff --git a/mlir/test/Target/SPIRV/selection.mlir b/mlir/test/Target/SPIRV/selection.mlir index f1d35d74dba15..24abb12998d06 100644 --- a/mlir/test/Target/SPIRV/selection.mlir +++ b/mlir/test/Target/SPIRV/selection.mlir @@ -105,6 +105,8 @@ spirv.module Logical GLSL450 requires #spirv.vce { %var = spirv.Variable : !spirv.ptr // CHECK-NEXT: spirv.Branch ^[[BB:.+]] // CHECK-NEXT: ^[[BB]]: +// CHECK: spirv.Branch ^[[BB:.+]] +// CHECK-NEXT: ^[[BB]]: // CHECK-NEXT: spirv.mlir.selection { spirv.mlir.selection { diff --git a/mlir/test/Target/SPIRV/selection.spv b/mlir/test/Target/SPIRV/selection.spv new file mode 100644 index 0000000000000..9642d0a44fb59 --- /dev/null +++ b/mlir/test/Target/SPIRV/selection.spv @@ -0,0 +1,60 @@ +; RUN: %if spirv-tools %{ spirv-as --target-env spv1.0 %s -o - | mlir-translate --deserialize-spirv - -o - | FileCheck %s %} + +; COM: The purpose of this test is to check that a variable (in this case %color) that +; COM: is defined before a selection region and used both in the selection region and +; COM: after it, is not sunk into that selection region by the deserializer. If the +; COM: variable is sunk, then it cannot be accessed outside the region and causes +; COM: control-flow structurization to fail. + +; CHECK: spirv.module Logical GLSL450 requires #spirv.vce { +; CHECK: spirv.func @main() "None" { +; CHECK: spirv.Variable : !spirv.ptr, Function> +; CHECK: spirv.mlir.selection { +; CHECK-NEXT: spirv.BranchConditional {{.*}}, ^[[bb:.+]], ^[[bb:.+]] +; CHECK-NEXT: ^[[bb:.+]] +; CHECK: spirv.Branch ^[[bb:.+]] +; CHECK-NEXT: ^[[bb:.+]]: +; CHECK-NEXT: spirv.mlir.merge +; CHECK-NEXT: } +; CHECK: spirv.Return +; CHECK-NEXT: } +; CHECK: } + + OpCapability Shader + %2 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %colorOut + OpExecutionMode %main OriginUpperLeft + OpDecorate %colorOut Location 0 + %void = OpTypeVoid + %4 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 +%fun_v4float = OpTypePointer Function %v4float + %float_1 = OpConstant %float 1 + %float_0 = OpConstant %float 0 + %13 = OpConstantComposite %v4float %float_1 %float_0 %float_0 %float_1 +%out_v4float = OpTypePointer Output %v4float + %colorOut = OpVariable %out_v4float Output + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 + %out_float = OpTypePointer Output %float + %bool = OpTypeBool + %25 = OpConstantComposite %v4float %float_1 %float_1 %float_0 %float_1 + %main = OpFunction %void None %4 + %6 = OpLabel + %color = OpVariable %fun_v4float Function + OpStore %color %13 + %19 = OpAccessChain %out_float %colorOut %uint_0 + %20 = OpLoad %float %19 + %22 = OpFOrdEqual %bool %20 %float_1 + OpSelectionMerge %24 None + OpBranchConditional %22 %23 %24 + %23 = OpLabel + OpStore %color %25 + OpBranch %24 + %24 = OpLabel + %26 = OpLoad %v4float %color + OpStore %colorOut %26 + OpReturn + OpFunctionEnd diff --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py index 32b2f8b53d5fa..8578c76969e74 100644 --- a/mlir/test/lit.cfg.py +++ b/mlir/test/lit.cfg.py @@ -43,6 +43,7 @@ ".test", ".pdll", ".c", + ".spv", ] # test_source_root: The root path where tests are located.