Skip to content

Conversation

@dchigarev
Copy link
Contributor

This PR fixes a case where a source memref in vector.transfer_read/write is not contiguous, which violates the memref.collapse_shape semantic that is used in the lowering.

An example of a failing test
gpu.module @xevm_module {
gpu.func @load_from_subview(%source: memref<4096x4096xf16>, %off1: index, %off2: index) -> vector<8xf16> {
  %c0 = arith.constant 0.0 : f16
  %subview = memref.subview %source[%off1, %off2] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
  %0 = vector.transfer_read %subview[%off2, %off2], %c0
    {in_bounds = [true]} : memref<256x256xf16, strided<[4096, 1], offset: ?>>, vector<8xf16>
  gpu.return %0 : vector<8xf16>
}
}

Fails with:

/home/user/llvm/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir:404:8: error: 'memref.collapse_shape' op invalid source layout map or collapsing non-contiguous dims
  %0 = vector.transfer_read %subview[%off2, %off2], %c0
       ^
/home/user/llvm/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir:404:8: note: see current operation: %8 = "memref.collapse_shape"(%2) <{reassociation = [[0, 1]]}> : (memref<256x256xf16, strided<[4096, 1], offset: ?>>) -> memref<65536xf16>

A suggestion was to replace memref.collapse_shape with memref.extract_aligned_pointer_as_index which is done in this PR. Since extract_aligned_pointer applied to a subview returns an original pointer without subview offsets, this PR also adds a logic to use an offset obtained from memref.extract_strided_metadata in baseOffset calculation in computeOffsets.

@github-actions
Copy link

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

Comment on lines 268 to +270
static Value computeOffsets(VectorTransferOpInterface xferOp,
PatternRewriter &rewriter,
ArrayRef<Value> strides) {
PatternRewriter &rewriter, ArrayRef<Value> strides,
Value baseOffset) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

computeOffsets now takes a baseOffset obtained from memref.extract_strided_metadata in computeMemrefMeta (ex computeStrides)

Comment on lines -318 to -319
Value baseOffset = nullptr;
if (!indices.empty()) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

made the branch unconditional since we always want to consider the baseOffset

Signed-off-by: dchigarev <[email protected]>
// handling both static and dynamic memrefs while applying permutation
// transformations for XeGPU lowering.
static std::pair<SmallVector<Value>, Value>
computeMemrefMeta(VectorTransferOpInterface xferOp, PatternRewriter &rewriter) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A function that calls memref.extract_strided_metadata now also returns memref's offset together with the strides

Comment on lines +343 to +344
static Value memrefToIndexPtr(VectorTransferOpInterface xferOp,
PatternRewriter &rewriter) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

memref.collapse_shape -> memref.extract_aligned_pointer_as_index + arith.index_cast index -> i64

@dchigarev
Copy link
Contributor Author

@Jianhui-Li for review

@Jianhui-Li
Copy link
Contributor

overall looks good. Thanks!

@Jianhui-Li Jianhui-Li marked this pull request as ready for review September 11, 2025 22:22
@llvmbot
Copy link
Member

llvmbot commented Sep 11, 2025

@llvm/pr-subscribers-mlir

Author: Dmitry Chigarev (dchigarev)

Changes

This PR fixes a case where a source memref in vector.transfer_read/write is not contiguous, which violates the memref.collapse_shape semantic that is used in the lowering.

<details><summary>An example of a failing test</summary>

gpu.module @<!-- -->xevm_module {
gpu.func @<!-- -->load_from_subview(%source: memref&lt;4096x4096xf16&gt;, %off1: index, %off2: index) -&gt; vector&lt;8xf16&gt; {
  %c0 = arith.constant 0.0 : f16
  %subview = memref.subview %source[%off1, %off2] [256, 256] [1, 1] : memref&lt;4096x4096xf16&gt; to memref&lt;256x256xf16, strided&lt;[4096, 1], offset: ?&gt;&gt;
  %0 = vector.transfer_read %subview[%off2, %off2], %c0
    {in_bounds = [true]} : memref&lt;256x256xf16, strided&lt;[4096, 1], offset: ?&gt;&gt;, vector&lt;8xf16&gt;
  gpu.return %0 : vector&lt;8xf16&gt;
}
}

Fails with:

/home/user/llvm/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir:404:8: error: 'memref.collapse_shape' op invalid source layout map or collapsing non-contiguous dims
  %0 = vector.transfer_read %subview[%off2, %off2], %c0
       ^
/home/user/llvm/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir:404:8: note: see current operation: %8 = "memref.collapse_shape"(%2) &lt;{reassociation = [[0, 1]]}&gt; : (memref&lt;256x256xf16, strided&lt;[4096, 1], offset: ?&gt;&gt;) -&gt; memref&lt;65536xf16&gt;

</details>

A suggestion was to replace memref.collapse_shape with memref.extract_aligned_pointer_as_index which is done in this PR. Since extract_aligned_pointer applied to a subview returns an original pointer without subview offsets, this PR also adds a logic to use an offset obtained from memref.extract_strided_metadata in baseOffset calculation in computeOffsets.


Patch is 27.33 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/158126.diff

3 Files Affected:

  • (modified) mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp (+52-59)
  • (modified) mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir (+61-16)
  • (modified) mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir (+61-11)
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 819c2e5973ffd..852c322cc6467 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -180,26 +180,31 @@ static void adjustStridesForPermutation(AffineMap permMap,
   strides = applyPermutation(strides, perms64);
 }
 
