Skip to content

Conversation

@banach-space
Copy link
Contributor

This addresses a TODO and an earlier review comment from the original PR
where the pattern was introduced:

The pattern is relatively straightforward and has not been updated since it
landed, so it seems reasonable to promote it to a canonicalization.

Note: this change only moves the existing pattern into canonicalization
infrastructure; it does not add or remove any functionality.

This addresses a TODO and an earlier review comment from the original PR
where the pattern was introduced:
* llvm#134389 (comment)

The pattern is relatively straightforward and has not been updated since it
landed, so it seems reasonable to promote it to a canonicalization.

Note: this change only moves the existing pattern into canonicalization
infrastructure; it does not add or remove any functionality.
@llvmbot
Copy link
Member

llvmbot commented Dec 8, 2025

@llvm/pr-subscribers-mlir-vector
@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: Andrzej Warzyński (banach-space)

Changes

This addresses a TODO and an earlier review comment from the original PR
where the pattern was introduced:

The pattern is relatively straightforward and has not been updated since it
landed, so it seems reasonable to promote it to a canonicalization.

Note: this change only moves the existing pattern into canonicalization
infrastructure; it does not add or remove any functionality.


Patch is 21.79 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/171198.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/IR/Vector.td (+2-1)
  • (modified) mlir/lib/Dialect/Vector/IR/CMakeLists.txt (+1)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+96-1)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+1-2)
  • (added) mlir/test/Dialect/Vector/canonicalize/vector-extract.mlir (+131)
  • (modified) mlir/test/Dialect/Vector/vector-sink.mlir (+1-127)
  • (modified) mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir (+1-2)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/Vector.td b/mlir/include/mlir/Dialect/Vector/IR/Vector.td
index 5125ae7c13717..beb6bedb908e9 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/Vector.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/Vector.td
@@ -23,7 +23,8 @@ def Vector_Dialect : Dialect {
   let hasConstantMaterializer = 1;
   let dependentDialects = [
     "arith::ArithDialect",
-    "ub::UBDialect"
+    "ub::UBDialect",
+    "memref::MemRefDialect"
   ];
 }
 
diff --git a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt
index 0248896e096a0..9cf4fedbbe978 100644
--- a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt
@@ -15,6 +15,7 @@ add_mlir_dialect_library(MLIRVectorDialect
   LINK_LIBS PUBLIC
   MLIRAffineDialect
   MLIRArithDialect
+  MLIRMemRefDialect
   MLIRControlFlowInterfaces
   MLIRDataLayoutInterfaces
   MLIRDestinationStyleOpInterface
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 2789f63555524..eceeed9d03f4b 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2226,6 +2226,99 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
   }
 };
 
