Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2158,13 +2158,72 @@ LogicalResult spirv::Deserializer::wireUpBlockArgument() {
return success();
}

LogicalResult spirv::Deserializer::splitConditionalBlocks() {
// Create a copy, so we can modify keys in the original.
BlockMergeInfoMap blockMergeInfoCopy = blockMergeInfo;
for (auto it = blockMergeInfoCopy.begin(), e = blockMergeInfoCopy.end();
it != e; ++it) {
auto &[block, mergeInfo] = *it;

// Skip processing loop regions. For loop regions continueBlock is non-null.
if (mergeInfo.continueBlock)
continue;

if (!block->mightHaveTerminator())
continue;

Operation *terminator = block->getTerminator();
assert(terminator);

if (!isa<spirv::BranchConditionalOp>(terminator))
continue;

// Do not split blocks that only contain a conditional branch, i.e., block
// size is <= 1.
if (block->begin() != block->end() &&
std::next(block->begin()) != block->end()) {
Block *newBlock = block->splitBlock(terminator);
OpBuilder builder(block, block->end());
builder.create<spirv::BranchOp>(block->getParent()->getLoc(), newBlock);

// If the split block was a merge block of another region we need to
// update the map.
for (auto it = blockMergeInfo.begin(); it != blockMergeInfo.end(); ++it) {
auto &[ignore, mergeInfo] = *it;
if (mergeInfo.mergeBlock == block) {
mergeInfo.mergeBlock = newBlock;
}
}

// After splitting we need to update the map to use the new block as a
// header.
blockMergeInfo.erase(block);
blockMergeInfo.try_emplace(newBlock, mergeInfo);
}
}

return success();
}