-// Computes memory strides for vector transfer operations, handling both
-// static and dynamic memrefs while applying permutation transformations
-// for XeGPU lowering.
-static SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp,
-                                         PatternRewriter &rewriter) {
+// Computes memory strides and a memref offset for vector transfer operations,
+// handling both static and dynamic memrefs while applying permutation
+// transformations for XeGPU lowering.
+static std::pair<SmallVector<Value>, Value>
+computeMemrefMeta(VectorTransferOpInterface xferOp, PatternRewriter &rewriter) {
   SmallVector<Value> strides;
   Value baseMemref = xferOp.getBase();
   AffineMap permMap = xferOp.getPermutationMap();
   MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType());
 
   Location loc = xferOp.getLoc();
+  Value offsetVal = nullptr;
   if (memrefType.hasStaticShape()) {
     int64_t offset;
     SmallVector<int64_t> intStrides;
     if (failed(memrefType.getStridesAndOffset(intStrides, offset)))
-      return {};
+      return {{}, offsetVal};
     // Wrap static strides as MLIR values
     for (int64_t s : intStrides)
       strides.push_back(arith::ConstantIndexOp::create(rewriter, loc, s));
-  } else {
+    if (!ShapedType::isDynamic(offset))
+      offsetVal = arith::ConstantIndexOp::create(rewriter, loc, offset);
+  }
+
+  if (strides.empty() || !offsetVal) {
     // For dynamic shape memref, use memref.extract_strided_metadata to get
     // stride values
     unsigned rank = memrefType.getRank();
@@ -220,11 +225,16 @@ static SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp,
 
     auto meta = memref::ExtractStridedMetadataOp::create(
         rewriter, loc, resultTypes, baseMemref);
-    strides.append(meta.getStrides().begin(), meta.getStrides().end());
+
+    if (strides.empty())
+      strides.append(meta.getStrides().begin(), meta.getStrides().end());
+
+    if (!offsetVal)
+      offsetVal = meta.getOffset();
   }
   // Adjust strides according to the permutation map (e.g., for transpose)
   adjustStridesForPermutation(permMap, strides);
-  return strides;
+  return {strides, offsetVal};
 }
 
 // This function compute the vectors of localOffsets for scattered load/stores.
@@ -254,10 +264,10 @@ static SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp,
 //   %23 = arith.add %20, %21
 //   %local_offsets = arith.add %22, %23
 //   %orig_offset = %block_id_y * 4x2x6x32 // consider using affine map