+/// Check if the element type is suitable for vector.load/store sinking.
+/// Element type must be index or byte-aligned integer or floating-point type.
+static bool isSupportedMemSinkElementType(Type type) {
+  if (isa<IndexType>(type))
+    return true;
+
+  return type.isIntOrFloat() && type.getIntOrFloatBitWidth() % 8 == 0;
+}
+
+/// Pattern to rewrite `vector.extract(vector.load) -> vector/memref.load.
+/// Only index and byte-aligned integer and floating-point element types are
+/// supported for now.
+///
+/// Example:
+/// ```
+///  vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+///  vector.extract %0[1] : f32 from vector<4xf32>
+/// ```
+/// Gets converted to:
+/// ```
+/// %c1 = arith.constant 1 : index
+/// %0 = arith.addi %arg1, %c1 overflow<nsw> : index
+/// %1 = memref.load %arg0[%0] : memref<?xf32>
+/// ```
+///
+/// Note, this is consider beneficial only in single-use cases.
+class ExtractOpFromLoad final : public OpRewritePattern<ExtractOp> {
+public:
+  using Base::Base;
+
+  LogicalResult matchAndRewrite(vector::ExtractOp op,
+                                PatternRewriter &rewriter) const override {
+    auto loadOp = op.getSource().getDefiningOp<vector::LoadOp>();
+    if (!loadOp)
+      return rewriter.notifyMatchFailure(op, "expected a load op");
+
+    // Checking for single use so we won't duplicate load ops.
+    if (!loadOp->hasOneUse())
+      return rewriter.notifyMatchFailure(op, "expected single op use");
+
+    VectorType loadVecType = loadOp.getVectorType();
+    if (loadVecType.isScalable())
+      return rewriter.notifyMatchFailure(op,
+                                         "scalable vectors are not supported");
+
+    MemRefType memType = loadOp.getMemRefType();
+
+    // Non-byte-aligned types are tricky and may require special handling,
+    // ignore them for now.
+    if (!isSupportedMemSinkElementType(memType.getElementType()))
+      return rewriter.notifyMatchFailure(op, "unsupported element type");
+
+    int64_t rankOffset = memType.getRank() - loadVecType.getRank();
+    if (rankOffset < 0)
+      return rewriter.notifyMatchFailure(op, "unsupported ranks combination");
+
+    auto extractVecType = dyn_cast<VectorType>(op.getResult().getType());
+    int64_t finalRank = 0;
+    if (extractVecType)
+      finalRank = extractVecType.getRank();
+
+    SmallVector<Value> indices = loadOp.getIndices();
+    SmallVector<OpFoldResult> extractPos = op.getMixedPosition();
+
+    // There may be memory stores between the load and the extract op, so we
+    // need to make sure that the new load op is inserted at the same place as
+    // the original load op.
+    OpBuilder::InsertionGuard g(rewriter);
+    rewriter.setInsertionPoint(loadOp);
+    Location loc = loadOp.getLoc();
+    ArithIndexingBuilder idxBuilderf(rewriter, loc);
+    for (auto i : llvm::seq<int64_t>(rankOffset, indices.size() - finalRank)) {
+      OpFoldResult pos = extractPos[i - rankOffset];
+      if (isZeroInteger(pos))
+        continue;
+
+      Value offset = getValueOrCreateConstantIndexOp(rewriter, loc, pos);
+      indices[i] = idxBuilderf.add(indices[i], offset);
+    }
+
+    Value base = loadOp.getBase();
+    if (extractVecType) {
+      rewriter.replaceOpWithNewOp<vector::LoadOp>(op, extractVecType, base,
+                                                  indices);
+    } else {
+      rewriter.replaceOpWithNewOp<memref::LoadOp>(op, base, indices);
+    }
+    // We checked for single use so we can safely erase the load op.
+    rewriter.eraseOp(loadOp);
+    return success();
+  }
+};
+
 // Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask.
 class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
 public:
@@ -2363,7 +2456,9 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
 
 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