LogicalResult spirv::Deserializer::structurizeControlFlow() {
LLVM_DEBUG({
logger.startLine()
<< "//----- [cf] start structurizing control flow -----//\n";
logger.indent();
});

LLVM_DEBUG({
logger.startLine() << "[cf] split conditional blocks\n";
logger.startLine() << "\n";
});

if (failed(splitConditionalBlocks())) {
return failure();
}

// TODO: This loop is non-deterministic. Iteration order may vary between runs
// for the same shader as the key to the map is a pointer. See:
// https://github.com/llvm/llvm-project/issues/128547
while (!blockMergeInfo.empty()) {
Block *headerBlock = blockMergeInfo.begin()->first;
BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second;
Expand Down
4 changes: 4 additions & 0 deletions mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,10 @@ class Deserializer {
return opBuilder.getStringAttr(attrName);
}

// Move a conditional branch into a separate basic block to avoid sinking
// defs that are required outside a selection region.
LogicalResult splitConditionalBlocks();

//===--------------------------------------------------------------------===//
// Type
//===--------------------------------------------------------------------===//
Expand Down
2 changes: 2 additions & 0 deletions mlir/test/Target/SPIRV/selection.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
%var = spirv.Variable : !spirv.ptr<i1, Function>
// CHECK-NEXT: spirv.Branch ^[[BB:.+]]
// CHECK-NEXT: ^[[BB]]:
// CHECK: spirv.Branch ^[[BB:.+]]
// CHECK-NEXT: ^[[BB]]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for following up on this objective! :)

Just a bit concerned about the added overhead here and in mlir/test/Target/SPIRV/loop.mlir. Is this creating extra unused branches?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point, but I don't think it'll be much of an issue outside running --test-spirv-roundtrip that is only used for testing. The splitting only happens in the deserializer, so it won't affect any lowering that use SPIR-V, etc. Then when it comes to deserialization, it shouldn't be an issue when going directly from SPIR-V generated from outside MLIR (GLSL, etc.). I would expect the code to always have something meaningful to split without introducing superfluous blocks.

Now why do we get extra blocks with the roundtrip? I serialized the MLIR code from the test you commented and got (trimmed):

          %4 = OpLabel
         %12 = OpVariable %_ptr_Function_bool Function
               OpBranch %13
         %13 = OpLabel
               OpSelectionMerge %16 None
               OpBranchConditional %true %14 %15
         %14 = OpLabel
               OpStore %12 %true
               OpBranch %16
         %15 = OpLabel
               OpStore %12 %false
               OpBranch %16
         %16 = OpLabel

After serializing we are getting an extra block (%13) that would be unlikely to be present in non-MLIR generated SPIR-V, as OpSeletionMerge would be part of the predeceasing block. Actually, if you think about this is something my patch do, it isolates OpSeletionMerge and OpBranchConditional, so it shows my approach in deserialization matches how serializer works. Now because the block is already split, deserializng it again does further splitting creating superfluous blocks.

So, yes extra blocks are possible, but I think that would only happen if the input SPIR-V is already split and I don't think that would happen often with the upstream code. But even if it does happen, this only introduces some direct branches, which I believe are easy to optimise somewhere down the line - just collapse blocks together.

Hope that makes sense!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really appreciate the rigor, fwiw my entire workflow relies on doing spirv roundtrip.

Basically I am perf sensitive, and want to do spirv instrumentation & optimization passes at MLIR level before going back to spirv for deployment & this currently works very well.

Tested your PR with my use case, and figured I'd throw out a demo of what happens after roundtrip (easier to read as glsl):

Starting point:

float raymarchwater(vec3 camera, vec3 start, vec3 end, float depth) {
    vec3 pos = start;
    vec3 dir = normalize(end - start);
    for(int i=0; i < 64; i++) {
        float height = getwaves(pos.xz, ITERATIONS_RAYMARCH) * depth - depth;
        if(height + 0.01 > pos.y) {
            return distance(pos, camera);
        }
        pos += dir * (pos.y - height);
    }
    return distance(start, camera);
}

Roundtrip from #123371 [My draft]:

highp float raymarchwater(vec3 _201, vec3 _202, vec3 _203, float _204)
{
    vec3 _206 = _202;
    vec3 _207 = normalize(_203 - _202);
    for (int _208 = 0; _208 < 64; _208++)
    {
        vec2 _210 = _206.xz;
        int _211 = 12;
        float _232 = getwaves(_210, _211);
        float _209 = (_232 * _204) - _204;
        if ((_209 + 0.00999999977648258209228515625) > _206.y)
        {
            return distance(_206, _201);
        }
        _206 += (_207 * (_206.y - _209));
    }
    return distance(_202, _201);
}

Roundtrip from this PR:

highp float raymarchwater(vec3 _202, vec3 _203, vec3 _204, float _205)
{
    vec3 _207 = _203;
    vec3 _208 = normalize(_204 - _203);
    int _209 = 0;
    for (;;)
    {
        if (_209 < 64)
        {
            vec2 _211 = _207.xz;
            int _212 = 12;
            float _232 = getwaves(_211, _212);
            float _210 = (_232 * _205) - _205;
            if ((_210 + 0.00999999977648258209228515625) > _207.y)
            {
                return distance(_207, _202);
            }
            _207 += (_208 * (_207.y - _210));
            _209++;
            continue;
        }
        else
        {
            break;
        }
    }
    return distance(_203, _202);
}

To my naiive understanding, this still seems concerning, but eager to defer to your or @kuhar 's judgement on the matter -- perhaps I need to perform a loop detection pass if your version gets merged(?) or perhaps my perceived worry is actually just insignificant.

Copy link
Contributor Author

@IgWod-IMG IgWod-IMG Feb 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for sharing it. I am a bit confused, because I didn't intend to affect loops, so I started looking into a simpler example and it seems to me that the loop continue block (https://mlir.llvm.org/docs/Dialects/SPIR-V/#loop) gets split into two. It wasn't my intention:

    // Do not split loop headers
    if (auto it = blockMergeInfo.find(block); it != blockMergeInfo.end()) {
      if (it->second.continueBlock) {
        return;
      }
    }

(I incorrectly called it a loop header here)

Let me investigate what's happing first and then I'll come back to you. There is no point in engaging into a deeper discussion when the problem may lie in an incorrect implementation :)

EDIT: Actually, this code may be doing what it intended (as header can have 2 outgoing edges) and the comment is correct, and I may need to handle continue blocks as well. Anyway, I need to re-think what I have done here.


// CHECK-NEXT: spirv.mlir.selection {
spirv.mlir.selection {
Expand Down
60 changes: 60 additions & 0 deletions mlir/test/Target/SPIRV/selection.spv
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
; RUN: %if spirv-tools %{ spirv-as --target-env spv1.0 %s -o - | mlir-translate --deserialize-spirv - -o - | FileCheck %s %}

; COM: The purpose of this test is to check that a variable (in this case %color) that
; COM: is defined before a selection region and used both in the selection region and
; COM: after it, is not sunk into that selection region by the deserializer. If the
; COM: variable is sunk, then it cannot be accessed outside the region and causes
; COM: control-flow structurization to fail.

; CHECK: spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
; CHECK: spirv.func @main() "None" {
; CHECK: spirv.Variable : !spirv.ptr<vector<4xf32>, Function>
; CHECK: spirv.mlir.selection {
; CHECK-NEXT: spirv.BranchConditional {{.*}}, ^[[bb:.+]], ^[[bb:.+]]
; CHECK-NEXT: ^[[bb:.+]]
; CHECK: spirv.Branch ^[[bb:.+]]
; CHECK-NEXT: ^[[bb:.+]]:
; CHECK-NEXT: spirv.mlir.merge
; CHECK-NEXT: }
; CHECK: spirv.Return
; CHECK-NEXT: }
; CHECK: }

OpCapability Shader
%2 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %main "main" %colorOut
OpExecutionMode %main OriginUpperLeft
OpDecorate %colorOut Location 0
%void = OpTypeVoid
%4 = OpTypeFunction %void
%float = OpTypeFloat 32
%v4float = OpTypeVector %float 4
%fun_v4float = OpTypePointer Function %v4float
%float_1 = OpConstant %float 1
%float_0 = OpConstant %float 0
%13 = OpConstantComposite %v4float %float_1 %float_0 %float_0 %float_1
%out_v4float = OpTypePointer Output %v4float
%colorOut = OpVariable %out_v4float Output
%uint = OpTypeInt 32 0
%uint_0 = OpConstant %uint 0
%out_float = OpTypePointer Output %float
%bool = OpTypeBool
%25 = OpConstantComposite %v4float %float_1 %float_1 %float_0 %float_1
%main = OpFunction %void None %4
%6 = OpLabel
%color = OpVariable %fun_v4float Function
OpStore %color %13
%19 = OpAccessChain %out_float %colorOut %uint_0
%20 = OpLoad %float %19
%22 = OpFOrdEqual %bool %20 %float_1
OpSelectionMerge %24 None
OpBranchConditional %22 %23 %24
%23 = OpLabel
OpStore %color %25
OpBranch %24
%24 = OpLabel
%26 = OpLoad %v4float %color
OpStore %colorOut %26
OpReturn
OpFunctionEnd
1 change: 1 addition & 0 deletions mlir/test/lit.cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
".test",
".pdll",
".c",
".spv",
]

# test_source_root: The root path where tests are located.
Expand Down