Skip to content

Conversation

@dcaballe
Copy link
Contributor

@dcaballe dcaballe commented Jul 8, 2025

It generates a linearized version of the vector.extract for scalar cases.

@llvmbot
Copy link
Member

llvmbot commented Jul 8, 2025

@llvm/pr-subscribers-mlir

Author: Diego Caballero (dcaballe)

Changes

It generates a linearized version of the vector.extract for scalar cases.


Full diff: https://github.com/llvm/llvm-project/pull/147440.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+36-13)
  • (modified) mlir/test/Dialect/Vector/linearize.mlir (+13-13)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 7cac1cbafdd64..8b232aafbca9d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -395,15 +395,32 @@ struct LinearizeVectorShuffle final
   }
 };
 
-/// This pattern converts the ExtractOp to a ShuffleOp that works on a
-/// linearized vector.
-/// Following,
-///   vector.extract %source [ position ]
-/// is converted to :
-///   %source_1d = vector.shape_cast %source
-///   %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
-///   %out_nd = vector.shape_cast %out_1d
-/// `shuffle_indices_1d` is computed using the position of the original extract.
+/// This pattern linearizes `vector.extract` operations. It generates a 1-D
+/// version of the `vector.extract` operation when extracting a scalar from a
+/// vector. It generates a 1-D `vector.shuffle` operation when extracting a
+/// subvector from a larger vector.
+///
+/// Example #1:
+///
+///     %0 = vector.extract %arg0[1]: vector<8x2xf32> from vector<2x8x2xf32>
+///
+///   is converted to:
+///
+///     %0 = vector.shape_cast %arg0 : vector<2x8x2xf32> to vector<32xf32>
+///     %1 = vector.shuffle %0, %0 [16, 17, 18, 19, 20, 21, 22, 23,
+///                                 24, 25, 26, 27, 28, 29, 30, 31] :
+///            vector<32xf32>, vector<32xf32>
+///     %2 = vector.shape_cast %1 : vector<16xf32> to vector<8x2xf32>
+///
+/// Example #2:
+///
+///     %0 = vector.extract %arg0[1, 2] : i32 from vector<2x4xi32>
+///
+///   is converted to:
+///
+///     %0 = vector.shape_cast %arg0 : vector<2x4xi32> to vector<8xi32>
+///     %1 = vector.extract %0[6] : i32 from vector<8xi32>
+///
 struct LinearizeVectorExtract final
     : public OpConversionPattern<vector::ExtractOp> {
   using OpConversionPattern::OpConversionPattern;
@@ -413,10 +430,6 @@ struct LinearizeVectorExtract final
   LogicalResult
   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    // Skip if result is not a vector type
-    if (!isa<VectorType>(extractOp.getType()))
-      return rewriter.notifyMatchFailure(extractOp,
-                                         "scalar extract not supported");
     Type dstTy = getTypeConverter()->convertType(extractOp.getType());
     assert(dstTy && "expected 1-D vector type");
 
@@ -436,6 +449,16 @@ struct LinearizeVectorExtract final
       linearizedOffset += offsets[i] * size;
     }
 
+    if (!isa<VectorType>(extractOp.getType())) {
+      // Scalar case: generate a 1-D extract.
+      Value result = rewriter.createOrFold<vector::ExtractOp>(
+          extractOp.getLoc(), adaptor.getVector(), linearizedOffset);
+      rewriter.replaceOp(extractOp, result);
+      return success();
+    }
+
+    // Vector case: generate a shuffle.
+
     llvm::SmallVector<int64_t, 2> indices(size);
     std::iota(indices.begin(), indices.end(), linearizedOffset);
     rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 894171500d9d6..cbc15f34918f6 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -264,6 +264,19 @@ func.func @test_vector_shuffle(%arg0: vector<4x2xf32>, %arg1: vector<4x2xf32>) -
 
 // -----
 