-//   %offsets =  orig_offset + local_offsets
+//   %offsets =  memref_offset + orig_offset + local_offsets
 static Value computeOffsets(VectorTransferOpInterface xferOp,
-                            PatternRewriter &rewriter,
-                            ArrayRef<Value> strides) {
+                            PatternRewriter &rewriter, ArrayRef<Value> strides,
+                            Value baseOffset) {
   Location loc = xferOp.getLoc();
   VectorType vectorType = xferOp.getVectorType();
   SmallVector<Value> indices(xferOp.getIndices().begin(),
@@ -315,51 +325,30 @@ static Value computeOffsets(VectorTransferOpInterface xferOp,
         arith::AddIOp::create(rewriter, loc, localOffsets, broadcasted[i]);
 
   // Compute base offset from transfer read indices
-  Value baseOffset = nullptr;
-  if (!indices.empty()) {
-    baseOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
-    for (size_t i = 0; i < indices.size(); ++i) {
-      Value strideVal = strides[i];
-      Value offsetContrib =
-          arith::MulIOp::create(rewriter, loc, indices[i], strideVal);
-      baseOffset =
-          arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
-    }
-    // Broadcast base offset to match vector shape
-    Value bcastBase = vector::BroadcastOp::create(
-        rewriter, loc, fullIndexVectorType, baseOffset);
-    localOffsets =
-        arith::AddIOp::create(rewriter, loc, bcastBase, localOffsets);
+  for (size_t i = 0; i < indices.size(); ++i) {
+    Value strideVal = strides[i];
+    Value offsetContrib =
+        arith::MulIOp::create(rewriter, loc, indices[i], strideVal);
+    baseOffset =
+        arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
   }
+  // Broadcast base offset to match vector shape
+  Value bcastBase = vector::BroadcastOp::create(
+      rewriter, loc, fullIndexVectorType, baseOffset);
+  localOffsets = arith::AddIOp::create(rewriter, loc, bcastBase, localOffsets);
   return localOffsets;
 }
 
-// Collapse memref shape to 1D
-static Value collapseMemrefTo1D(VectorTransferOpInterface xferOp,
-                                PatternRewriter &rewriter) {
+// Convert memref to i64 base pointer
+static Value memrefToIndexPtr(VectorTransferOpInterface xferOp,
+                              PatternRewriter &rewriter) {
   Location loc = xferOp.getLoc();
-
-  Value baseMemref = xferOp.getBase();
-  MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType());
-  Type elementType = memrefType.getElementType();
-
-  // Compute the total number of elements in the memref
-  MemRefType flatMemrefType;
-  if (memrefType.hasStaticShape()) {
-    auto totalElements = memrefType.getNumElements();
-    flatMemrefType = MemRefType::get({totalElements}, elementType);
-  } else {
-    flatMemrefType = MemRefType::get({ShapedType::kDynamic}, elementType);
-  }
-
-  SmallVector<ReassociationIndices> reassociation;
-  ReassociationIndices allDims =
-      llvm::to_vector(llvm::seq<int64_t>(0, memrefType.getRank()));
-  reassociation.push_back(allDims);
-
-  auto collapseOp = memref::CollapseShapeOp::create(
-      rewriter, loc, flatMemrefType, baseMemref, reassociation);
-  return collapseOp;
+  auto indexPtr = memref::ExtractAlignedPointerAsIndexOp::create(
+                      rewriter, loc, xferOp.getBase())
+                      .getResult();
+  return arith::IndexCastOp::create(rewriter, loc, rewriter.getI64Type(),
+                                    indexPtr)
+      .getResult();
 }
 
 static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp,
@@ -372,13 +361,14 @@ static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp,
   if (!memrefType)
     return rewriter.notifyMatchFailure(readOp, "Expected memref source");
 
-  SmallVector<Value> strides = computeStrides(readOp, rewriter);
-  if (strides.empty())
+  auto meta = computeMemrefMeta(readOp, rewriter);
+  if (meta.first.empty())
     return rewriter.notifyMatchFailure(readOp, "Failed to compute strides");
 
-  Value localOffsets = computeOffsets(readOp, rewriter, strides);
+  Value localOffsets =
+      computeOffsets(readOp, rewriter, meta.first, meta.second);
 
-  Value flatMemref = collapseMemrefTo1D(readOp, rewriter);
+  Value flatMemref = memrefToIndexPtr(readOp, rewriter);
 
   Value mask = vector::ConstantMaskOp::create(
       rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()),
@@ -405,11 +395,14 @@ static LogicalResult lowerToScatteredStoreOp(vector::TransferWriteOp writeOp,
   if (!memrefType)
     return rewriter.notifyMatchFailure(writeOp, "Expected memref source");
 
-  SmallVector<Value> strides = computeStrides(writeOp, rewriter);
+  auto meta = computeMemrefMeta(writeOp, rewriter);
+  if (meta.first.empty())
+    return rewriter.notifyMatchFailure(writeOp, "Failed to compute strides");
 
-  Value localOffsets = computeOffsets(writeOp, rewriter, strides);
+  Value localOffsets =
+      computeOffsets(writeOp, rewriter, meta.first, meta.second);
 
-  Value flatMemref = collapseMemrefTo1D(writeOp, rewriter);
+  Value flatMemref = memrefToIndexPtr(writeOp, rewriter);
 
   Value mask = vector::ConstantMaskOp::create(
       rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()),
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
index b373bdab80567..c4ca79af1bd9a 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
@@ -27,8 +27,9 @@ gpu.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vector
 // LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
 // LOAD-GATHER:        %[[SPLAT:.+]] = vector.broadcast {{.*}}:  index to vector<8xindex>
 // LOAD-GATHER:        %[[IDX:.+]] = arith.addi %[[SPLAT]], %[[STEP]] : vector<8xindex>
-// LOAD-GATHER:        %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
-// LOAD-GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : memref<4096xf32>, vector<8xindex>, vector<8xi1> -> vector<8xf32>
+// LOAD-GATHER:        %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x16x32xf32> -> index
+// LOAD-GATHER:        %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// LOAD-GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf32>
 
 }
 
@@ -62,8 +63,9 @@ gpu.func @load_2D_vector(%source: memref<8x16x32xf32>,
 // LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
 // LOAD-GATHER:        %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8x16xindex>
 // LOAD-GATHER:        %[[IDX:.+]] = arith.addi %[[SPLAT]], {{.*}}: vector<8x16xindex>
-// LOAD-GATHER:        %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
-// LOAD-GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : memref<4096xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
+// LOAD-GATHER:        %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x16x32xf32> -> index
+// LOAD-GATHER:        %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// LOAD-GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
 
 }
 
@@ -124,8 +126,9 @@ gpu.func @load_transposed(%source: memref<32x64xf32>,
 // LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
 // LOAD-GATHER:        %[[BCAST2:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
 // LOAD-GATHER:        %[[IDX:.+]] = arith.addi %[[BCAST2]], {{.*}}: vector<8x16xindex>
-// LOAD-GATHER:        %[[COLLAPSE:.*]] = memref.collapse_shape %arg0 {{\[\[}}0, 1{{\]\]}} : memref<32x64xf32> into memref<2048xf32>
-// LOAD-GATHER:        %[[LOAD:.*]] = xegpu.load %[[COLLAPSE]][%[[IDX]]], %[[CST]] : memref<2048xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
+// LOAD-GATHER:        %[[COLLAPSE:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<32x64xf32> -> index
+// LOAD-GATHER:        %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// LOAD-GATHER:        %[[LOAD:.*]] = xegpu.load %[[COLLAPSE_I]][%[[IDX]]], %[[CST]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
 
 }
 
@@ -164,8 +167,9 @@ gpu.func @load_dynamic_source(%source: memref<?x?x?xf32>,
 // LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
 // LOAD-GATHER:        %[[BROADIDX:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
 // LOAD-GATHER:        %[[FINALIDX:.+]] = arith.addi %[[BROADIDX]], {{.*}} : vector<8x16xindex>
-// LOAD-GATHER:        %[[COLLAPSE:.+]] = memref.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2]{{\]}} : memref<?x?x?xf32> into memref<?xf32>
-// LOAD-GATHER:        %[[RES:.+]] = xegpu.load %[[COLLAPSE]][%[[FINALIDX]]], %[[CST]] : memref<?xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
+// LOAD-GATHER:        %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<?x?x?xf32> -> index
+// LOAD-GATHER:        %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// LOAD-GATHER:        %[[RES:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[FINALIDX]]{{\]}}, %[[CST]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
 // LOAD-GATHER:        gpu.return %[[RES]] : vector<8x16xf32>
 }
 
@@ -195,8 +199,9 @@ gpu.func @load_dynamic_source2(%source: memref<?x8x16xf32>,
 // LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
 // LOAD-GATHER-DAG:    %[[BCASTIDX:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
 // LOAD-GATHER-DAG:    %[[OFFSETS:.+]] = arith.addi %[[BCASTIDX]], {{.*}} : vector<8x16xindex>
-// LOAD-GATHER-DAG:    %[[COLLAPSE:.+]] = memref.collapse_shape %arg0 {{\[}}[0, 1, 2]{{\]}} : memref<?x8x16xf32> into memref<?xf32>
-// LOAD-GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[OFFSETS]]{{\]}}, %[[CST_0]] : memref<?xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32> 
+// LOAD-GATHER-DAG:    %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %arg0 : memref<?x8x16xf32> -> index
+// LOAD-GATHER-DAG:    %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// LOAD-GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[OFFSETS]]{{\]}}, %[[CST_0]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32> 
 
 }
 
@@ -224,8 +229,9 @@ gpu.func @load_dynamic_source3(%source: memref<?x?x?x?x?xf32>,
 // LOAD-GATHER-COUNT3: arith.addi {{.*}} : vector<2x4x8x16xindex>
 // LOAD-GATHER:        %[[SPLAT:.+]] = vector.broadcast {{.*}} : index to vector<2x4x8x16xindex>
 // LOAD-GATHER:        %[[IDX:.+]] = arith.addi %[[SPLAT]], {{.*}} : vector<2x4x8x16xindex>
-// LOAD-GATHER:        %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2, 3, 4]{{\]}} : memref<?x?x?x?x?xf32> into memref<?xf32>
-// LOAD-GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : memref<?xf32>, vector<2x4x8x16xindex>, vector<2x4x8x16xi1> -> vector<2x4x8x16xf32>
+// LOAD-GATHER:        %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<?x?x?x?x?xf32> -> index
+// LOAD-GATHER:        %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// LOAD-GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<2x4x8x16xindex>, vector<2x4x8x16xi1> -> vector<2x4x8x16xf32>
 // LOAD-GATHER:        return %[[VEC]]
 }
 
@@ -254,8 +260,9 @@ gpu.func @load_high_dim_vector(%source: memref<16x32x64xf32>,
 // LOAD-GATHER-COUNT2: arith.addi {{.*}} : vector<8x16x32xindex>
 // LOAD-GATHER:        %[[BCASTOFF:.+]] = vector.broadcast {{.*}} : index to vector<8x16x32xindex>
 // LOAD-GATHER:        %[[IDX:.+]] = arith.addi %[[BCASTOFF]], {{.*}} : vector<8x16x32xindex>
-// LOAD-GATHER:        %[[COLLAPSE:.+]] = memref.collapse_shape %arg0 {{\[}}[0, 1, 2]{{\]}} : memref<16x32x64xf32> into memref<32768xf32>
-// LOAD-GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE]][%[[IDX]]], %[[CST]] : memref<32768xf32>, vector<8x16x32xindex>, vector<8x16x32xi1> -> vector<8x16x32xf32>
+// LOAD-GATHER:        %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %arg0 : memref<16x32x64xf32> -> index
+// LOAD-GATHER:        %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// LOAD-GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]][%[[IDX]]], %[[CST]] : i64, vector<8x16x32xindex>, vector<8x16x32xi1> -> vector<8x16x32xf32>
 
 }
 
@@ -283,8 +290,9 @@ gpu.func @load_transpose_f16(%source: memref<32x64xf16>,
 // LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
 // LOAD-GATHER:        %[[BCAST2:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
 // LOAD-GATHER:        %[[IDX:.+]] = arith.addi %[[BCAST2]], {{.*}}: vector<8x16xindex>
-// LOAD-GATHER:        %[[COLLAPSE:.*]] = memref.collapse_shape %arg0 {{\[\[}}0, 1{{\]\]}} : memref<32x64xf16> into memref<2048xf16>
-// LOAD-GATHER:        %[[LOAD:.*]] = xegpu.load %[[COLLAPSE]][%[[IDX]]], %[[CST]] : memref<2048xf16>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf16>
+// LOAD-GATHER:        %[[COLLAPSE:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<32x64xf16> -> index
+// LOAD-GATHER:        %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// LOAD-GATHER:        %[[LOAD:.*]] = xegpu.load %[[COLLAPSE_I]][%[[IDX]]], %[[CST]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf16>
 }
 
 // -----
@@ -396,3 +404,40 @@ gpu.func @no_load_unsupported_map(%source: memref<16x32x64xf32>,
 // LOAD-GATHER:        vector.transfer_read
 }
 
+// -----
+gpu.module @xevm_module {
+gpu.func @load_from_subview(%source: memref<4096x4096xf16>, %off1: index, %off2: index) -> vector<8xf16> {
+  %c0 = arith.constant 0.0 : f16
+  %subview = memref.subview %source[%off1, %off2] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
+  %0 = vector.transfer_read %subview[%off2, %off2], %c0
+    {in_bounds = [true]} : memref<256x256xf16, strided<[4096, 1], offset: ?>>, vector<8xf16>
+  gpu.return %0 : vector<8xf16>
+}
+
+// LOAD-ND-LABEL:  @load_from_subview(
+// LOAD-ND-SAME:   %[[SRC:.+]]: memref<4096x4096xf16>,
+// LOAD-ND-SAME:   %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// LOAD-ND:        %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>> 
+// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc
+// LOAD-ND-SAME:     %[[SUBVIEW]][%[[OFF2]], %[[OFF2]]]
+// LOAD-ND-SAME:     memref<256x256xf16, strided<[4096, 1], offset: ?>> -> !xegpu.tensor_desc<8xf16,
+// LOAD-ND-SAME:     boundary_check = false
+// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8xf16>
+// LOAD-ND:        return %[[VEC]]
+
+// LOAD-GATHER-LABEL:  @load_from_subview(
+// LOAD-GATHER-SAME:   %[[SRC:.+]]: memref<4096x4096xf16>,
+// LOAD-GATHER-SAME:   %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// LOAD-GATHER:        %[[CST:.+]] = arith.constant dense<true> : vector<8xi1>
+// LOAD-GATHER:        %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>> 
+// LOAD-GATHER:        %[[BB:.+]], %[[OFFSET:.+]],{{.*}},{{.*}} = memref.extract_strided_metadata %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> memref<f16>, index, index, index, index, index
+// LOAD-GATHER:        %[[STEP:.+]] = vector.step : vector<8xindex>
+// LOAD-GATHER:        arith.muli {{.*}} : index
+// LOAD-GATHER:        arith.addi %[[OFFSET]]{{.*}} : index
+// LOAD-GATHER:        arith.addi {{.*}} : index
+// LOAD-GATHER:        %[[SPLAT:.+]] = vector.broadcast {{.*}}:  index to vector<8xindex>
+// LOAD-GATHER:        %[[IDX:.+]] = arith.addi %[[SPLAT]], %[[STEP]] : vector<8xindex>
+// LOAD-GATHER:        %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> index
+// LOAD-GATHER:        %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// LOAD-GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf16>
+}
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
index b3f761a545ee1..fcfc9414da4f6 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
@@ -30,8 +30,9 @@ gpu.func @store_1D_vector(%vec: vector<8xf32>,
 // STORE-SCATTER-COUNT2: arith.addi {{.*}} : index
 // STORE-SCATTER-DAG:    %[[BCAST:.+]] = vector.broadcast {{.*}} : index to vector<8xindex>
 // STORE-SCATTER-DAG:    %[[IDX:.+]] = arith.addi %[[BCAST]], %{{.*}} : vector<8xindex>
-// STORE-SCATTER-DAG:    %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
-// STORE-SCATTER:       xegpu.store %[[VEC]], %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : vector<8xf32>, memref<4096xf32>, vector<8xindex>, vector<8xi1>
+// STORE-SCATTER-DAG:    %[[COLLAPSE:.+]] = memref.extract_aligned_point...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Sep 11, 2025

@llvm/pr-subscribers-mlir-gpu

Author: Dmitry Chigarev (dchigarev)

Changes

This PR fixes a case where a source memref in vector.transfer_read/write is not contiguous, which violates the memref.collapse_shape semantic that is used in the lowering.

<details><summary>An example of a failing test</summary>

gpu.module @<!-- -->xevm_module {
gpu.func @<!-- -->load_from_subview(%source: memref&lt;4096x4096xf16&gt;, %off1: index, %off2: index) -&gt; vector&lt;8xf16&gt; {
  %c0 = arith.constant 0.0 : f16
  %subview = memref.subview %source[%off1, %off2] [256, 256] [1, 1] : memref&lt;4096x4096xf16&gt; to memref&lt;256x256xf16, strided&lt;[4096, 1], offset: ?&gt;&gt;
  %0 = vector.transfer_read %subview[%off2, %off2], %c0
    {in_bounds = [true]} : memref&lt;256x256xf16, strided&lt;[4096, 1], offset: ?&gt;&gt;, vector&lt;8xf16&gt;
  gpu.return %0 : vector&lt;8xf16&gt;
}
}

Fails with:

/home/user/llvm/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir:404:8: error: 'memref.collapse_shape' op invalid source layout map or collapsing non-contiguous dims
  %0 = vector.transfer_read %subview[%off2, %off2], %c0
       ^
/home/user/llvm/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir:404:8: note: see current operation: %8 = "memref.collapse_shape"(%2) &lt;{reassociation = [[0, 1]]}&gt; : (memref&lt;256x256xf16, strided&lt;[4096, 1], offset: ?&gt;&gt;) -&gt; memref&lt;65536xf16&gt;

</details>

A suggestion was to replace memref.collapse_shape with memref.extract_aligned_pointer_as_index which is done in this PR. Since extract_aligned_pointer applied to a subview returns an original pointer without subview offsets, this PR also adds a logic to use an offset obtained from memref.extract_strided_metadata in baseOffset calculation in computeOffsets.


Patch is 27.33 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/158126.diff

3 Files Affected:

  • (modified) mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp (+52-59)
  • (modified) mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir (+61-16)
  • (modified) mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir (+61-11)
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 819c2e5973ffd..852c322cc6467 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -180,26 +180,31 @@ static void adjustStridesForPermutation(AffineMap permMap,
   strides = applyPermutation(strides, perms64);
 }
 
-// Computes memory strides for vector transfer operations, handling both
-// static and dynamic memrefs while applying permutation transformations
-// for XeGPU lowering.
-static SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp,
-                                         PatternRewriter &rewriter) {
+// Computes memory strides and a memref offset for vector transfer operations,
+// handling both static and dynamic memrefs while applying permutation
+// transformations for XeGPU lowering.
+static std::pair<SmallVector<Value>, Value>
+computeMemrefMeta(VectorTransferOpInterface xferOp, PatternRewriter &rewriter) {
   SmallVector<Value> strides;
   Value baseMemref = xferOp.getBase();
   AffineMap permMap = xferOp.getPermutationMap();
   MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType());
 
   Location loc = xferOp.getLoc();
+  Value offsetVal = nullptr;
   if (memrefType.hasStaticShape()) {
     int64_t offset;
     SmallVector<int64_t> intStrides;
     if (failed(memrefType.getStridesAndOffset(intStrides, offset)))
-      return {};
+      return {{}, offsetVal};
     // Wrap static strides as MLIR values
     for (int64_t s : intStrides)
       strides.push_back(arith::ConstantIndexOp::create(rewriter, loc, s));
-  } else {
+    if (!ShapedType::isDynamic(offset))
+      offsetVal = arith::ConstantIndexOp::create(rewriter, loc, offset);
+  }
+
+  if (strides.empty() || !offsetVal) {
     // For dynamic shape memref, use memref.extract_strided_metadata to get
     // stride values
     unsigned rank = memrefType.getRank();
@@ -220,11 +225,16 @@ static SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp,
 
     auto meta = memref::ExtractStridedMetadataOp::create(
         rewriter, loc, resultTypes, baseMemref);
-    strides.append(meta.getStrides().begin(), meta.getStrides().end());
+
+    if (strides.empty())
+      strides.append(meta.getStrides().begin(), meta.getStrides().end());
+
+    if (!offsetVal)
+      offsetVal = meta.getOffset();
   }
   // Adjust strides according to the permutation map (e.g., for transpose)
   adjustStridesForPermutation(permMap, strides);
-  return strides;
+  return {strides, offsetVal};
 }
 
 // This function compute the vectors of localOffsets for scattered load/stores.
@@ -254,10 +264,10 @@ static SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp,
 //   %23 = arith.add %20, %21
 //   %local_offsets = arith.add %22, %23
 //   %orig_offset = %block_id_y * 4x2x6x32 // consider using affine map
-//   %offsets =  orig_offset + local_offsets
+//   %offsets =  memref_offset + orig_offset + local_offsets
 static Value computeOffsets(VectorTransferOpInterface xferOp,
-                            PatternRewriter &rewriter,
-                            ArrayRef<Value> strides) {
+                            PatternRewriter &rewriter, ArrayRef<Value> strides,
+                            Value baseOffset) {
   Location loc = xferOp.getLoc();
   VectorType vectorType = xferOp.getVectorType();
   SmallVector<Value> indices(xferOp.getIndices().begin(),
@@ -315,51 +325,30 @@ static Value computeOffsets(VectorTransferOpInterface xferOp,
         arith::AddIOp::create(rewriter, loc, localOffsets, broadcasted[i]);
 
   // Compute base offset from transfer read indices
-  Value baseOffset = nullptr;
-  if (!indices.empty()) {
-    baseOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
-    for (size_t i = 0; i < indices.size(); ++i) {
-      Value strideVal = strides[i];
-      Value offsetContrib =
-          arith::MulIOp::create(rewriter, loc, indices[i], strideVal);
-      baseOffset =
-          arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
-    }
-    // Broadcast base offset to match vector shape
-    Value bcastBase = vector::BroadcastOp::create(
-        rewriter, loc, fullIndexVectorType, baseOffset);
-    localOffsets =
-        arith::AddIOp::create(rewriter, loc, bcastBase, localOffsets);
+  for (size_t i = 0; i < indices.size(); ++i) {
+    Value strideVal = strides[i];
+    Value offsetContrib =
+        arith::MulIOp::create(rewriter, loc, indices[i], strideVal);
+    baseOffset =
+        arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
   }
+  // Broadcast base offset to match vector shape
+  Value bcastBase = vector::BroadcastOp::create(
+      rewriter, loc, fullIndexVectorType, baseOffset);
+  localOffsets = arith::AddIOp::create(rewriter, loc, bcastBase, localOffsets);
   return localOffsets;
 }
 
-// Collapse memref shape to 1D
-static Value collapseMemrefTo1D(VectorTransferOpInterface xferOp,
-                                PatternRewriter &rewriter) {
+// Convert memref to i64 base pointer
+static Value memrefToIndexPtr(VectorTransferOpInterface xferOp,
+                              PatternRewriter &rewriter) {
   Location loc = xferOp.getLoc();
-
-  Value baseMemref = xferOp.getBase();
-  MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType());
-  Type elementType = memrefType.getElementType();
-
-  // Compute the total number of elements in the memref
-  MemRefType flatMemrefType;
-  if (memrefType.hasStaticShape()) {
-    auto totalElements = memrefType.getNumElements();
-    flatMemrefType = MemRefType::get({totalElements}, elementType);
-  } else {
-    flatMemrefType = MemRefType::get({ShapedType::kDynamic}, elementType);
-  }
-
-  SmallVector<ReassociationIndices> reassociation;
-  ReassociationIndices allDims =
-      llvm::to_vector(llvm::seq<int64_t>(0, memrefType.getRank()));
-  reassociation.push_back(allDims);
-
-  auto collapseOp = memref::CollapseShapeOp::create(
-      rewriter, loc, flatMemrefType, baseMemref, reassociation);
-  return collapseOp;
+  auto indexPtr = memref::ExtractAlignedPointerAsIndexOp::create(
+                      rewriter, loc, xferOp.getBase())
+                      .getResult();
+  return arith::IndexCastOp::create(rewriter, loc, rewriter.getI64Type(),
+                                    indexPtr)
+      .getResult();
 }
 
 static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp,
@@ -372,13 +361,14 @@ static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp,
   if (!memrefType)
     return rewriter.notifyMatchFailure(readOp, "Expected memref source");
 
-  SmallVector<Value> strides = computeStrides(readOp, rewriter);
-  if (strides.empty())
+  auto meta = computeMemrefMeta(readOp, rewriter);
+  if (meta.first.empty())
     return rewriter.notifyMatchFailure(readOp, "Failed to compute strides");
 
-  Value localOffsets = computeOffsets(readOp, rewriter, strides);
+  Value localOffsets =
+      computeOffsets(readOp, rewriter, meta.first, meta.second);
 
-  Value flatMemref = collapseMemrefTo1D(readOp, rewriter);
+  Value flatMemref = memrefToIndexPtr(readOp, rewriter);
 
   Value mask = vector::ConstantMaskOp::create(
       rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()),
@@ -405,11 +395,14 @@ static LogicalResult lowerToScatteredStoreOp(vector::TransferWriteOp writeOp,
   if (!memrefType)
     return rewriter.notifyMatchFailure(writeOp, "Expected memref source");
 
-  SmallVector<Value> strides = computeStrides(writeOp, rewriter);
+  auto meta = computeMemrefMeta(writeOp, rewriter);
+  if (meta.first.empty())
+    return rewriter.notifyMatchFailure(writeOp, "Failed to compute strides");
 
-  Value localOffsets = computeOffsets(writeOp, rewriter, strides);
+  Value localOffsets =
+      computeOffsets(writeOp, rewriter, meta.first, meta.second);
 
-  Value flatMemref = collapseMemrefTo1D(writeOp, rewriter);
+  Value flatMemref = memrefToIndexPtr(writeOp, rewriter);
 
   Value mask = vector::ConstantMaskOp::create(
       rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()),
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
index b373bdab80567..c4ca79af1bd9a 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
@@ -27,8 +27,9 @@ gpu.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vector
 // LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
 // LOAD-GATHER:        %[[SPLAT:.+]] = vector.broadcast {{.*}}:  index to vector<8xindex>
 // LOAD-GATHER:        %[[IDX:.+]] = arith.addi %[[SPLAT]], %[[STEP]] : vector<8xindex>
-// LOAD-GATHER:        %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
-// LOAD-GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : memref<4096xf32>, vector<8xindex>, vector<8xi1> -> vector<8xf32>
+// LOAD-GATHER:        %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x16x32xf32> -> index
+// LOAD-GATHER:        %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// LOAD-GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf32>
 
 }
 
@@ -62,8 +63,9 @@ gpu.func @load_2D_vector(%source: memref<8x16x32xf32>,
 // LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
 // LOAD-GATHER:        %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8x16xindex>
 // LOAD-GATHER:        %[[IDX:.+]] = arith.addi %[[SPLAT]], {{.*}}: vector<8x16xindex>
-// LOAD-GATHER:        %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
-// LOAD-GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : memref<4096xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
+// LOAD-GATHER:        %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x16x32xf32> -> index
+// LOAD-GATHER:        %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// LOAD-GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
 
 }
 
@@ -124,8 +126,9 @@ gpu.func @load_transposed(%source: memref<32x64xf32>,
 // LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
 // LOAD-GATHER:        %[[BCAST2:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
 // LOAD-GATHER:        %[[IDX:.+]] = arith.addi %[[BCAST2]], {{.*}}: vector<8x16xindex>
-// LOAD-GATHER:        %[[COLLAPSE:.*]] = memref.collapse_shape %arg0 {{\[\[}}0, 1{{\]\]}} : memref<32x64xf32> into memref<2048xf32>
-// LOAD-GATHER:        %[[LOAD:.*]] = xegpu.load %[[COLLAPSE]][%[[IDX]]], %[[CST]] : memref<2048xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
+// LOAD-GATHER:        %[[COLLAPSE:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<32x64xf32> -> index
+// LOAD-GATHER:        %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// LOAD-GATHER:        %[[LOAD:.*]] = xegpu.load %[[COLLAPSE_I]][%[[IDX]]], %[[CST]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
 
 }
 
@@ -164,8 +167,9 @@ gpu.func @load_dynamic_source(%source: memref<?x?x?xf32>,
 // LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
 // LOAD-GATHER:        %[[BROADIDX:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
 // LOAD-GATHER:        %[[FINALIDX:.+]] = arith.addi %[[BROADIDX]], {{.*}} : vector<8x16xindex>
-// LOAD-GATHER:        %[[COLLAPSE:.+]] = memref.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2]{{\]}} : memref<?x?x?xf32> into memref<?xf32>
-// LOAD-GATHER:        %[[RES:.+]] = xegpu.load %[[COLLAPSE]][%[[FINALIDX]]], %[[CST]] : memref<?xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
+// LOAD-GATHER:        %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<?x?x?xf32> -> index
+// LOAD-GATHER:        %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// LOAD-GATHER:        %[[RES:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[FINALIDX]]{{\]}}, %[[CST]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
 // LOAD-GATHER:        gpu.return %[[RES]] : vector<8x16xf32>
 }
 
@@ -195,8 +199,9 @@ gpu.func @load_dynamic_source2(%source: memref<?x8x16xf32>,
 // LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
 // LOAD-GATHER-DAG:    %[[BCASTIDX:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
 // LOAD-GATHER-DAG:    %[[OFFSETS:.+]] = arith.addi %[[BCASTIDX]], {{.*}} : vector<8x16xindex>
-// LOAD-GATHER-DAG:    %[[COLLAPSE:.+]] = memref.collapse_shape %arg0 {{\[}}[0, 1, 2]{{\]}} : memref<?x8x16xf32> into memref<?xf32>
-// LOAD-GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[OFFSETS]]{{\]}}, %[[CST_0]] : memref<?xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32> 
+// LOAD-GATHER-DAG:    %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %arg0 : memref<?x8x16xf32> -> index
+// LOAD-GATHER-DAG:    %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// LOAD-GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[OFFSETS]]{{\]}}, %[[CST_0]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32> 
 
 }
 
@@ -224,8 +229,9 @@ gpu.func @load_dynamic_source3(%source: memref<?x?x?x?x?xf32>,
 // LOAD-GATHER-COUNT3: arith.addi {{.*}} : vector<2x4x8x16xindex>
 // LOAD-GATHER:        %[[SPLAT:.+]] = vector.broadcast {{.*}} : index to vector<2x4x8x16xindex>
 // LOAD-GATHER:        %[[IDX:.+]] = arith.addi %[[SPLAT]], {{.*}} : vector<2x4x8x16xindex>
-// LOAD-GATHER:        %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2, 3, 4]{{\]}} : memref<?x?x?x?x?xf32> into memref<?xf32>
-// LOAD-GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : memref<?xf32>, vector<2x4x8x16xindex>, vector<2x4x8x16xi1> -> vector<2x4x8x16xf32>
+// LOAD-GATHER:        %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<?x?x?x?x?xf32> -> index
+// LOAD-GATHER:        %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// LOAD-GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<2x4x8x16xindex>, vector<2x4x8x16xi1> -> vector<2x4x8x16xf32>
 // LOAD-GATHER:        return %[[VEC]]
 }
 
@@ -254,8 +260,9 @@ gpu.func @load_high_dim_vector(%source: memref<16x32x64xf32>,
 // LOAD-GATHER-COUNT2: arith.addi {{.*}} : vector<8x16x32xindex>
 // LOAD-GATHER:        %[[BCASTOFF:.+]] = vector.broadcast {{.*}} : index to vector<8x16x32xindex>
 // LOAD-GATHER:        %[[IDX:.+]] = arith.addi %[[BCASTOFF]], {{.*}} : vector<8x16x32xindex>
-// LOAD-GATHER:        %[[COLLAPSE:.+]] = memref.collapse_shape %arg0 {{\[}}[0, 1, 2]{{\]}} : memref<16x32x64xf32> into memref<32768xf32>
-// LOAD-GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE]][%[[IDX]]], %[[CST]] : memref<32768xf32>, vector<8x16x32xindex>, vector<8x16x32xi1> -> vector<8x16x32xf32>
+// LOAD-GATHER:        %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %arg0 : memref<16x32x64xf32> -> index
+// LOAD-GATHER:        %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// LOAD-GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]][%[[IDX]]], %[[CST]] : i64, vector<8x16x32xindex>, vector<8x16x32xi1> -> vector<8x16x32xf32>
 
 }
 
@@ -283,8 +290,9 @@ gpu.func @load_transpose_f16(%source: memref<32x64xf16>,
 // LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
 // LOAD-GATHER:        %[[BCAST2:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
 // LOAD-GATHER:        %[[IDX:.+]] = arith.addi %[[BCAST2]], {{.*}}: vector<8x16xindex>
-// LOAD-GATHER:        %[[COLLAPSE:.*]] = memref.collapse_shape %arg0 {{\[\[}}0, 1{{\]\]}} : memref<32x64xf16> into memref<2048xf16>
-// LOAD-GATHER:        %[[LOAD:.*]] = xegpu.load %[[COLLAPSE]][%[[IDX]]], %[[CST]] : memref<2048xf16>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf16>
+// LOAD-GATHER:        %[[COLLAPSE:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<32x64xf16> -> index
+// LOAD-GATHER:        %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// LOAD-GATHER:        %[[LOAD:.*]] = xegpu.load %[[COLLAPSE_I]][%[[IDX]]], %[[CST]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf16>
 }
 
 // -----
@@ -396,3 +404,40 @@ gpu.func @no_load_unsupported_map(%source: memref<16x32x64xf32>,
 // LOAD-GATHER:        vector.transfer_read
 }
 
+// -----
+gpu.module @xevm_module {
+gpu.func @load_from_subview(%source: memref<4096x4096xf16>, %off1: index, %off2: index) -> vector<8xf16> {
+  %c0 = arith.constant 0.0 : f16
+  %subview = memref.subview %source[%off1, %off2] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
+  %0 = vector.transfer_read %subview[%off2, %off2], %c0
+    {in_bounds = [true]} : memref<256x256xf16, strided<[4096, 1], offset: ?>>, vector<8xf16>
+  gpu.return %0 : vector<8xf16>
+}
+
+// LOAD-ND-LABEL:  @load_from_subview(
+// LOAD-ND-SAME:   %[[SRC:.+]]: memref<4096x4096xf16>,
+// LOAD-ND-SAME:   %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// LOAD-ND:        %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>> 
+// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc
+// LOAD-ND-SAME:     %[[SUBVIEW]][%[[OFF2]], %[[OFF2]]]
+// LOAD-ND-SAME:     memref<256x256xf16, strided<[4096, 1], offset: ?>> -> !xegpu.tensor_desc<8xf16,
+// LOAD-ND-SAME:     boundary_check = false
+// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8xf16>
+// LOAD-ND:        return %[[VEC]]
+
+// LOAD-GATHER-LABEL:  @load_from_subview(
+// LOAD-GATHER-SAME:   %[[SRC:.+]]: memref<4096x4096xf16>,
+// LOAD-GATHER-SAME:   %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// LOAD-GATHER:        %[[CST:.+]] = arith.constant dense<true> : vector<8xi1>
+// LOAD-GATHER:        %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>> 
+// LOAD-GATHER:        %[[BB:.+]], %[[OFFSET:.+]],{{.*}},{{.*}} = memref.extract_strided_metadata %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> memref<f16>, index, index, index, index, index
+// LOAD-GATHER:        %[[STEP:.+]] = vector.step : vector<8xindex>
+// LOAD-GATHER:        arith.muli {{.*}} : index
+// LOAD-GATHER:        arith.addi %[[OFFSET]]{{.*}} : index
+// LOAD-GATHER:        arith.addi {{.*}} : index
+// LOAD-GATHER:        %[[SPLAT:.+]] = vector.broadcast {{.*}}:  index to vector<8xindex>
+// LOAD-GATHER:        %[[IDX:.+]] = arith.addi %[[SPLAT]], %[[STEP]] : vector<8xindex>
+// LOAD-GATHER:        %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> index
+// LOAD-GATHER:        %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// LOAD-GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf16>
+}
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
index b3f761a545ee1..fcfc9414da4f6 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
@@ -30,8 +30,9 @@ gpu.func @store_1D_vector(%vec: vector<8xf32>,
 // STORE-SCATTER-COUNT2: arith.addi {{.*}} : index
 // STORE-SCATTER-DAG:    %[[BCAST:.+]] = vector.broadcast {{.*}} : index to vector<8xindex>
 // STORE-SCATTER-DAG:    %[[IDX:.+]] = arith.addi %[[BCAST]], %{{.*}} : vector<8xindex>
-// STORE-SCATTER-DAG:    %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
-// STORE-SCATTER:       xegpu.store %[[VEC]], %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : vector<8xf32>, memref<4096xf32>, vector<8xindex>, vector<8xi1>
+// STORE-SCATTER-DAG:    %[[COLLAPSE:.+]] = memref.extract_aligned_point...
[truncated]

@Jianhui-Li Jianhui-Li merged commit 40e85fc into llvm:main Sep 12, 2025
13 checks passed
@github-actions
Copy link

@dchigarev Congratulations on having your first Pull Request (PR) merged into the LLVM Project!

Your changes will be combined with recent changes from other authors, then tested by our build bots. If there is a problem with a build, you may receive a report in an email or a comment on this PR.

Please check whether problems have been caused by your change specifically, as the builds can include changes from many authors. It is not uncommon for your change to be included in a build that fails due to someone else's changes, or infrastructure issues.

How to do this, and the rest of the post-merge process, is covered in detail here.

If your change does cause a problem, it may be reverted, or you can revert it yourself. This is a normal part of LLVM development. You can fix your changes and open a new PR to merge them again.

If you don't get any reports, no action is required from you. Your changes are working as expected, well done!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants