Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 12 additions & 13 deletions mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2263,23 +2263,22 @@ LogicalResult spirv::Deserializer::splitConditionalBlocks() {
if (!isa<spirv::BranchConditionalOp>(terminator))
continue;

// Do not split blocks that only contain a conditional branch, i.e., block
// size is <= 1.
if (block->begin() != block->end() &&
std::next(block->begin()) != block->end()) {
// Check if the current header block is a merge block of another construct.
bool splitHeaderMergeBlock = false;
for (const auto &[_, mergeInfo] : blockMergeInfo) {
if (mergeInfo.mergeBlock == block)
splitHeaderMergeBlock = true;
}

// Do not split a block that only contains a conditional branch, unless it
// is also a merge block of another construct - in that case we want to
// split the block. We do not want two constructs to share header / merge
// block.
if (!llvm::hasSingleElement(*block) || splitHeaderMergeBlock) {
Block *newBlock = block->splitBlock(terminator);
OpBuilder builder(block, block->end());
builder.create<spirv::BranchOp>(block->getParent()->getLoc(), newBlock);

// If the split block was a merge block of another region we need to
// update the map.
for (auto it = blockMergeInfo.begin(); it != blockMergeInfo.end(); ++it) {
auto &[ignore, mergeInfo] = *it;
if (mergeInfo.mergeBlock == block) {
mergeInfo.mergeBlock = newBlock;
}
}

// After splitting we need to update the map to use the new block as a
// header.
blockMergeInfo.erase(block);
Expand Down
6 changes: 4 additions & 2 deletions mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,10 @@ class Deserializer {
return opBuilder.getStringAttr(attrName);
}

// Move a conditional branch into a separate basic block to avoid sinking
// defs that are required outside a selection region.
/// Move a conditional branch into a separate basic block to avoid unnecessary
/// sinking of defs that may be required outside a selection region. This
/// function also ensures that a single block cannot be a header block of one
/// selection construct and the merge block of another.
LogicalResult splitConditionalBlocks();

//===--------------------------------------------------------------------===//
Expand Down
71 changes: 71 additions & 0 deletions mlir/test/Target/SPIRV/consecutive-selection.spv
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
; RUN: %if spirv-tools %{ spirv-as --target-env spv1.0 %s -o - | mlir-translate --deserialize-spirv - -o - | FileCheck %s %}

; COM: The purpose of this test is to check that in the case where two selections
; COM: regions share a header / merge block, this block is split and the selection
; COM: regions are not incorrectly nested.

; CHECK: spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
; CHECK: spirv.func @main() "None" {
; CHECK: spirv.mlir.selection {
; CHECK-NEXT: spirv.BranchConditional {{.*}}, ^[[bb:.+]], ^[[bb:.+]]
; CHECK-NEXT: ^[[bb:.+]]
; CHECK: spirv.Branch ^[[bb:.+]]
; CHECK-NEXT: ^[[bb:.+]]:
; CHECK-NEXT: spirv.mlir.merge
; CHECK-NEXT: }
; CHECK: spirv.mlir.selection {
; CHECK-NEXT: spirv.BranchConditional {{.*}}, ^[[bb:.+]], ^[[bb:.+]]
; CHECK-NEXT: ^[[bb:.+]]
; CHECK: spirv.Branch ^[[bb:.+]]
; CHECK-NEXT: ^[[bb:.+]]:
; CHECK-NEXT: spirv.mlir.merge
; CHECK-NEXT: }
; CHECK: spirv.Return
; CHECK-NEXT: }
; CHECK: }

OpCapability Shader
%2 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %main "main" %colorOut
OpExecutionMode %main OriginUpperLeft
OpDecorate %colorOut Location 0
%void = OpTypeVoid
%4 = OpTypeFunction %void
%float = OpTypeFloat 32
%v4float = OpTypeVector %float 4
%fun_v4float = OpTypePointer Function %v4float
%float_1 = OpConstant %float 1
%float_0 = OpConstant %float 0
%13 = OpConstantComposite %v4float %float_1 %float_0 %float_0 %float_1
%out_v4float = OpTypePointer Output %v4float
%colorOut = OpVariable %out_v4float Output
%uint = OpTypeInt 32 0
%uint_0 = OpConstant %uint 0
%out_float = OpTypePointer Output %float
%bool = OpTypeBool
%25 = OpConstantComposite %v4float %float_1 %float_1 %float_0 %float_1
%main = OpFunction %void None %4
%6 = OpLabel
%color = OpVariable %fun_v4float Function
OpStore %color %13
%19 = OpAccessChain %out_float %colorOut %uint_0
%20 = OpLoad %float %19
%22 = OpFOrdEqual %bool %20 %float_1
OpSelectionMerge %24 None
OpBranchConditional %22 %23 %24
%23 = OpLabel
OpStore %color %25
OpBranch %24
%24 = OpLabel
%30 = OpFOrdEqual %bool %20 %float_1
OpSelectionMerge %32 None
OpBranchConditional %30 %31 %32
%31 = OpLabel
OpStore %color %25
OpBranch %32
%32 = OpLabel
%26 = OpLoad %v4float %color
OpStore %colorOut %26
OpReturn
OpFunctionEnd