-  results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
+  results
+      .add<ExtractOpFromBroadcast, ExtractOpFromCreateMask, ExtractOpFromLoad>(
+          context);
   results.add(foldExtractFromShapeCastToShapeCast);
   results.add(foldExtractFromFromElements);
 }
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 726da1e9a3d14..2e1f8ff38dbf6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -2384,8 +2384,7 @@ void mlir::vector::populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
 void mlir::vector::populateSinkVectorMemOpsPatterns(RewritePatternSet &patterns,
                                                     PatternBenefit benefit) {
   // TODO: Consider converting these patterns to canonicalizations.
-  patterns.add<ExtractOpFromLoad, StoreOpFromBroadcast>(patterns.getContext(),
-                                                        benefit);
+  patterns.add<StoreOpFromBroadcast>(patterns.getContext(), benefit);
 }
 
 void mlir::vector::populateChainedVectorReductionFoldingPatterns(
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-extract.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-extract.mlir
new file mode 100644
index 0000000000000..c140bcb3af8ad
--- /dev/null
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-extract.mlir
@@ -0,0 +1,131 @@
+// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file | FileCheck %s
+
+// This file contains some tests of folding/canonicalizing vector.extract
+
+//-----------------------------------------------------------------------------
+// [Pattern: ExtractOpFromLoad]
+//-----------------------------------------------------------------------------
+
+// CHECK-LABEL: @extract_load_scalar
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
+func.func @extract_load_scalar(%arg0: memref<?xf32>, %arg1: index) -> f32 {
+// CHECK:   %[[RES:.*]] = memref.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>
+// CHECK:   return %[[RES]] : f32
+  %0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+  %1 = vector.extract %0[0] : f32 from vector<4xf32>
+  return %1 : f32
+}
+
+// CHECK-LABEL: @extract_load_index
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xindex>, %[[ARG1:.*]]: index)
+func.func @extract_load_index(%arg0: memref<?xindex>, %arg1: index) -> index {
+// CHECK:   %[[RES:.*]] = memref.load %[[ARG0]][%[[ARG1]]] : memref<?xindex>
+// CHECK:   return %[[RES]] : index
+  %0 = vector.load %arg0[%arg1] : memref<?xindex>, vector<4xindex>
+  %1 = vector.extract %0[0] : index from vector<4xindex>
+  return %1 : index
+}
+
+// CHECK-LABEL: @extract_load_scalar_non_zero_off
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
+func.func @extract_load_scalar_non_zero_off(%arg0: memref<?xf32>, %arg1: index) -> f32 {
+// CHECK:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK:   %[[OFF:.*]] = arith.addi %[[ARG1]], %[[C1]] overflow<nsw> : index
+// CHECK:   %[[RES:.*]] = memref.load %[[ARG0]][%[[OFF]]] : memref<?xf32>
+// CHECK:   return %[[RES]] : f32
+  %0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+  %1 = vector.extract %0[1] : f32 from vector<4xf32>
+  return %1 : f32
+}
+
+// CHECK-LABEL: @extract_load_scalar_dyn_off
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+func.func @extract_load_scalar_dyn_off(%arg0: memref<?xf32>, %arg1: index, %arg2: index) -> f32 {
+// CHECK:   %[[OFF:.*]] = arith.addi %[[ARG1]], %[[ARG2]] overflow<nsw> : index
+// CHECK:   %[[RES:.*]] = memref.load %[[ARG0]][%[[OFF]]] : memref<?xf32>
+// CHECK:   return %[[RES]] : f32
+  %0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+  %1 = vector.extract %0[%arg2] : f32 from vector<4xf32>
+  return %1 : f32
+}
+
+// CHECK-LABEL: @extract_load_vec_non_zero_off
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+func.func @extract_load_vec_non_zero_off(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index) -> vector<4xf32> {
+// CHECK:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK:   %[[OFF:.*]] = arith.addi %[[ARG1]], %[[C1]] overflow<nsw> : index
+// CHECK:   %[[RES:.*]] = vector.load %[[ARG0]][%[[OFF]], %[[ARG2]]] : memref<?x?xf32>, vector<4xf32>
+// CHECK:   return %[[RES]] : vector<4xf32>
+  %0 = vector.load %arg0[%arg1, %arg2] : memref<?x?xf32>, vector<2x4xf32>
+  %1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32>
+  return %1 : vector<4xf32>
+}
+
+// CHECK-LABEL: @extract_load_scalar_non_zero_off_2d_src_memref
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+func.func @extract_load_scalar_non_zero_off_2d_src_memref(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index) -> f32 {
+// CHECK:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK:   %[[OFF:.*]] = arith.addi %[[ARG2]], %[[C1]] overflow<nsw> : index
+// CHECK:   %[[RES:.*]] = memref.load %[[ARG0]][%[[ARG1]], %[[OFF]]] : memref<?x?xf32>
+// CHECK:   return %[[RES]] : f32
+  %0 = vector.load %arg0[%arg1, %arg2] : memref<?x?xf32>, vector<4xf32>
+  %1 = vector.extract %0[1] : f32 from vector<4xf32>
+  return %1 : f32
+}
+
+// CHECK-LABEL: @extract_load_vec_high_rank
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?x?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)
+func.func @extract_load_vec_high_rank(%arg0: memref<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index) -> vector<4xf32> {
+// CHECK:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK:   %[[OFF:.*]] = arith.addi %[[ARG2]], %[[C1]] overflow<nsw> : index
+// CHECK:   %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]], %[[OFF]], %[[ARG3]]] : memref<?x?x?xf32>, vector<4xf32>
+// CHECK:   return %[[RES]] : vector<4xf32>
+  %0 = vector.load %arg0[%arg1, %arg2, %arg3] : memref<?x?x?xf32>, vector<2x4xf32>
+  %1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32>
+  return %1 : vector<4xf32>
+}
+
+// CHECK-LABEL: @negative_extract_load_scalar_from_memref_of_vec
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xvector<4xf32>>, %[[ARG1:.*]]: index)
+func.func @negative_extract_load_scalar_from_memref_of_vec(%arg0: memref<?xvector<4xf32>>, %arg1: index) -> f32 {
+// CHECK:   %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xvector<4xf32>>, vector<4xf32>
+// CHECK:   %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<4xf32>
+// CHECK:   return %[[EXT]] : f32
+  %0 = vector.load %arg0[%arg1] : memref<?xvector<4xf32>>, vector<4xf32>
+  %1 = vector.extract %0[0] : f32 from vector<4xf32>
+  return %1 : f32
+}
+
+// CHECK-LABEL: @negative_extract_load_scalar_from_memref_of_i1
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xi1>, %[[ARG1:.*]]: index)
+func.func @negative_extract_load_scalar_from_memref_of_i1(%arg0: memref<?xi1>, %arg1: index) -> i1 {
+// Subbyte types are tricky, ignore them for now.
+// CHECK:   %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xi1>, vector<8xi1>
+// CHECK:   %[[EXT:.*]] = vector.extract %[[RES]][0] : i1 from vector<8xi1>
+// CHECK:   return %[[EXT]] : i1
+  %0 = vector.load %arg0[%arg1] : memref<?xi1>, vector<8xi1>
+  %1 = vector.extract %0[0] : i1 from vector<8xi1>
+  return %1 : i1
+}
+
+// CHECK-LABEL: @negative_extract_load_no_single_use
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
+func.func @negative_extract_load_no_single_use(%arg0: memref<?xf32>, %arg1: index) -> (f32, vector<4xf32>) {
+// CHECK:   %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<4xf32>
+// CHECK:   %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<4xf32>
+// CHECK:   return %[[EXT]], %[[RES]] : f32, vector<4xf32>
+  %0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+  %1 = vector.extract %0[0] : f32 from vector<4xf32>
+  return %1, %0 : f32, vector<4xf32>
+}
+
+// CHECK-LABEL: @negative_extract_load_scalable
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
+func.func @negative_extract_load_scalable(%arg0: memref<?xf32>, %arg1: index) -> f32 {
+// CHECK:   %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<[1]xf32>
+// CHECK:   %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<[1]xf32>
+// CHECK:   return %[[EXT]] : f32
+  %0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<[1]xf32>
+  %1 = vector.extract %0[0] : f32 from vector<[1]xf32>
+  return %1 : f32
+}
diff --git a/mlir/test/Dialect/Vector/vector-sink.mlir b/mlir/test/Dialect/Vector/vector-sink.mlir
index beaba52af1841..5d6ea5147fa73 100644
--- a/mlir/test/Dialect/Vector/vector-sink.mlir
+++ b/mlir/test/Dialect/Vector/vector-sink.mlir
@@ -11,6 +11,7 @@
 // CHECK:           %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex>
 // CHECK:           return %[[BCAST]] : vector<1x4xindex>
 
