Skip to content

Conversation

@IgWod-IMG
Copy link
Contributor

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Nov 20, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-spirv

Author: Igor Wodiany (IgWod-IMG)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/168899.diff

3 Files Affected:

  • (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp (+17)
  • (modified) mlir/lib/Target/SPIRV/Serialization/Serializer.cpp (+14-1)
  • (modified) mlir/test/Target/SPIRV/selection.mlir (+58)
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 252be796488c5..c91e7fc0e748c 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -2831,6 +2831,23 @@ LogicalResult spirv::Deserializer::wireUpBlockArgument() {
             branchCondOp.getFalseBlock());
 
       branchCondOp.erase();
+    } else if (auto switchOp = dyn_cast<spirv::SwitchOp>(op)) {
+      if (target == switchOp.getDefaultTarget()) {
+        SmallVector<ValueRange> targetOperands(switchOp.getTargetOperands());
+        DenseIntElementsAttr literals =
+            switchOp.getLiterals().value_or(DenseIntElementsAttr());
+        spirv::SwitchOp::create(
+            opBuilder, switchOp.getLoc(), switchOp.getSelector(),
+            switchOp.getDefaultTarget(), blockArgs, literals,
+            switchOp.getTargets(), targetOperands);
+        switchOp.erase();
+      } else {
+        SuccessorRange targets = switchOp.getTargets();
+        auto it = llvm::find(targets, target);
+        assert(it != targets.end());
+        size_t index = std::distance(targets.begin(), it);
+        switchOp.getTargetOperandsMutable(index).assign(blockArgs);
+      }
     } else {
       return emitError(unknownLoc, "unimplemented terminator for Phi creation");
     }
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 4e03a809bd0bc..153e9c770e464 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -1443,7 +1443,20 @@ LogicalResult Serializer::emitPhiForBlockArguments(Block *block) {
         assert(branchCondOp.getFalseTarget() == block);
         blockOperands = branchCondOp.getFalseTargetOperands();
       }
-
+      assert(!blockOperands->empty() &&
+             "expected non-empty block operand range");
+      predecessors.emplace_back(spirvPredecessor, *blockOperands);
+    } else if (auto switchOp = dyn_cast<spirv::SwitchOp>(terminator)) {
+      std::optional<OperandRange> blockOperands;
+      if (block == switchOp.getDefaultTarget()) {
+        blockOperands = switchOp.getDefaultOperands();
+      } else {
+        SuccessorRange targets = switchOp.getTargets();
+        auto it = llvm::find(targets, block);
+        assert(it != targets.end());
+        size_t index = std::distance(targets.begin(), it);
+        blockOperands = switchOp.getTargetOperands(index);
+      }
       assert(!blockOperands->empty() &&
              "expected non-empty block operand range");
       predecessors.emplace_back(spirvPredecessor, *blockOperands);
diff --git a/mlir/test/Target/SPIRV/selection.mlir b/mlir/test/Target/SPIRV/selection.mlir
index 3f762920015aa..d0ad118b01c8a 100644
--- a/mlir/test/Target/SPIRV/selection.mlir
+++ b/mlir/test/Target/SPIRV/selection.mlir
@@ -288,3 +288,61 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
   spirv.EntryPoint "GLCompute" @main
   spirv.ExecutionMode @main "LocalSize", 1, 1, 1
 }
+
+// -----
+
+// Selection with switch and block operands
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.5, [Shader], []> {
+// CHECK-LABEL: @selection_switch_operands
+  spirv.func @selection_switch_operands(%selector : si32) "None" {
+    %cst1 = spirv.Constant 1.000000e+00 : f32
+    %vec0 = spirv.Undef : vector<3xf32>
+// CHECK: {{%.*}} = spirv.CompositeInsert {{%.*}}, {{%.*}}[0 : i32] : f32 into vector<3xf32>
+    %vec1 = spirv.CompositeInsert %cst1, %vec0[0 : i32] : f32 into vector<3xf32>
+    spirv.Branch ^bb1
+  ^bb1:
+// CHECK: {{%.*}} = spirv.mlir.selection -> vector<3xf32> {
+    %vec4 = spirv.mlir.selection -> vector<3xf32> {
+// CHECK-NEXT: spirv.Switch {{%.*}} : si32, [
+// CHECK-NEXT: default: ^[[DEFAULT:.+]]({{%.*}} : vector<3xf32>),
+// CHECK-NEXT: 0: ^[[CASE0:.+]]({{%.*}} : vector<3xf32>),
+// CHECK-NEXT: 1: ^[[CASE1:.+]]({{%.*}} : vector<3xf32>)
+      spirv.Switch %selector : si32, [
+        default: ^bb3(%vec1 : vector<3xf32>),
+        0: ^bb1(%vec1 : vector<3xf32>),
+        1: ^bb2(%vec1 : vector<3xf32>)
+      ]
+// CHECK: ^[[CASE0]]({{%.*}}: vector<3xf32>)
+    ^bb1(%vecbb1: vector<3xf32>):
+      %cst3 = spirv.Constant 3.000000e+00 : f32
+// CHECK: {{%.*}} = spirv.CompositeInsert {{%.*}}, {{%.*}}[1 : i32] : f32 into vector<3xf32>
+      %vec2 = spirv.CompositeInsert %cst3, %vecbb1[1 : i32] : f32 into vector<3xf32>
+// CHECK-NEXT: spirv.Branch ^[[DEFAULT]]({{%.*}} : vector<3xf32>)
+      spirv.Branch ^bb3(%vec2 : vector<3xf32>)
+// CHECK-NEXT: ^[[CASE1]]({{%.*}}: vector<3xf32>)
+    ^bb2(%vecbb2: vector<3xf32>):
+      %cst4 = spirv.Constant 4.000000e+00 : f32
+// CHECK: {{%.*}} = spirv.CompositeInsert {{%.*}}, {{%.*}}[1 : i32] : f32 into vector<3xf32>
+      %vec3 = spirv.CompositeInsert %cst4, %vecbb2[1 : i32] : f32 into vector<3xf32>
+// CHECK-NEXT: spirv.Branch ^[[DEFAULT]]({{%.*}} : vector<3xf32>)
+      spirv.Branch ^bb3(%vec3 : vector<3xf32>)
+// CHECK-NEXT: ^[[DEFAULT]]({{%.*}}: vector<3xf32>)
+    ^bb3(%vecbb3: vector<3xf32>):
+// CHECK-NEXT: spirv.mlir.merge {{%.*}} : vector<3xf32>
+      spirv.mlir.merge %vecbb3 : vector<3xf32>
+// CHECK-NEXT: }
+    }
+    %cst2 = spirv.Constant 2.000000e+00 : f32
+// CHECK: {{%.*}} = spirv.CompositeInsert {{%.*}}, {{%.*}}[2 : i32] : f32 into vector<3xf32>
+    %vec5 = spirv.CompositeInsert %cst2, %vec4[2 : i32] : f32 into vector<3xf32>
+    spirv.Return
+  }
+
+  spirv.func @main() -> () "None" {
+    spirv.Return
+  }
+
+  spirv.EntryPoint "GLCompute" @main
+  spirv.ExecutionMode @main "LocalSize", 1, 1, 1
+}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants