@@ -185,6 +185,31 @@ class ConvertWhileOpTypes
185185};
186186} // namespace
187187
188+ namespace {
189+ class ConvertIndexSwitchOpTypes
190+ : public Structural1ToNConversionPattern<IndexSwitchOp,
191+ ConvertIndexSwitchOpTypes> {
192+ public:
193+ using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
194+
195+ std::optional<IndexSwitchOp>
196+ convertSourceOp (IndexSwitchOp op, OneToNOpAdaptor adaptor,
197+ ConversionPatternRewriter &rewriter,
198+ TypeRange dstTypes) const {
199+ auto newOp = rewriter.create <IndexSwitchOp>(
200+ op.getLoc (), dstTypes, op.getArg (), op.getCases (), op.getNumCases ());
201+
202+ for (unsigned i = 0u ; i < op.getNumRegions (); i++) {
203+ if (failed (rewriter.convertRegionTypes (&op.getRegion (i), *typeConverter)))
204+ return std::nullopt ;
205+ auto &dstRegion = newOp.getRegion (i);
206+ rewriter.inlineRegionBefore (op.getRegion (i), dstRegion, dstRegion.end ());
207+ }
208+ return newOp;
209+ }
210+ };
211+ } // namespace
212+
188213namespace {
189214// When the result types of a ForOp/IfOp get changed, the operand types of the
190215// corresponding yield op need to be changed. In order to trigger the
@@ -220,18 +245,19 @@ void mlir::scf::populateSCFStructuralTypeConversions(
220245 const TypeConverter &typeConverter, RewritePatternSet &patterns,
221246 PatternBenefit benefit) {
222247 patterns.add <ConvertForOpTypes, ConvertIfOpTypes, ConvertYieldOpTypes,
223- ConvertWhileOpTypes, ConvertConditionOpTypes>(
224- typeConverter, patterns.getContext (), benefit);
248+ ConvertWhileOpTypes, ConvertConditionOpTypes,
249+ ConvertIndexSwitchOpTypes>(typeConverter, patterns.getContext (),
250+ benefit);
225251}
226252
227253void mlir::scf::populateSCFStructuralTypeConversionTarget (
228254 const TypeConverter &typeConverter, ConversionTarget &target) {
229- target.addDynamicallyLegalOp <ForOp, IfOp>(
255+ target.addDynamicallyLegalOp <ForOp, IfOp, IndexSwitchOp >(
230256 [&](Operation *op) { return typeConverter.isLegal (op->getResults ()); });
231257 target.addDynamicallyLegalOp <scf::YieldOp>([&](scf::YieldOp op) {
232258 // We only have conversions for a subset of ops that use scf.yield
233259 // terminators.
234- if (!isa<ForOp, IfOp, WhileOp>(op->getParentOp ()))
260+ if (!isa<ForOp, IfOp, WhileOp, IndexSwitchOp >(op->getParentOp ()))
235261 return true ;
236262 return typeConverter.isLegal (op.getOperands ());
237263 });
0 commit comments