Skip to content
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1937,8 +1937,15 @@ LogicalResult ControlFlowStructurizer::structurize() {
<< "[cf] block " << block << " is a function entry block\n");
}

for (auto &op : *block)
for (auto &op : *block) {
if (auto varOp = dyn_cast<spirv::VariableOp>(op)) {
if (varOp.getStorageClass() == spirv::StorageClass::Function) {
mapper.map(&op, &op);
continue;
}
}
newBlock->push_back(op.clone(mapper));
}
}

// Go through all ops and remap the operands.
Expand Down Expand Up @@ -2006,6 +2013,11 @@ LogicalResult ControlFlowStructurizer::structurize() {
// the SelectionOp/LoopOp's region, there is no escape for it:
// SelectionOp/LooOp does not support yield values right now.
Copy link
Contributor Author

@mishaobu mishaobu Jan 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixes this(?)

for (auto *block : constructBlocks) {
block->walk([&](spirv::VariableOp varOp) {
if (varOp.getStorageClass() == spirv::StorageClass::Function) {
varOp->moveBefore(&body.front().front());
}
});
for (Operation &op : *block)
if (!op.use_empty())
return op.emitOpError(
Expand Down Expand Up @@ -2070,6 +2082,12 @@ LogicalResult ControlFlowStructurizer::structurize() {
}
}

if (auto selectionOp = llvm::dyn_cast<spirv::SelectionOp>(op)) {
selectionOp.walk([&](spirv::VariableOp varOp) {
varOp->moveBefore(&op->getParentRegion()->front().front());
});
}

LLVM_DEBUG(logger.startLine()
<< "[cf] after structurizing construct with header block "
<< headerBlock << ":\n"
Expand Down
35 changes: 35 additions & 0 deletions mlir/test/Target/SPIRV/branch-load.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// RUN: mlir-translate -no-implicit-module -test-spirv-roundtrip %s | FileCheck %s

// CHECK: spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []>
spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
// CHECK: spirv.func @main() "None"
spirv.func @main() "None" {
// CHECK: %[[VAR:.*]] = spirv.Variable : !spirv.ptr<vector<3xf32>, Function>
%0 = spirv.Variable : !spirv.ptr<vector<3xf32>, Function>
spirv.Branch ^bb1
^bb1: // pred: ^bb0
// CHECK: spirv.mlir.selection
spirv.mlir.selection {
// CHECK: %[[COND:.*]] = spirv.Constant true
// CHECK: spirv.BranchConditional %[[COND]]
%true = spirv.Constant true
spirv.BranchConditional %true, ^bb1, ^bb2
^bb1: // pred: ^bb0
// CHECK: %[[CONST:.*]] = spirv.Constant dense<0.000000e+00> : vector<3xf32>
// CHECK: spirv.Store "Function" %[[VAR]], %[[CONST]] : vector<3xf32>
%cst_vec_3xf32 = spirv.Constant dense<0.000000e+00> : vector<3xf32>
spirv.Store "Function" %0, %cst_vec_3xf32 : vector<3xf32>
spirv.Branch ^bb2
^bb2: // 2 preds: ^bb0, ^bb1
spirv.mlir.merge
}
// CHECK: %[[RESULT:.*]] = spirv.Load "Function" %[[VAR]] : vector<3xf32>
// CHECK: spirv.Return
%1 = spirv.Load "Function" %0 : vector<3xf32>
spirv.Return
}
// CHECK: spirv.EntryPoint "Fragment" @main
// CHECK: spirv.ExecutionMode @main "OriginUpperLeft"
spirv.EntryPoint "Fragment" @main
spirv.ExecutionMode @main "OriginUpperLeft"
}
40 changes: 40 additions & 0 deletions mlir/test/Target/SPIRV/ssa.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# RUN: split-file %s %t
# RUN: spirv-as --target-env spv1.0 %t/spv.spvasm -o %t.spv
# RUN: mlir-translate --deserialize-spirv %t.spv -o - | FileCheck %s

// CHECK: module
// CHECK: spirv.func @main
// CHECK: spirv.Variable
// CHECK: spirv.Return
//--- spv.spvasm
; SPIR-V
; Version: 1.0
; Generator: Khronos SPIR-V Tools Assembler; 0
; Bound: 20
; Schema: 0
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %main "main"
OpExecutionMode %main OriginUpperLeft
%void = OpTypeVoid
%float = OpTypeFloat 32
%v3float = OpTypeVector %float 3
%ptr_v3f = OpTypePointer Function %v3float
%fn = OpTypeFunction %void
%float_0 = OpConstant %float 0
%float_1 = OpConstant %float 1
%bool = OpTypeBool
%true = OpConstantTrue %bool
%v3_zero = OpConstantComposite %v3float %float_0 %float_0 %float_0
%main = OpFunction %void None %fn
%entry = OpLabel
%var = OpVariable %ptr_v3f Function
OpSelectionMerge %merge None
OpBranchConditional %true %then %merge
%then = OpLabel
OpStore %var %v3_zero
OpBranch %merge
%merge = OpLabel
%load = OpLoad %v3float %var
OpReturn
OpFunctionEnd
36 changes: 36 additions & 0 deletions mlir/test/Target/SPIRV/ssa2.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# RUN: split-file %s %t
# RUN: spirv-as --target-env spv1.0 %t/spv.spvasm -o %t.spv
# RUN: mlir-translate --deserialize-spirv %t.spv -o - | FileCheck %s
// CHECK: module
// CHECK: spirv.func @main
// CHECK: spirv.Variable
// CHECK: spirv.Return
//--- spv.spvasm
; SPIR-V
; Version: 1.0
; Generator: Khronos SPIR-V Tools Assembler; 0
; Bound: 20
; Schema: 0
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %main "main"
OpExecutionMode %main OriginUpperLeft
%void = OpTypeVoid
%float = OpTypeFloat 32
%ptr_f = OpTypePointer Function %float
%fn = OpTypeFunction %void
%float_1 = OpConstant %float 1.0
%bool = OpTypeBool
%true = OpConstantTrue %bool
%main = OpFunction %void None %fn
%entry = OpLabel
%var = OpVariable %ptr_f Function
OpSelectionMerge %merge None
OpBranchConditional %true %then %merge
%then = OpLabel
OpStore %var %float_1
OpBranch %merge
%merge = OpLabel
%load = OpLoad %float %var
OpReturn
OpFunctionEnd