+// CHECK-LABEL: test_vector_extract_scalar
+// CHECK-SAME: (%[[ARG:.*]]: vector<2x4xi32>) -> i32 {
+func.func @test_vector_extract_scalar(%arg0 : vector<2x4xi32>) -> i32 {
+
+  // CHECK: %[[SRC_1D:.*]] = vector.shape_cast %[[ARG]] : vector<2x4xi32> to vector<8xi32>
+  // CHECK: %[[EXTRACT_1D:.*]] = vector.extract %[[SRC_1D]][6] : i32 from vector<8xi32>
+  // CHECK: return %[[EXTRACT_1D]] : i32
+  %0 = vector.extract %arg0[1, 2] : i32 from vector<2x4xi32>
+  return %0 : i32
+}
+
+// -----
+
 // CHECK-LABEL: test_vector_extract
 // CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x8x2xf32>) -> vector<8x2xf32> {
 func.func @test_vector_extract(%arg0: vector<2x8x2xf32>) -> vector<8x2xf32> {
@@ -341,19 +354,6 @@ func.func @test_vector_insert_scalable(%arg0: vector<2x8x[4]xf32>, %arg1: vector
 
 // -----
 
-// CHECK-LABEL: test_vector_extract_scalar
-func.func @test_vector_extract_scalar(%idx : index) {
-  %cst = arith.constant dense<[1, 2, 3, 4]> : vector<4xi32>
-
-  // CHECK-NOT: vector.shuffle
-  // CHECK:     vector.extract
-  // CHECK-NOT: vector.shuffle
-  %0 = vector.extract %cst[%idx] : i32 from vector<4xi32>
-  return
-}
-
-// -----
-
 // CHECK-LABEL: test_vector_bitcast
 // CHECK-SAME: %[[ARG_0:.*]]: vector<4x4xf32>
 func.func @test_vector_bitcast(%arg0: vector<4x4xf32>) -> vector<4x8xf16> {

@llvmbot
Copy link
Member

llvmbot commented Jul 8, 2025

@llvm/pr-subscribers-mlir-vector

Author: Diego Caballero (dcaballe)

Changes

It generates a linearized version of the vector.extract for scalar cases.


Full diff: https://github.com/llvm/llvm-project/pull/147440.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+36-13)
  • (modified) mlir/test/Dialect/Vector/linearize.mlir (+13-13)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 7cac1cbafdd64..8b232aafbca9d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -395,15 +395,32 @@ struct LinearizeVectorShuffle final
   }
 };
 
-/// This pattern converts the ExtractOp to a ShuffleOp that works on a
-/// linearized vector.
-/// Following,
-///   vector.extract %source [ position ]
-/// is converted to :
-///   %source_1d = vector.shape_cast %source
-///   %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
-///   %out_nd = vector.shape_cast %out_1d
-/// `shuffle_indices_1d` is computed using the position of the original extract.
+/// This pattern linearizes `vector.extract` operations. It generates a 1-D
+/// version of the `vector.extract` operation when extracting a scalar from a
+/// vector. It generates a 1-D `vector.shuffle` operation when extracting a
+/// subvector from a larger vector.
+///
+/// Example #1:
+///
+///     %0 = vector.extract %arg0[1]: vector<8x2xf32> from vector<2x8x2xf32>
+///
+///   is converted to:
+///
+///     %0 = vector.shape_cast %arg0 : vector<2x8x2xf32> to vector<32xf32>
+///     %1 = vector.shuffle %0, %0 [16, 17, 18, 19, 20, 21, 22, 23,
+///                                 24, 25, 26, 27, 28, 29, 30, 31] :
+///            vector<32xf32>, vector<32xf32>
+///     %2 = vector.shape_cast %1 : vector<16xf32> to vector<8x2xf32>
+///
+/// Example #2:
+///
+///     %0 = vector.extract %arg0[1, 2] : i32 from vector<2x4xi32>
+///
+///   is converted to:
+///
+///     %0 = vector.shape_cast %arg0 : vector<2x4xi32> to vector<8xi32>
+///     %1 = vector.extract %0[6] : i32 from vector<8xi32>
+///
 struct LinearizeVectorExtract final
     : public OpConversionPattern<vector::ExtractOp> {
   using OpConversionPattern::OpConversionPattern;
@@ -413,10 +430,6 @@ struct LinearizeVectorExtract final
   LogicalResult
   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    // Skip if result is not a vector type
-    if (!isa<VectorType>(extractOp.getType()))
-      return rewriter.notifyMatchFailure(extractOp,
-                                         "scalar extract not supported");
     Type dstTy = getTypeConverter()->convertType(extractOp.getType());
     assert(dstTy && "expected 1-D vector type");
 
@@ -436,6 +449,16 @@ struct LinearizeVectorExtract final
       linearizedOffset += offsets[i] * size;
     }
 
+    if (!isa<VectorType>(extractOp.getType())) {
+      // Scalar case: generate a 1-D extract.
+      Value result = rewriter.createOrFold<vector::ExtractOp>(
+          extractOp.getLoc(), adaptor.getVector(), linearizedOffset);
+      rewriter.replaceOp(extractOp, result);
+      return success();
+    }
+
+    // Vector case: generate a shuffle.
+
     llvm::SmallVector<int64_t, 2> indices(size);
     std::iota(indices.begin(), indices.end(), linearizedOffset);
     rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 894171500d9d6..cbc15f34918f6 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -264,6 +264,19 @@ func.func @test_vector_shuffle(%arg0: vector<4x2xf32>, %arg1: vector<4x2xf32>) -
 
 // -----
 
+// CHECK-LABEL: test_vector_extract_scalar
+// CHECK-SAME: (%[[ARG:.*]]: vector<2x4xi32>) -> i32 {
+func.func @test_vector_extract_scalar(%arg0 : vector<2x4xi32>) -> i32 {
+
+  // CHECK: %[[SRC_1D:.*]] = vector.shape_cast %[[ARG]] : vector<2x4xi32> to vector<8xi32>
+  // CHECK: %[[EXTRACT_1D:.*]] = vector.extract %[[SRC_1D]][6] : i32 from vector<8xi32>
+  // CHECK: return %[[EXTRACT_1D]] : i32
+  %0 = vector.extract %arg0[1, 2] : i32 from vector<2x4xi32>
+  return %0 : i32
+}
+
+// -----
+
 // CHECK-LABEL: test_vector_extract
 // CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x8x2xf32>) -> vector<8x2xf32> {
 func.func @test_vector_extract(%arg0: vector<2x8x2xf32>) -> vector<8x2xf32> {
@@ -341,19 +354,6 @@ func.func @test_vector_insert_scalable(%arg0: vector<2x8x[4]xf32>, %arg1: vector
 
 // -----
 
-// CHECK-LABEL: test_vector_extract_scalar
-func.func @test_vector_extract_scalar(%idx : index) {
-  %cst = arith.constant dense<[1, 2, 3, 4]> : vector<4xi32>
-
-  // CHECK-NOT: vector.shuffle
-  // CHECK:     vector.extract
-  // CHECK-NOT: vector.shuffle
-  %0 = vector.extract %cst[%idx] : i32 from vector<4xi32>
-  return
-}
-
-// -----
-
 // CHECK-LABEL: test_vector_bitcast
 // CHECK-SAME: %[[ARG_0:.*]]: vector<4x4xf32>
 func.func @test_vector_bitcast(%arg0: vector<4x4xf32>) -> vector<4x8xf16> {

@dcaballe dcaballe requested review from Hardcode84 and nbpatel July 9, 2025 18:52
Copy link
Contributor

@nbpatel nbpatel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Just one nit:maybe adaptor.getVector() can be promoted to a variable

dcaballe added 2 commits July 11, 2025 22:35
Generate a linearized version of the `vector.extract` for these cases.
@dcaballe dcaballe force-pushed the vector-linearize-scalar-extract branch from 781f0ca to 03eb768 Compare July 11, 2025 22:45
@dcaballe dcaballe merged commit ace1c83 into llvm:main Jul 11, 2025
8 of 9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants