Skip to content

Conversation

@dcaballe
Copy link
Contributor

@dcaballe dcaballe commented Jun 7, 2025

This PR is part of the last step to remove vector.extractelement and vector.insertelement ops (RFC: https://discourse.llvm.org/t/rfc-psa-remove-vector-extractelement-and-vector-insertelement-ops-in-favor-of-vector-extract-and-vector-insert-ops).
It updates the Sparse Vectorizer to use vector.extract and vector.insert instead of
vector.extractelement and vector.insertelement.

…se vectorizer

This PR is part of the last step to remove `vector.extractelement` and `vector.insertelement` ops.
It updates the Sparse Vectorizer to use `vector.extract` and `vector.insert` instead of
`vector.extractelement` and `vector.insertelement`.
@llvmbot
Copy link
Member

llvmbot commented Jun 7, 2025

@llvm/pr-subscribers-mlir

Author: Diego Caballero (dcaballe)

Changes

This PR is part of the last step to remove vector.extractelement and vector.insertelement ops.
It updates the Sparse Vectorizer to use vector.extract and vector.insert instead of
vector.extractelement and vector.insertelement.


Full diff: https://github.com/llvm/llvm-project/pull/143270.diff

5 Files Affected:

  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp (+37-20)
  • (modified) mlir/test/Dialect/SparseTensor/minipipeline_vector.mlir (+1-1)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_vector.mlir (+3-3)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir (+1-1)
  • (modified) mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir (+5-5)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
index 3d963dea2f572..482720ba8aed1 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
@@ -198,12 +198,12 @@ static Value genVectorReducInit(PatternRewriter &rewriter, Location loc,
   case vector::CombiningKind::ADD:
   case vector::CombiningKind::XOR:
     // Initialize reduction vector to: | 0 | .. | 0 | r |
-    return rewriter.create<vector::InsertElementOp>(
+    return rewriter.create<vector::InsertOp>(
         loc, r, constantZero(rewriter, loc, vtp),
         constantIndex(rewriter, loc, 0));
   case vector::CombiningKind::MUL:
     // Initialize reduction vector to: | 1 | .. | 1 | r |
-    return rewriter.create<vector::InsertElementOp>(
+    return rewriter.create<vector::InsertOp>(
         loc, r, constantOne(rewriter, loc, vtp),
         constantIndex(rewriter, loc, 0));
   case vector::CombiningKind::AND:
@@ -628,31 +628,48 @@ struct ForOpRewriter : public OpRewritePattern<scf::ForOp> {
   const VL vl;
 };
 
+static LogicalResult cleanReducChain(PatternRewriter &rewriter, Operation *op,
+                                     Value inp) {
+  if (auto redOp = inp.getDefiningOp<vector::ReductionOp>()) {
+    if (auto forOp = redOp.getVector().getDefiningOp<scf::ForOp>()) {
+      if (forOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName())) {
+        rewriter.replaceOp(op, redOp.getVector());
+        return success();
+      }
+    }
+  }
+  return failure();
+}
+
 /// Reduction chain cleanup.
 ///   v = for { }
-///   s = vsum(v)               v = for { }
-///   u = expand(s)       ->    for (v) { }
+///   s = vsum(v)                  v = for { }
+///   u = broadcast(s)       ->    for (v) { }
 ///   for (u) { }
-template <typename VectorOp>
-struct ReducChainRewriter : public OpRewritePattern<VectorOp> {
+struct ReducChainBroadcastRewriter : public OpRewritePattern<vector::BroadcastOp> {
 public:
-  using OpRewritePattern<VectorOp>::OpRewritePattern;
+  using OpRewritePattern<vector::BroadcastOp>::OpRewritePattern;
 
-  LogicalResult matchAndRewrite(VectorOp op,
+  LogicalResult matchAndRewrite(vector::BroadcastOp op,
                                 PatternRewriter &rewriter) const override {
-    Value inp = op.getSource();
-    if (auto redOp = inp.getDefiningOp<vector::ReductionOp>()) {
-      if (auto forOp = redOp.getVector().getDefiningOp<scf::ForOp>()) {
-        if (forOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName())) {
-          rewriter.replaceOp(op, redOp.getVector());
-          return success();
-        }
-      }
-    }
-    return failure();
+    return cleanReducChain(rewriter, op, op.getSource());
   }
 };
 
+/// Reduction chain cleanup.
+///   v = for { }
+///   s = vsum(v)               v = for { }
+///   u = insert(s)       ->    for (v) { }
+///   for (u) { }
+struct ReducChainInsertRewriter : public OpRewritePattern<vector::InsertOp> {
+public:
+  using OpRewritePattern<vector::InsertOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::InsertOp op,
+                                PatternRewriter &rewriter) const override {
+    return cleanReducChain(rewriter, op, op.getValueToStore());
+  }
+};
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -668,6 +685,6 @@ void mlir::populateSparseVectorizationPatterns(RewritePatternSet &patterns,
   vector::populateVectorStepLoweringPatterns(patterns);
   patterns.add<ForOpRewriter>(patterns.getContext(), vectorLength,
                               enableVLAVectorization, enableSIMDIndex32);
-  patterns.add<ReducChainRewriter<vector::InsertElementOp>,
-               ReducChainRewriter<vector::BroadcastOp>>(patterns.getContext());
+  patterns.add<ReducChainInsertRewriter, ReducChainBroadcastRewriter>(
+      patterns.getContext());
 }
diff --git a/mlir/test/Dialect/SparseTensor/minipipeline_vector.mlir b/mlir/test/Dialect/SparseTensor/minipipeline_vector.mlir
index 2475aa5139da4..b2dfbeb53fde8 100755
--- a/mlir/test/Dialect/SparseTensor/minipipeline_vector.mlir
+++ b/mlir/test/Dialect/SparseTensor/minipipeline_vector.mlir
@@ -22,7 +22,7 @@
 // CHECK-NOVEC:       }
 //
 // CHECK-VEC-LABEL: func.func @sum_reduction
-// CHECK-VEC:       vector.insertelement
+// CHECK-VEC:       vector.insert
 // CHECK-VEC:       scf.for
 // CHECK-VEC:         vector.create_mask
 // CHECK-VEC:         vector.maskedload
diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
index 364ba6e71ff3b..64235c7227800 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
@@ -241,7 +241,7 @@ func.func @mul_s(%arga: tensor<1024xf32, #SparseVector>,
 // CHECK-VEC16-DAG:   %[[c1024:.*]] = arith.constant 1024 : index
 // CHECK-VEC16-DAG:   %[[v0:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32>
 // CHECK-VEC16:       %[[l:.*]] = memref.load %{{.*}}[] : memref<f32>
-// CHECK-VEC16:       %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[c0]] : index] : vector<16xf32>
+// CHECK-VEC16:       %[[r:.*]] = vector.insert %[[l]], %[[v0]] [0] : f32 into vector<16xf32>
 // CHECK-VEC16:       %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] iter_args(%[[red_in:.*]] = %[[r]]) -> (vector<16xf32>) {
 // CHECK-VEC16:         %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>
 // CHECK-VEC16:         %[[lb:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
@@ -258,7 +258,7 @@ func.func @mul_s(%arga: tensor<1024xf32, #SparseVector>,
 // CHECK-VEC16-IDX32-DAG:   %[[c1024:.*]] = arith.constant 1024 : index
 // CHECK-VEC16-IDX32-DAG:   %[[v0:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32>
 // CHECK-VEC16-IDX32:       %[[l:.*]] = memref.load %{{.*}}[] : memref<f32>
-// CHECK-VEC16-IDX32:       %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[c0]] : index] : vector<16xf32>
+// CHECK-VEC16-IDX32:       %[[r:.*]] = vector.insert %[[l]], %[[v0]] [0] : f32 into vector<16xf32>
 // CHECK-VEC16-IDX32:       %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] iter_args(%[[red_in:.*]] = %[[r]]) -> (vector<16xf32>) {
 // CHECK-VEC16-IDX32:         %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>
 // CHECK-VEC16-IDX32:         %[[lb:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
@@ -278,7 +278,7 @@ func.func @mul_s(%arga: tensor<1024xf32, #SparseVector>,
 // CHECK-VEC4-SVE:       %[[l:.*]] = memref.load %{{.*}}[] : memref<f32>
 // CHECK-VEC4-SVE:       %[[vscale:.*]] = vector.vscale
 // CHECK-VEC4-SVE:       %[[step:.*]] = arith.muli %[[vscale]], %[[c4]] : index
-// CHECK-VEC4-SVE:       %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[c0]] : index] : vector<[4]xf32>
+// CHECK-VEC4-SVE:       %[[r:.*]] = vector.insert %[[l]], %[[v0]] [0] : f32 into vector<[4]xf32>
 // CHECK-VEC4-SVE:       %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[step]] iter_args(%[[red_in:.*]] = %[[r]]) -> (vector<[4]xf32>) {
 // CHECK-VEC4-SVE:         %[[sub:.*]] = affine.min #[[$map]](%[[c1024]], %[[i]])[%[[step]]]
 // CHECK-VEC4-SVE:         %[[mask:.*]] = vector.create_mask %[[sub]] : vector<[4]xi1>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
index f4b565c7f9c8a..0ab72897d7bc3 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
@@ -82,7 +82,7 @@
 // CHECK:               %[[VAL_57:.*]] = arith.select %[[VAL_39]], %[[VAL_56]], %[[VAL_32]] : index
 // CHECK:               scf.yield %[[VAL_55]], %[[VAL_57]], %[[VAL_58:.*]] : index, index, f64
 // CHECK:             } attributes {"Emitted from" = "linalg.generic"}
-// CHECK:             %[[VAL_59:.*]] = vector.insertelement %[[VAL_60:.*]]#2, %[[VAL_4]]{{\[}}%[[VAL_6]] : index] : vector<8xf64>
+// CHECK:             %[[VAL_59:.*]] = vector.insert %[[VAL_60:.*]]#2, %[[VAL_4]] [0] : f64 into vector<8xf64>
 // CHECK:             %[[VAL_61:.*]] = scf.for %[[VAL_62:.*]] = %[[VAL_60]]#0 to %[[VAL_21]] step %[[VAL_3]] iter_args(%[[VAL_63:.*]] = %[[VAL_59]]) -> (vector<8xf64>) {
 // CHECK:               %[[VAL_64:.*]] = affine.min #map(%[[VAL_21]], %[[VAL_62]]){{\[}}%[[VAL_3]]]
 // CHECK:               %[[VAL_65:.*]] = vector.create_mask %[[VAL_64]] : vector<8xi1>
diff --git a/mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir b/mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir
index 01b717090e87a..6effbbf98abb7 100644
--- a/mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir
+++ b/mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir
@@ -172,7 +172,7 @@ func.func @sparse_reduction_ori_accumulator_on_rhs(%argx: tensor<i13>,
 // CHECK-ON:           %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref<i32>
 // CHECK-ON:           %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_3]]] : memref<?xindex>
 // CHECK-ON:           %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK-ON:           %[[VAL_12:.*]] = vector.insertelement %[[VAL_9]], %[[VAL_4]]{{\[}}%[[VAL_3]] : index] : vector<8xi32>
+// CHECK-ON:           %[[VAL_12:.*]] = vector.insert %[[VAL_9]], %[[VAL_4]] [0] : i32 into vector<8xi32>
 // CHECK-ON:           %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xi32>) {
 // CHECK-ON:             %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]]
 // CHECK-ON:             %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1>
@@ -247,7 +247,7 @@ func.func @sparse_reduction_subi(%argx: tensor<i32>,
 // CHECK-ON:  %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref<i32>
 // CHECK-ON:  %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK-ON:  %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK-ON:  %[[VAL_12:.*]] = vector.insertelement %[[VAL_9]], %[[VAL_3]]{{\[}}%[[VAL_4]] : index] : vector<8xi32>
+// CHECK-ON:  %[[VAL_12:.*]] = vector.insert %[[VAL_9]], %[[VAL_3]] [0] : i32 into vector<8xi32>
 // CHECK-ON:  %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xi32>) {
 // CHECK-ON:    %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]]
 // CHECK-ON:    %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1>
@@ -323,7 +323,7 @@ func.func @sparse_reduction_xor(%argx: tensor<i32>,
 // CHECK-ON:   %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref<i32>
 // CHECK-ON:   %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK-ON:   %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK-ON:   %[[VAL_12:.*]] = vector.insertelement %[[VAL_9]], %[[VAL_3]]{{\[}}%[[VAL_4]] : index] : vector<8xi32>
+// CHECK-ON:   %[[VAL_12:.*]] = vector.insert %[[VAL_9]], %[[VAL_3]] [0] : i32 into vector<8xi32>
 // CHECK-ON:   %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xi32>) {
 // CHECK-ON:     %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]]
 // CHECK-ON:     %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1>
@@ -399,7 +399,7 @@ func.func @sparse_reduction_addi(%argx: tensor<i32>,
 // CHECK-ON:   %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref<f32>
 // CHECK-ON:   %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK-ON:   %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK-ON:   %[[VAL_12:.*]] = vector.insertelement %[[VAL_9]], %[[VAL_3]]{{\[}}%[[VAL_4]] : index] : vector<8xf32>
+// CHECK-ON:   %[[VAL_12:.*]] = vector.insert %[[VAL_9]], %[[VAL_3]] [0] : f32 into vector<8xf32>
 // CHECK-ON:   %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xf32>) {
 // CHECK-ON:     %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]]
 // CHECK-ON:     %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1>
@@ -475,7 +475,7 @@ func.func @sparse_reduction_subf(%argx: tensor<f32>,
 // CHECK-ON:   %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref<f32>
 // CHECK-ON:   %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK-ON:   %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK-ON:   %[[VAL_12:.*]] = vector.insertelement %[[VAL_9]], %[[VAL_3]]{{\[}}%[[VAL_4]] : index] : vector<8xf32>
+// CHECK-ON:   %[[VAL_12:.*]] = vector.insert %[[VAL_9]], %[[VAL_3]] [0] : f32 into vector<8xf32>
 // CHECK-ON:   %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xf32>) {
 // CHECK-ON:     %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]]
 // CHECK-ON:     %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1>

@llvmbot
Copy link
Member

llvmbot commented Jun 7, 2025

@llvm/pr-subscribers-mlir-sparse

Author: Diego Caballero (dcaballe)

Changes

This PR is part of the last step to remove vector.extractelement and vector.insertelement ops.
It updates the Sparse Vectorizer to use vector.extract and vector.insert instead of
vector.extractelement and vector.insertelement.


Full diff: https://github.com/llvm/llvm-project/pull/143270.diff

5 Files Affected:

  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp (+37-20)
  • (modified) mlir/test/Dialect/SparseTensor/minipipeline_vector.mlir (+1-1)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_vector.mlir (+3-3)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir (+1-1)
  • (modified) mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir (+5-5)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
index 3d963dea2f572..482720ba8aed1 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
@@ -198,12 +198,12 @@ static Value genVectorReducInit(PatternRewriter &rewriter, Location loc,
   case vector::CombiningKind::ADD:
   case vector::CombiningKind::XOR:
     // Initialize reduction vector to: | 0 | .. | 0 | r |
-    return rewriter.create<vector::InsertElementOp>(
+    return rewriter.create<vector::InsertOp>(
         loc, r, constantZero(rewriter, loc, vtp),
         constantIndex(rewriter, loc, 0));
   case vector::CombiningKind::MUL:
     // Initialize reduction vector to: | 1 | .. | 1 | r |
-    return rewriter.create<vector::InsertElementOp>(
+    return rewriter.create<vector::InsertOp>(
         loc, r, constantOne(rewriter, loc, vtp),
         constantIndex(rewriter, loc, 0));
   case vector::CombiningKind::AND:
@@ -628,31 +628,48 @@ struct ForOpRewriter : public OpRewritePattern<scf::ForOp> {
   const VL vl;
 };
 
+static LogicalResult cleanReducChain(PatternRewriter &rewriter, Operation *op,
+                                     Value inp) {
+  if (auto redOp = inp.getDefiningOp<vector::ReductionOp>()) {
+    if (auto forOp = redOp.getVector().getDefiningOp<scf::ForOp>()) {
+      if (forOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName())) {
+        rewriter.replaceOp(op, redOp.getVector());
+        return success();
+      }
+    }
+  }
+  return failure();
+}
+
 /// Reduction chain cleanup.
 ///   v = for { }
-///   s = vsum(v)               v = for { }
-///   u = expand(s)       ->    for (v) { }
+///   s = vsum(v)                  v = for { }
+///   u = broadcast(s)       ->    for (v) { }
 ///   for (u) { }
-template <typename VectorOp>
-struct ReducChainRewriter : public OpRewritePattern<VectorOp> {
+struct ReducChainBroadcastRewriter : public OpRewritePattern<vector::BroadcastOp> {
 public:
-  using OpRewritePattern<VectorOp>::OpRewritePattern;
+  using OpRewritePattern<vector::BroadcastOp>::OpRewritePattern;
 
-  LogicalResult matchAndRewrite(VectorOp op,
+  LogicalResult matchAndRewrite(vector::BroadcastOp op,
                                 PatternRewriter &rewriter) const override {
-    Value inp = op.getSource();
-    if (auto redOp = inp.getDefiningOp<vector::ReductionOp>()) {
-      if (auto forOp = redOp.getVector().getDefiningOp<scf::ForOp>()) {
-        if (forOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName())) {
-          rewriter.replaceOp(op, redOp.getVector());
-          return success();
-        }
-      }
-    }
-    return failure();
+    return cleanReducChain(rewriter, op, op.getSource());
   }
 };
 
+/// Reduction chain cleanup.
+///   v = for { }
+///   s = vsum(v)               v = for { }
+///   u = insert(s)       ->    for (v) { }
+///   for (u) { }
+struct ReducChainInsertRewriter : public OpRewritePattern<vector::InsertOp> {
+public:
+  using OpRewritePattern<vector::InsertOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::InsertOp op,
+                                PatternRewriter &rewriter) const override {
+    return cleanReducChain(rewriter, op, op.getValueToStore());
+  }
+};
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -668,6 +685,6 @@ void mlir::populateSparseVectorizationPatterns(RewritePatternSet &patterns,
   vector::populateVectorStepLoweringPatterns(patterns);
   patterns.add<ForOpRewriter>(patterns.getContext(), vectorLength,
                               enableVLAVectorization, enableSIMDIndex32);
-  patterns.add<ReducChainRewriter<vector::InsertElementOp>,
-               ReducChainRewriter<vector::BroadcastOp>>(patterns.getContext());
+  patterns.add<ReducChainInsertRewriter, ReducChainBroadcastRewriter>(
+      patterns.getContext());
 }
diff --git a/mlir/test/Dialect/SparseTensor/minipipeline_vector.mlir b/mlir/test/Dialect/SparseTensor/minipipeline_vector.mlir
index 2475aa5139da4..b2dfbeb53fde8 100755
--- a/mlir/test/Dialect/SparseTensor/minipipeline_vector.mlir
+++ b/mlir/test/Dialect/SparseTensor/minipipeline_vector.mlir
@@ -22,7 +22,7 @@
 // CHECK-NOVEC:       }
 //
 // CHECK-VEC-LABEL: func.func @sum_reduction
-// CHECK-VEC:       vector.insertelement
+// CHECK-VEC:       vector.insert
 // CHECK-VEC:       scf.for
 // CHECK-VEC:         vector.create_mask
 // CHECK-VEC:         vector.maskedload
diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
index 364ba6e71ff3b..64235c7227800 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
@@ -241,7 +241,7 @@ func.func @mul_s(%arga: tensor<1024xf32, #SparseVector>,
 // CHECK-VEC16-DAG:   %[[c1024:.*]] = arith.constant 1024 : index
 // CHECK-VEC16-DAG:   %[[v0:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32>
 // CHECK-VEC16:       %[[l:.*]] = memref.load %{{.*}}[] : memref<f32>
-// CHECK-VEC16:       %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[c0]] : index] : vector<16xf32>
+// CHECK-VEC16:       %[[r:.*]] = vector.insert %[[l]], %[[v0]] [0] : f32 into vector<16xf32>
 // CHECK-VEC16:       %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] iter_args(%[[red_in:.*]] = %[[r]]) -> (vector<16xf32>) {
 // CHECK-VEC16:         %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>
 // CHECK-VEC16:         %[[lb:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
@@ -258,7 +258,7 @@ func.func @mul_s(%arga: tensor<1024xf32, #SparseVector>,
 // CHECK-VEC16-IDX32-DAG:   %[[c1024:.*]] = arith.constant 1024 : index
 // CHECK-VEC16-IDX32-DAG:   %[[v0:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32>
 // CHECK-VEC16-IDX32:       %[[l:.*]] = memref.load %{{.*}}[] : memref<f32>
-// CHECK-VEC16-IDX32:       %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[c0]] : index] : vector<16xf32>
+// CHECK-VEC16-IDX32:       %[[r:.*]] = vector.insert %[[l]], %[[v0]] [0] : f32 into vector<16xf32>
 // CHECK-VEC16-IDX32:       %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] iter_args(%[[red_in:.*]] = %[[r]]) -> (vector<16xf32>) {
 // CHECK-VEC16-IDX32:         %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>
 // CHECK-VEC16-IDX32:         %[[lb:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
@@ -278,7 +278,7 @@ func.func @mul_s(%arga: tensor<1024xf32, #SparseVector>,
 // CHECK-VEC4-SVE:       %[[l:.*]] = memref.load %{{.*}}[] : memref<f32>
 // CHECK-VEC4-SVE:       %[[vscale:.*]] = vector.vscale
 // CHECK-VEC4-SVE:       %[[step:.*]] = arith.muli %[[vscale]], %[[c4]] : index
-// CHECK-VEC4-SVE:       %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[c0]] : index] : vector<[4]xf32>
+// CHECK-VEC4-SVE:       %[[r:.*]] = vector.insert %[[l]], %[[v0]] [0] : f32 into vector<[4]xf32>
 // CHECK-VEC4-SVE:       %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[step]] iter_args(%[[red_in:.*]] = %[[r]]) -> (vector<[4]xf32>) {
 // CHECK-VEC4-SVE:         %[[sub:.*]] = affine.min #[[$map]](%[[c1024]], %[[i]])[%[[step]]]
 // CHECK-VEC4-SVE:         %[[mask:.*]] = vector.create_mask %[[sub]] : vector<[4]xi1>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
index f4b565c7f9c8a..0ab72897d7bc3 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
@@ -82,7 +82,7 @@
 // CHECK:               %[[VAL_57:.*]] = arith.select %[[VAL_39]], %[[VAL_56]], %[[VAL_32]] : index
 // CHECK:               scf.yield %[[VAL_55]], %[[VAL_57]], %[[VAL_58:.*]] : index, index, f64
 // CHECK:             } attributes {"Emitted from" = "linalg.generic"}
-// CHECK:             %[[VAL_59:.*]] = vector.insertelement %[[VAL_60:.*]]#2, %[[VAL_4]]{{\[}}%[[VAL_6]] : index] : vector<8xf64>
+// CHECK:             %[[VAL_59:.*]] = vector.insert %[[VAL_60:.*]]#2, %[[VAL_4]] [0] : f64 into vector<8xf64>
 // CHECK:             %[[VAL_61:.*]] = scf.for %[[VAL_62:.*]] = %[[VAL_60]]#0 to %[[VAL_21]] step %[[VAL_3]] iter_args(%[[VAL_63:.*]] = %[[VAL_59]]) -> (vector<8xf64>) {
 // CHECK:               %[[VAL_64:.*]] = affine.min #map(%[[VAL_21]], %[[VAL_62]]){{\[}}%[[VAL_3]]]
 // CHECK:               %[[VAL_65:.*]] = vector.create_mask %[[VAL_64]] : vector<8xi1>
diff --git a/mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir b/mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir
index 01b717090e87a..6effbbf98abb7 100644
--- a/mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir
+++ b/mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir
@@ -172,7 +172,7 @@ func.func @sparse_reduction_ori_accumulator_on_rhs(%argx: tensor<i13>,
 // CHECK-ON:           %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref<i32>
 // CHECK-ON:           %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_3]]] : memref<?xindex>
 // CHECK-ON:           %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK-ON:           %[[VAL_12:.*]] = vector.insertelement %[[VAL_9]], %[[VAL_4]]{{\[}}%[[VAL_3]] : index] : vector<8xi32>
+// CHECK-ON:           %[[VAL_12:.*]] = vector.insert %[[VAL_9]], %[[VAL_4]] [0] : i32 into vector<8xi32>
 // CHECK-ON:           %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xi32>) {
 // CHECK-ON:             %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]]
 // CHECK-ON:             %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1>
@@ -247,7 +247,7 @@ func.func @sparse_reduction_subi(%argx: tensor<i32>,
 // CHECK-ON:  %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref<i32>
 // CHECK-ON:  %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK-ON:  %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK-ON:  %[[VAL_12:.*]] = vector.insertelement %[[VAL_9]], %[[VAL_3]]{{\[}}%[[VAL_4]] : index] : vector<8xi32>
+// CHECK-ON:  %[[VAL_12:.*]] = vector.insert %[[VAL_9]], %[[VAL_3]] [0] : i32 into vector<8xi32>
 // CHECK-ON:  %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xi32>) {
 // CHECK-ON:    %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]]
 // CHECK-ON:    %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1>
@@ -323,7 +323,7 @@ func.func @sparse_reduction_xor(%argx: tensor<i32>,
 // CHECK-ON:   %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref<i32>
 // CHECK-ON:   %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK-ON:   %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK-ON:   %[[VAL_12:.*]] = vector.insertelement %[[VAL_9]], %[[VAL_3]]{{\[}}%[[VAL_4]] : index] : vector<8xi32>
+// CHECK-ON:   %[[VAL_12:.*]] = vector.insert %[[VAL_9]], %[[VAL_3]] [0] : i32 into vector<8xi32>
 // CHECK-ON:   %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xi32>) {
 // CHECK-ON:     %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]]
 // CHECK-ON:     %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1>
@@ -399,7 +399,7 @@ func.func @sparse_reduction_addi(%argx: tensor<i32>,
 // CHECK-ON:   %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref<f32>
 // CHECK-ON:   %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK-ON:   %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK-ON:   %[[VAL_12:.*]] = vector.insertelement %[[VAL_9]], %[[VAL_3]]{{\[}}%[[VAL_4]] : index] : vector<8xf32>
+// CHECK-ON:   %[[VAL_12:.*]] = vector.insert %[[VAL_9]], %[[VAL_3]] [0] : f32 into vector<8xf32>
 // CHECK-ON:   %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xf32>) {
 // CHECK-ON:     %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]]
 // CHECK-ON:     %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1>
@@ -475,7 +475,7 @@ func.func @sparse_reduction_subf(%argx: tensor<f32>,
 // CHECK-ON:   %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref<f32>
 // CHECK-ON:   %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK-ON:   %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK-ON:   %[[VAL_12:.*]] = vector.insertelement %[[VAL_9]], %[[VAL_3]]{{\[}}%[[VAL_4]] : index] : vector<8xf32>
+// CHECK-ON:   %[[VAL_12:.*]] = vector.insert %[[VAL_9]], %[[VAL_3]] [0] : f32 into vector<8xf32>
 // CHECK-ON:   %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xf32>) {
 // CHECK-ON:     %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]]
 // CHECK-ON:     %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1>

@github-actions
Copy link

github-actions bot commented Jun 7, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@dcaballe dcaballe merged commit 1ac61c8 into llvm:main Jun 12, 2025
7 checks passed
tomtor pushed a commit to tomtor/llvm-project that referenced this pull request Jun 14, 2025
…se vectorizer (llvm#143270)

This PR is part of the last step to remove `vector.extractelement` and `vector.insertelement` ops.
RFC: https://discourse.llvm.org/t/rfc-psa-remove-vector-extractelement-and-vector-insertelement-ops-in-favor-of-vector-extract-and-vector-insert-ops

It updates the Sparse Vectorizer to use `vector.extract` and `vector.insert` instead of `vector.extractelement` and `vector.insertelement`.
akuhlens pushed a commit to akuhlens/llvm-project that referenced this pull request Jun 24, 2025
…se vectorizer (llvm#143270)

This PR is part of the last step to remove `vector.extractelement` and `vector.insertelement` ops.
RFC: https://discourse.llvm.org/t/rfc-psa-remove-vector-extractelement-and-vector-insertelement-ops-in-favor-of-vector-extract-and-vector-insert-ops

It updates the Sparse Vectorizer to use `vector.extract` and `vector.insert` instead of `vector.extractelement` and `vector.insertelement`.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

mlir:sparse Sparse compiler in MLIR mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants