@@ -219,6 +219,89 @@ MutableOperandRange FunctionCallOp::getArgOperandsMutable() {
219219 return getArgumentsMutable ();
220220}
221221
222+ // ===----------------------------------------------------------------------===//
223+ // spirv.Switch
224+ // ===----------------------------------------------------------------------===//
225+
226+ void SwitchOp::build (OpBuilder &builder, OperationState &result, Value selector,
227+ Block *defaultTarget, ValueRange defaultOperands,
228+ DenseIntElementsAttr literals, BlockRange targets,
229+ ArrayRef<ValueRange> targetOperands) {
230+ build (builder, result, selector, defaultOperands, targetOperands, literals,
231+ defaultTarget, targets);
232+ }
233+
234+ void SwitchOp::build (OpBuilder &builder, OperationState &result, Value selector,
235+ Block *defaultTarget, ValueRange defaultOperands,
236+ ArrayRef<APInt> literals, BlockRange targets,
237+ ArrayRef<ValueRange> targetOperands) {
238+ DenseIntElementsAttr literalsAttr;
239+ if (!literals.empty ()) {
240+ ShapedType literalType = VectorType::get (
241+ static_cast <int64_t >(literals.size ()), selector.getType ());
242+ literalsAttr = DenseIntElementsAttr::get (literalType, literals);
243+ }
244+ build (builder, result, selector, defaultTarget, defaultOperands, literalsAttr,
245+ targets, targetOperands);
246+ }
247+
248+ void SwitchOp::build (OpBuilder &builder, OperationState &result, Value selector,
249+ Block *defaultTarget, ValueRange defaultOperands,
250+ ArrayRef<int32_t > literals, BlockRange targets,
251+ ArrayRef<ValueRange> targetOperands) {
252+ DenseIntElementsAttr literalsAttr;
253+ if (!literals.empty ()) {
254+ ShapedType literalType = VectorType::get (
255+ static_cast <int64_t >(literals.size ()), selector.getType ());
256+ literalsAttr = DenseIntElementsAttr::get (literalType, literals);
257+ }
258+ build (builder, result, selector, defaultTarget, defaultOperands, literalsAttr,
259+ targets, targetOperands);
260+ }
261+
262+ LogicalResult SwitchOp::verify () {
263+ std::optional<DenseIntElementsAttr> literals = getLiterals ();
264+ BlockRange targets = getTargets ();
265+
266+ if (!literals && targets.empty ())
267+ return success ();
268+
269+ Type selectorType = getSelector ().getType ();
270+ Type literalType = literals->getType ().getElementType ();
271+ if (literalType != selectorType)
272+ return emitOpError () << " 'selector' type (" << selectorType
273+ << " ) should match literals type (" << literalType
274+ << " )" ;
275+
276+ if (literals && literals->size () != static_cast <int64_t >(targets.size ()))
277+ return emitOpError () << " number of literals (" << literals->size ()
278+ << " ) should match number of targets ("
279+ << targets.size () << " )" ;
280+ return success ();
281+ }
282+
283+ SuccessorOperands SwitchOp::getSuccessorOperands (unsigned index) {
284+ assert (index < getNumSuccessors () && " invalid successor index" );
285+ return SuccessorOperands (index == 0 ? getDefaultOperandsMutable ()
286+ : getTargetOperandsMutable (index - 1 ));
287+ }
288+
289+ Block *SwitchOp::getSuccessorForOperands (ArrayRef<Attribute> operands) {
290+ std::optional<DenseIntElementsAttr> literals = getLiterals ();
291+
292+ if (!literals)
293+ return getDefaultTarget ();
294+
295+ SuccessorRange targets = getTargets ();
296+ if (auto value = dyn_cast_or_null<IntegerAttr>(operands.front ())) {
297+ for (auto [index, literal] : llvm::enumerate (literals->getValues <APInt>()))
298+ if (literal == value.getValue ())
299+ return targets[index];
300+ return getDefaultTarget ();
301+ }
302+ return nullptr ;
303+ }
304+
222305// ===----------------------------------------------------------------------===//
223306// spirv.mlir.loop
224307// ===----------------------------------------------------------------------===//
0 commit comments