@@ -190,6 +190,31 @@ class ConvertWhileOpTypes
190190};
191191} // namespace
192192
193+ namespace {
194+ class ConvertIndexSwitchOpTypes
195+ : public Structural1ToNConversionPattern<IndexSwitchOp,
196+ ConvertIndexSwitchOpTypes> {
197+ public:
198+ using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
199+
200+ std::optional<IndexSwitchOp>
201+ convertSourceOp (IndexSwitchOp op, OneToNOpAdaptor adaptor,
202+ ConversionPatternRewriter &rewriter,
203+ TypeRange dstTypes) const {
204+ auto newOp = rewriter.create <IndexSwitchOp>(
205+ op.getLoc (), dstTypes, op.getArg (), op.getCases (), op.getNumCases ());
206+
207+ for (unsigned i = 0u ; i < op.getNumRegions (); i++) {
208+ if (failed (rewriter.convertRegionTypes (&op.getRegion (i), *typeConverter)))
209+ return std::nullopt ;
210+ auto &dstRegion = newOp.getRegion (i);
211+ rewriter.inlineRegionBefore (op.getRegion (i), dstRegion, dstRegion.end ());
212+ }
213+ return newOp;
214+ }
215+ };
216+ } // namespace
217+
193218namespace {
194219// When the result types of a ForOp/IfOp get changed, the operand types of the
195220// corresponding yield op need to be changed. In order to trigger the
@@ -224,19 +249,19 @@ class ConvertConditionOpTypes : public OpConversionPattern<ConditionOp> {
224249void mlir::scf::populateSCFStructuralTypeConversions (
225250 const TypeConverter &typeConverter, RewritePatternSet &patterns) {
226251 patterns.add <ConvertForOpTypes, ConvertIfOpTypes, ConvertYieldOpTypes,
227- ConvertWhileOpTypes, ConvertConditionOpTypes>(
228- typeConverter, patterns.getContext ());
252+ ConvertWhileOpTypes, ConvertConditionOpTypes,
253+ ConvertIndexSwitchOpTypes>(typeConverter, patterns.getContext (),
254+ benefit);
229255}
230256
231257void mlir::scf::populateSCFStructuralTypeConversionTarget (
232258 const TypeConverter &typeConverter, ConversionTarget &target) {
233- target.addDynamicallyLegalOp <ForOp, IfOp>([&](Operation *op) {
234- return typeConverter.isLegal (op->getResultTypes ());
235- });
259+ target.addDynamicallyLegalOp <ForOp, IfOp, IndexSwitchOp>(
260+ [&](Operation *op) { return typeConverter.isLegal (op->getResults ()); });
236261 target.addDynamicallyLegalOp <scf::YieldOp>([&](scf::YieldOp op) {
237262 // We only have conversions for a subset of ops that use scf.yield
238263 // terminators.
239- if (!isa<ForOp, IfOp, WhileOp>(op->getParentOp ()))
264+ if (!isa<ForOp, IfOp, WhileOp, IndexSwitchOp >(op->getParentOp ()))
240265 return true ;
241266 return typeConverter.isLegal (op.getOperandTypes ());
242267 });
0 commit comments