From 64fe18d2e020d4c10c74c0e3d816cdb8c54fe773 Mon Sep 17 00:00:00 2001 From: "Skrebkov, Artemy" Date: Tue, 23 Sep 2025 16:48:05 +0000 Subject: [PATCH 1/3] [mlir][SCF] Add scf.index_switch support for populateSCFStructuralTypeConversionsAndLegality --- .../Transforms/StructuralTypeConversions.cpp | 37 +++++++++++++--- .../SparseTensor/scf_1_N_conversion.mlir | 43 +++++++++++++++++++ 2 files changed, 74 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp index c0589044c26ec..ad62cfc867a27 100644 --- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp @@ -190,6 +190,31 @@ class ConvertWhileOpTypes }; } // namespace +namespace { +class ConvertIndexSwitchOpTypes + : public Structural1ToNConversionPattern { +public: + using Structural1ToNConversionPattern::Structural1ToNConversionPattern; + + std::optional + convertSourceOp(IndexSwitchOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + TypeRange dstTypes) const { + auto newOp = rewriter.create( + op.getLoc(), dstTypes, op.getArg(), op.getCases(), op.getNumCases()); + + for (unsigned i = 0u; i < op.getNumRegions(); i++) { + if (failed(rewriter.convertRegionTypes(&op.getRegion(i), *typeConverter))) + return std::nullopt; + auto &dstRegion = newOp.getRegion(i); + rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end()); + } + return newOp; + } +}; +} // namespace + namespace { // When the result types of a ForOp/IfOp get changed, the operand types of the // corresponding yield op need to be changed. In order to trigger the @@ -224,19 +249,19 @@ class ConvertConditionOpTypes : public OpConversionPattern { void mlir::scf::populateSCFStructuralTypeConversions( const TypeConverter &typeConverter, RewritePatternSet &patterns) { patterns.add( - typeConverter, patterns.getContext()); + ConvertWhileOpTypes, ConvertConditionOpTypes, + ConvertIndexSwitchOpTypes>(typeConverter, patterns.getContext(), + benefit); } void mlir::scf::populateSCFStructuralTypeConversionTarget( const TypeConverter &typeConverter, ConversionTarget &target) { - target.addDynamicallyLegalOp([&](Operation *op) { - return typeConverter.isLegal(op->getResultTypes()); - }); + target.addDynamicallyLegalOp( + [&](Operation *op) { return typeConverter.isLegal(op->getResults()); }); target.addDynamicallyLegalOp([&](scf::YieldOp op) { // We only have conversions for a subset of ops that use scf.yield // terminators. - if (!isa(op->getParentOp())) + if (!isa(op->getParentOp())) return true; return typeConverter.isLegal(op.getOperandTypes()); }); diff --git a/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir b/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir index f5d6a08b7de31..00f13ed7c8149 100644 --- a/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir +++ b/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir @@ -86,3 +86,46 @@ func.func @while(%arg0: tensor<1024xf32, #SparseVector>, %c: i1) -> tensor<1024x } return %0: tensor<1024xf32, #SparseVector> } + +// CHECK-LABEL: func.func @index_switch( +// CHECK-SAME: %[[PRED:.*0]]: index, +// CHECK-SAME: %[[VAL_A_1:.*1]]: memref, +// CHECK-SAME: %[[VAL_A_2:.*2]]: memref, +// CHECK-SAME: %[[VAL_A_3:.*3]]: memref, +// CHECK-SAME: %[[VAL_A_4:.*4]]: !sparse_tensor.storage_specifier +// CHECK-SAME: %[[VAL_B_1:.*5]]: memref, +// CHECK-SAME: %[[VAL_B_2:.*6]]: memref, +// CHECK-SAME: %[[VAL_B_3:.*7]]: memref, +// CHECK-SAME: %[[VAL_B_4:.*8]]: !sparse_tensor.storage_specifier +// CHECK-SAME: %[[VAL_C_1:.*9]]: memref, +// CHECK-SAME: %[[VAL_C_2:.*10]]: memref, +// CHECK-SAME: %[[VAL_C_3:.*11]]: memref, +// CHECK-SAME: %[[VAL_C_4:.*12]]: !sparse_tensor.storage_specifier + +// CHECK: %[[RES:.*]]:4 = scf.index_switch %[[PRED]] +// CHECK: case 1 { +// CHECK: scf.yield %[[VAL_A_1]], %[[VAL_A_2]], %[[VAL_A_3]], %[[VAL_A_4]] +// CHECK: case 2 { +// CHECK: scf.yield %[[VAL_B_1]], %[[VAL_B_2]], %[[VAL_B_3]], %[[VAL_B_4]] +// CHECK: default { +// CHECK: scf.yield %[[VAL_C_1]], %[[VAL_C_2]], %[[VAL_C_3]], %[[VAL_C_4]] + +// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2, %[[RES]]#3 : +// CHECK-SAME: memref, memref, memref, !sparse_tensor.storage_specifier + +func.func @index_switch(%pred: index, %a: tensor<5xf32, #SparseVector>, + %b: tensor<5xf32, #SparseVector>, + %c: tensor<5xf32, #SparseVector>) -> tensor<5xf32, #SparseVector> { + %0 = scf.index_switch %pred -> tensor<5xf32, #SparseVector> + case 1 { + scf.yield %a : tensor<5xf32, #SparseVector> + } + case 2 { + scf.yield %b : tensor<5xf32, #SparseVector> + } + default { + scf.yield %c : tensor<5xf32, #SparseVector> + } + + return %0 : tensor<5xf32, #SparseVector> +} From c8eee8878f8aa042c7c976111a2d8c4aa80d4cf0 Mon Sep 17 00:00:00 2001 From: "Skrebkov, Artemy" Date: Tue, 23 Sep 2025 21:12:03 +0000 Subject: [PATCH 2/3] Apply review remarks --- .../Dialect/SCF/Transforms/StructuralTypeConversions.cpp | 7 +++---- mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir | 1 + 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp index ad62cfc867a27..a8287c055db1d 100644 --- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp @@ -201,12 +201,11 @@ class ConvertIndexSwitchOpTypes convertSourceOp(IndexSwitchOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter, TypeRange dstTypes) const { - auto newOp = rewriter.create( - op.getLoc(), dstTypes, op.getArg(), op.getCases(), op.getNumCases()); + auto newOp = + IndexSwitchOp::create(rewriter, op.getLoc(), dstTypes, op.getArg(), + op.getCases(), op.getNumCases()); for (unsigned i = 0u; i < op.getNumRegions(); i++) { - if (failed(rewriter.convertRegionTypes(&op.getRegion(i), *typeConverter))) - return std::nullopt; auto &dstRegion = newOp.getRegion(i); rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end()); } diff --git a/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir b/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir index 00f13ed7c8149..515de5502f322 100644 --- a/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir +++ b/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir @@ -103,6 +103,7 @@ func.func @while(%arg0: tensor<1024xf32, #SparseVector>, %c: i1) -> tensor<1024x // CHECK-SAME: %[[VAL_C_4:.*12]]: !sparse_tensor.storage_specifier // CHECK: %[[RES:.*]]:4 = scf.index_switch %[[PRED]] +// CHECK-SAME: -> memref, memref, memref, !sparse_tensor.storage_specifier // CHECK: case 1 { // CHECK: scf.yield %[[VAL_A_1]], %[[VAL_A_2]], %[[VAL_A_3]], %[[VAL_A_4]] // CHECK: case 2 { From f960c70ca8e565520dbd6351359e8ec929f0a099 Mon Sep 17 00:00:00 2001 From: Dhawal Srivastava Date: Tue, 2 Dec 2025 14:32:45 -0800 Subject: [PATCH 3/3] Minor fix to make it compatible with vpux Signed-off-by: Dhawal Srivastava --- .../SCF/Transforms/StructuralTypeConversions.cpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp index a8287c055db1d..39b9ff4f53bb1 100644 --- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp @@ -201,9 +201,9 @@ class ConvertIndexSwitchOpTypes convertSourceOp(IndexSwitchOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter, TypeRange dstTypes) const { - auto newOp = - IndexSwitchOp::create(rewriter, op.getLoc(), dstTypes, op.getArg(), - op.getCases(), op.getNumCases()); + auto newOp = rewriter.create( + op.getLoc(), dstTypes, getSingleValue(adaptor.getArg()), op.getCases(), + op.getNumCases()); for (unsigned i = 0u; i < op.getNumRegions(); i++) { auto &dstRegion = newOp.getRegion(i); @@ -249,14 +249,14 @@ void mlir::scf::populateSCFStructuralTypeConversions( const TypeConverter &typeConverter, RewritePatternSet &patterns) { patterns.add(typeConverter, patterns.getContext(), - benefit); + ConvertIndexSwitchOpTypes>(typeConverter, patterns.getContext()); } void mlir::scf::populateSCFStructuralTypeConversionTarget( const TypeConverter &typeConverter, ConversionTarget &target) { - target.addDynamicallyLegalOp( - [&](Operation *op) { return typeConverter.isLegal(op->getResults()); }); + target.addDynamicallyLegalOp([&](Operation *op) { + return typeConverter.isLegal(op->getResultTypes()); + }); target.addDynamicallyLegalOp([&](scf::YieldOp op) { // We only have conversions for a subset of ops that use scf.yield // terminators.