+
 func.func @broadcast_scalar_with_bcast(%arg1: index, %arg2: index) -> vector<1x4xindex> {
   %0 = vector.broadcast %arg1 : index to vector<1x4xindex>
   %1 = vector.broadcast %arg2 : index to vector<1x4xindex>
@@ -651,133 +652,6 @@ func.func @negative_extract_dynamic_pos(%arg0: vector<4xf32>, %arg1 : vector<4xf
   return %2 : f32
 }
 
-//-----------------------------------------------------------------------------
-// [Pattern: ExtractOpFromLoad]
-//-----------------------------------------------------------------------------
-
-// CHECK-LABEL: @extract_load_scalar
-//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
-func.func @extract_load_scalar(%arg0: memref<?xf32>, %arg1: index) -> f32 {
-// CHECK:   %[[RES:.*]] = memref.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>
-// CHECK:   return %[[RES]] : f32
-  %0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
-  %1 = vector.extract %0[0] : f32 from vector<4xf32>
-  return %1 : f32
-}
-
-// CHECK-LABEL: @extract_load_index
-//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xindex>, %[[ARG1:.*]]: index)
-func.func @extract_load_index(%arg0: memref<?xindex>, %arg1: index) -> index {
-// CHECK:   %[[RES:.*]] = memref.load %[[ARG0]][%[[ARG1]]] : memref<?xindex>
-// CHECK:   return %[[RES]] : index
-  %0 = vector.load %arg0[%arg1] : memref<?xindex>, vector<4xindex>
-  %1 = vector.extract %0[0] : index from vector<4xindex>
-  return %1 : index
-}
-
-// CHECK-LABEL: @extract_load_scalar_non_zero_off
-//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
-func.func @extract_load_scalar_non_zero_off(%arg0: memref<?xf32>, %arg1: index) -> f32 {
-// CHECK:   %[[C1:.*]] = arith.constant 1 : index
-// CHECK:   %[[OFF:.*]] = arith.addi %[[ARG1]], %[[C1]] overflow<nsw> : index
-// CHECK:   %[[RES:.*]] = memref.load %[[ARG0]][%[[OFF]]] : memref<?xf32>
-// CHECK:   return %[[RES]] : f32
-  %0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
-  %1 = vector.extract %0[1] : f32 from vector<4xf32>
-  return %1 : f32
-}
-
-// CHECK-LABEL: @extract_load_scalar_dyn_off
-//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
-func.func @extract_load_scalar_dyn_off(%arg0: memref<?xf32>, %arg1: index, %arg2: index) -> f32 {
-// CHECK:   %[[OFF:.*]] = arith.addi %[[ARG1]], %[[ARG2]] overflow<nsw> : index
-// CHECK:   %[[RES:.*]] = memref.load %[[ARG0]][%[[OFF]]] : memref<?xf32>
-// CHECK:   return %[[RES]] : f32
-  %0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
-  %1 = vector.extract %0[%arg2] : f32 from vector<4xf32>
-  return %1 : f32
-}
-
-// CHECK-LABEL: @extract_load_vec_non_zero_off
-//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
-func.func @extract_load_vec_non_zero_off(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index) -> vector<4xf32> {
-// CHECK:   %[[C1:.*]] = arith.constant 1 : index
-// CHECK:   %[[OFF:.*]] = arith.addi %[[ARG1]], %[[C1]] overflow<nsw> : index
-// CHECK:   %[[RES:.*]] = vector.load %[[ARG0]][%[[OFF]], %[[ARG2]]] : memref<?x?xf32>, vector<4xf32>
-// CHECK:   return %[[RES]] : vector<4xf32>
-  %0 = vector.load %arg0[%arg1, %arg2] : memref<?x?xf32>, vector<2x4xf32>
-  %1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32>
-  return %1 : vector<4xf32>
-}
-
-// CHECK-LABEL: @extract_load_scalar_non_zero_off_2d_src_memref
-//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
-func.func @extract_load_scalar_non_zero_off_2d_src_memref(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index) -> f32 {
-// CHECK:   %[[C1:.*]] = arith.constant 1 : index
-// CHECK:   %[[OFF:.*]] = arith.addi %[[ARG2]], %[[C1]] overflow<nsw> : index
-// CHECK:   %[[RES:.*]] = memref.load %[[ARG0]][%[[ARG1]], %[[OFF]]] : memref<?x?xf32>
-// CHECK:   return %[[RES]] : f32
-  %0 = vector.load %arg0[%arg1, %arg2] : memref<?x?xf32>, vector<4xf32>
-  %1 = vector.extract %0[1] : f32 from vector<4xf32>
-  return %1 : f32
-}
-
-// CHECK-LABEL: @extract_load_vec_high_rank
-//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?x?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)
-func.func @extract_load_vec_high_rank(%arg0: memref<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index) -> vector<4xf32> {
-// CHECK:   %[[C1:.*]] = arith.constant 1 : index
-// CHECK:   %[[OFF:.*]] = arith.addi %[[ARG2]], %[[C1]] overflow<nsw> : index
-// CHECK:   %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]], %[[OFF]], %[[ARG3]]] : memref<?x?x?xf32>, vector<4xf32>
-// CHECK:   return %[[RES]] : vector<4xf32>
-  %0 = vector.load %arg0[%arg1, %arg2, %arg3] : memref<?x?x?xf32>, vector<2x4xf32>
-  %1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32>
-  return %1 : vector<4xf32>
-}
-
-// CHECK-LABEL: @negative_extract_load_scalar_from_memref_of_vec
-//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xvector<4xf32>>, %[[ARG1:.*]]: index)
-func.func @negative_extract_load_scalar_from_memref_of_vec(%arg0: memref<?xvector<4xf32>>, %arg1: index) -> f32 {
-// CHECK:   %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xvector<4xf32>>, vector<4xf32>
-// CHECK:   %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<4xf32>
-// CHECK:   return %[[EXT]] : f32
-  %0 = vector.load %arg0[%arg1] : memref<?xvector<4xf32>>, vector<4xf32>
-  %1 = vector.extract %0[0] : f32 from vector<4xf32>
-  return %1 : f32
-}
-
-// CHECK-LABEL: @negative_extract_load_scalar_from_memref_of_i1
-//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xi1>, %[[ARG1:.*]]: index)
-func.func @negative_extract_load_scalar_from_memref_of_i1(%arg0: memref<?xi1>, %arg1: index) -> i1 {
-// Subbyte types are tricky, ignore them for now.
-// CHECK:   %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xi1>, vector<8xi1>
-// CHECK:   %[[EXT:.*]] = vector.extract %[[RES]][0] : i1 from vector<8xi1>
-// CHECK:   return %[[EXT]] : i1
-  %0 = vector.load %arg0[%arg1] : memref<?xi1>, vector<8xi1>
-  %1 = vector.extract %0[0] : i1 from vector<8xi1>
-  return %1 : i1
-}
-
-// CHECK-LABEL: @negative_extract_load_no_single_use
-//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
-func.func @negative_extract_load_no_single_use(%arg0: memref<?xf32>, %arg1: index) -> (f32, vector<4xf32>) {
-// CHECK:   %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<4xf...
[truncated]

@joker-eph
Copy link
Collaborator

The dependency on memref dialect is a bit unfortunate here isn't it?

@banach-space
Copy link
Contributor Author

The dependency on memref dialect is a bit unfortunate here isn't it?

Agreed. I guess the other dialect dependencies also come from canonicalizations ?

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wrote some comments under the original PR that need to be addressed: #134389 (comment)

The two main issues are losing alignment info, like @dcaballe said:

For that example, I would expect the alignment information to be explicit somewhere as vector.load doesn’t have any default alignment. In the presence of no alignment information, I’m still not sure this transformation is dropping information.

We now have alignment attributes for this.

And second, breaking programs that depend on OOB semantics.

@Hardcode84
Copy link
Contributor

re OOB semantics

    Representation-wise, the 'vector.load' operation permits out-of-bounds
    reads. Support and implementation of out-of-bounds vector loads is
    target-specific. No assumptions should be made on the value of elements
    loaded out of bounds. Not all targets may support out-of-bounds vector
    loads.

I don't think current semantics is particularly useful in general, I'd prefer it was specified explicitly via flag per vector op. This can also be potentially useful during lowering (e.g. use or not buffer ops on AMDGPU).

@kuhar
Copy link
Member

kuhar commented Dec 8, 2025

I don't think current semantics is particularly useful in general, I'd prefer it was specified explicitly via flag per vector op. This can also be potentially useful during lowering (e.g. use or not buffer ops on AMDGPU).

+1, I think to make progress on this we'd have to refine the docs first and decide what happens on partial-oob accesses. The docs say that we can't assume anything about OOB elements, but it doesn't specify what happens with the in-bounds elements.

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, so I don't think, going off the definition of vector.load, this is canonical, and I'd argue it shouldn't be.

I'm going to link #134734 here, where a related issue with how partially out-of-bounds accesses are handled has been in a holding pattern for a while.

The current notes on vector.load re out-of-bounds values are

Representation-wise, the ‘vector.load’ operation permits out-of-bounds reads. Support and implementation of out-of-bounds vector loads is target-specific. No assumptions should be made on the value of elements loaded out of bounds. Not all targets may support out-of-bounds vector loads.

This isn't the most concrete, but my read here is that "any vector.load where some vector element is out of bounds can have arbitrary behavior, not necessarily limited to the value of that element".

That is, targets where the behavior is "if any accessed element is out of bounds, the entire access is deemed out of bounds" is valid. Under that behavior, scoping down to an extract could change the result of an access.

And I do, in fact, have an example of these semantics.

On AMD, when loading more than a (32-bit) word of elements from a buffer resource, bounds checks are performed per-word with, critically, saturating arithmetic, and out-of-bounds loads return 0.

That is, suppose I have a resource containing `8 floats (32 bytes) [f0, f1 ..., f7], annotated with the correct bounds.

Then,

%e1 = call <4 x float> @llvm.andgcn.raw.ptr.buffer.load(rsrc, i32 6 * 4, ...)

will read to %e1 = <f6, f7, 0, 0>. Applying this pattern will lead to the element at issue or 0 being returned.

However, with

%e2 = call <4 x float> @llvm.amdgcn.raw.buffer.load(rsrc, i32 -4, ...)

, I will not get the vector <0, f0, f1, f2>. I will instead get <0, 0, 0, 0>

Implementing the former behavior (currently needed for some Vulkan conformance) is, in the backend, "strict buffer OOB mode", and it implies substantial pessimizations of the generated code in order to get the correct behavior just in case someone did vector.load %memref[%value_that_is_negative_1] : memref<?xf32>, vector<4xf32>

Given that the traditional default is for this sort of scalarization to not be performed, and that that's important for performance, I conclude that this sort of "propagating"/"poisonous" behavior for vector.load's out of bounds accesses is a valid implementation of vector.load, and therefore that this pattern isn't canonical.

(That is, you can do this rewrite ... if you have reason to believe you're in a context where out of bounds accesses either can't happen or are handled elementwise)

@kuhar
Copy link
Member

kuhar commented Dec 11, 2025

I think this could be a path forward without restricting OOB behavior in the general case:

  1. Clarify the op semantics to say that partial OOB is implementation-defined
  2. Add a new inbounds attribute that introduces an assumption that there won't be any OOB access
  3. Enable this canon pattern in the presence of inbounds and take care of adjusting alignment attributes (if present) so that the backend can promote back to wider aligned accesses if it wants to

inbounds can either be set by an analysis or from a frontend.

@krzysz00
Copy link
Contributor

I think the inbounds attribute is a good move in any case - I've had it on my long-tail "someone really ought to" list for a while because it'll allow us to emit gep inbounds nuw when lowering vector.load to LLVM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants