-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[MLIR][XeGPU][VectorToXeGPU] Fix transfer_read/write cases with non-contiguous memrefs #158126
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…ontiguous memrefs Signed-off-by: dchigarev <[email protected]>
|
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be notified. If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers. If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
| static Value computeOffsets(VectorTransferOpInterface xferOp, | ||
| PatternRewriter &rewriter, | ||
| ArrayRef<Value> strides) { | ||
| PatternRewriter &rewriter, ArrayRef<Value> strides, | ||
| Value baseOffset) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
computeOffsets now takes a baseOffset obtained from memref.extract_strided_metadata in computeMemrefMeta (ex computeStrides)
| Value baseOffset = nullptr; | ||
| if (!indices.empty()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
made the branch unconditional since we always want to consider the baseOffset
Signed-off-by: dchigarev <[email protected]>
| // handling both static and dynamic memrefs while applying permutation | ||
| // transformations for XeGPU lowering. | ||
| static std::pair<SmallVector<Value>, Value> | ||
| computeMemrefMeta(VectorTransferOpInterface xferOp, PatternRewriter &rewriter) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A function that calls memref.extract_strided_metadata now also returns memref's offset together with the strides
| static Value memrefToIndexPtr(VectorTransferOpInterface xferOp, | ||
| PatternRewriter &rewriter) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
memref.collapse_shape -> memref.extract_aligned_pointer_as_index + arith.index_cast index -> i64
|
@Jianhui-Li for review |
|
overall looks good. Thanks! |
|
@llvm/pr-subscribers-mlir Author: Dmitry Chigarev (dchigarev) ChangesThis PR fixes a case where a source memref in <details><summary>An example of a failing test</summary> gpu.module @<!-- -->xevm_module {
gpu.func @<!-- -->load_from_subview(%source: memref<4096x4096xf16>, %off1: index, %off2: index) -> vector<8xf16> {
%c0 = arith.constant 0.0 : f16
%subview = memref.subview %source[%off1, %off2] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
%0 = vector.transfer_read %subview[%off2, %off2], %c0
{in_bounds = [true]} : memref<256x256xf16, strided<[4096, 1], offset: ?>>, vector<8xf16>
gpu.return %0 : vector<8xf16>
}
}Fails with: </details> A suggestion was to replace Patch is 27.33 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/158126.diff 3 Files Affected:
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 819c2e5973ffd..852c322cc6467 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -180,26 +180,31 @@ static void adjustStridesForPermutation(AffineMap permMap,
strides = applyPermutation(strides, perms64);
}
-// Computes memory strides for vector transfer operations, handling both
-// static and dynamic memrefs while applying permutation transformations
-// for XeGPU lowering.
-static SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp,
- PatternRewriter &rewriter) {
+// Computes memory strides and a memref offset for vector transfer operations,
+// handling both static and dynamic memrefs while applying permutation
+// transformations for XeGPU lowering.
+static std::pair<SmallVector<Value>, Value>
+computeMemrefMeta(VectorTransferOpInterface xferOp, PatternRewriter &rewriter) {
SmallVector<Value> strides;
Value baseMemref = xferOp.getBase();
AffineMap permMap = xferOp.getPermutationMap();
MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType());
Location loc = xferOp.getLoc();
+ Value offsetVal = nullptr;
if (memrefType.hasStaticShape()) {
int64_t offset;
SmallVector<int64_t> intStrides;
if (failed(memrefType.getStridesAndOffset(intStrides, offset)))
- return {};
+ return {{}, offsetVal};
// Wrap static strides as MLIR values
for (int64_t s : intStrides)
strides.push_back(arith::ConstantIndexOp::create(rewriter, loc, s));
- } else {
+ if (!ShapedType::isDynamic(offset))
+ offsetVal = arith::ConstantIndexOp::create(rewriter, loc, offset);
+ }
+
+ if (strides.empty() || !offsetVal) {
// For dynamic shape memref, use memref.extract_strided_metadata to get
// stride values
unsigned rank = memrefType.getRank();
@@ -220,11 +225,16 @@ static SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp,
auto meta = memref::ExtractStridedMetadataOp::create(
rewriter, loc, resultTypes, baseMemref);
- strides.append(meta.getStrides().begin(), meta.getStrides().end());
+
+ if (strides.empty())
+ strides.append(meta.getStrides().begin(), meta.getStrides().end());
+
+ if (!offsetVal)
+ offsetVal = meta.getOffset();
}
// Adjust strides according to the permutation map (e.g., for transpose)
adjustStridesForPermutation(permMap, strides);
- return strides;
+ return {strides, offsetVal};
}
// This function compute the vectors of localOffsets for scattered load/stores.
@@ -254,10 +264,10 @@ static SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp,
// %23 = arith.add %20, %21
// %local_offsets = arith.add %22, %23
// %orig_offset = %block_id_y * 4x2x6x32 // consider using affine map
-// %offsets = orig_offset + local_offsets
+// %offsets = memref_offset + orig_offset + local_offsets
static Value computeOffsets(VectorTransferOpInterface xferOp,
- PatternRewriter &rewriter,
- ArrayRef<Value> strides) {
+ PatternRewriter &rewriter, ArrayRef<Value> strides,
+ Value baseOffset) {
Location loc = xferOp.getLoc();
VectorType vectorType = xferOp.getVectorType();
SmallVector<Value> indices(xferOp.getIndices().begin(),
@@ -315,51 +325,30 @@ static Value computeOffsets(VectorTransferOpInterface xferOp,
arith::AddIOp::create(rewriter, loc, localOffsets, broadcasted[i]);
// Compute base offset from transfer read indices
- Value baseOffset = nullptr;
- if (!indices.empty()) {
- baseOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
- for (size_t i = 0; i < indices.size(); ++i) {
- Value strideVal = strides[i];
- Value offsetContrib =
- arith::MulIOp::create(rewriter, loc, indices[i], strideVal);
- baseOffset =
- arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
- }
- // Broadcast base offset to match vector shape
- Value bcastBase = vector::BroadcastOp::create(
- rewriter, loc, fullIndexVectorType, baseOffset);
- localOffsets =
- arith::AddIOp::create(rewriter, loc, bcastBase, localOffsets);
+ for (size_t i = 0; i < indices.size(); ++i) {
+ Value strideVal = strides[i];
+ Value offsetContrib =
+ arith::MulIOp::create(rewriter, loc, indices[i], strideVal);
+ baseOffset =
+ arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
}
+ // Broadcast base offset to match vector shape
+ Value bcastBase = vector::BroadcastOp::create(
+ rewriter, loc, fullIndexVectorType, baseOffset);
+ localOffsets = arith::AddIOp::create(rewriter, loc, bcastBase, localOffsets);
return localOffsets;
}
-// Collapse memref shape to 1D
-static Value collapseMemrefTo1D(VectorTransferOpInterface xferOp,
- PatternRewriter &rewriter) {
+// Convert memref to i64 base pointer
+static Value memrefToIndexPtr(VectorTransferOpInterface xferOp,
+ PatternRewriter &rewriter) {
Location loc = xferOp.getLoc();
-
- Value baseMemref = xferOp.getBase();
- MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType());
- Type elementType = memrefType.getElementType();
-
- // Compute the total number of elements in the memref
- MemRefType flatMemrefType;
- if (memrefType.hasStaticShape()) {
- auto totalElements = memrefType.getNumElements();
- flatMemrefType = MemRefType::get({totalElements}, elementType);
- } else {
- flatMemrefType = MemRefType::get({ShapedType::kDynamic}, elementType);
- }
-
- SmallVector<ReassociationIndices> reassociation;
- ReassociationIndices allDims =
- llvm::to_vector(llvm::seq<int64_t>(0, memrefType.getRank()));
- reassociation.push_back(allDims);
-
- auto collapseOp = memref::CollapseShapeOp::create(
- rewriter, loc, flatMemrefType, baseMemref, reassociation);
- return collapseOp;
+ auto indexPtr = memref::ExtractAlignedPointerAsIndexOp::create(
+ rewriter, loc, xferOp.getBase())
+ .getResult();
+ return arith::IndexCastOp::create(rewriter, loc, rewriter.getI64Type(),
+ indexPtr)
+ .getResult();
}
static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp,
@@ -372,13 +361,14 @@ static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp,
if (!memrefType)
return rewriter.notifyMatchFailure(readOp, "Expected memref source");
- SmallVector<Value> strides = computeStrides(readOp, rewriter);
- if (strides.empty())
+ auto meta = computeMemrefMeta(readOp, rewriter);
+ if (meta.first.empty())
return rewriter.notifyMatchFailure(readOp, "Failed to compute strides");
- Value localOffsets = computeOffsets(readOp, rewriter, strides);
+ Value localOffsets =
+ computeOffsets(readOp, rewriter, meta.first, meta.second);
- Value flatMemref = collapseMemrefTo1D(readOp, rewriter);
+ Value flatMemref = memrefToIndexPtr(readOp, rewriter);
Value mask = vector::ConstantMaskOp::create(
rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()),
@@ -405,11 +395,14 @@ static LogicalResult lowerToScatteredStoreOp(vector::TransferWriteOp writeOp,
if (!memrefType)
return rewriter.notifyMatchFailure(writeOp, "Expected memref source");
- SmallVector<Value> strides = computeStrides(writeOp, rewriter);
+ auto meta = computeMemrefMeta(writeOp, rewriter);
+ if (meta.first.empty())
+ return rewriter.notifyMatchFailure(writeOp, "Failed to compute strides");
- Value localOffsets = computeOffsets(writeOp, rewriter, strides);
+ Value localOffsets =
+ computeOffsets(writeOp, rewriter, meta.first, meta.second);
- Value flatMemref = collapseMemrefTo1D(writeOp, rewriter);
+ Value flatMemref = memrefToIndexPtr(writeOp, rewriter);
Value mask = vector::ConstantMaskOp::create(
rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()),
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
index b373bdab80567..c4ca79af1bd9a 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
@@ -27,8 +27,9 @@ gpu.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vector
// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
// LOAD-GATHER: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8xindex>
// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[SPLAT]], %[[STEP]] : vector<8xindex>
-// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
-// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : memref<4096xf32>, vector<8xindex>, vector<8xi1> -> vector<8xf32>
+// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x16x32xf32> -> index
+// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf32>
}
@@ -62,8 +63,9 @@ gpu.func @load_2D_vector(%source: memref<8x16x32xf32>,
// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
// LOAD-GATHER: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8x16xindex>
// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[SPLAT]], {{.*}}: vector<8x16xindex>
-// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
-// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : memref<4096xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
+// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x16x32xf32> -> index
+// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
}
@@ -124,8 +126,9 @@ gpu.func @load_transposed(%source: memref<32x64xf32>,
// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
// LOAD-GATHER: %[[BCAST2:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[BCAST2]], {{.*}}: vector<8x16xindex>
-// LOAD-GATHER: %[[COLLAPSE:.*]] = memref.collapse_shape %arg0 {{\[\[}}0, 1{{\]\]}} : memref<32x64xf32> into memref<2048xf32>
-// LOAD-GATHER: %[[LOAD:.*]] = xegpu.load %[[COLLAPSE]][%[[IDX]]], %[[CST]] : memref<2048xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
+// LOAD-GATHER: %[[COLLAPSE:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<32x64xf32> -> index
+// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// LOAD-GATHER: %[[LOAD:.*]] = xegpu.load %[[COLLAPSE_I]][%[[IDX]]], %[[CST]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
}
@@ -164,8 +167,9 @@ gpu.func @load_dynamic_source(%source: memref<?x?x?xf32>,
// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
// LOAD-GATHER: %[[BROADIDX:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
// LOAD-GATHER: %[[FINALIDX:.+]] = arith.addi %[[BROADIDX]], {{.*}} : vector<8x16xindex>
-// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2]{{\]}} : memref<?x?x?xf32> into memref<?xf32>
-// LOAD-GATHER: %[[RES:.+]] = xegpu.load %[[COLLAPSE]][%[[FINALIDX]]], %[[CST]] : memref<?xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
+// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<?x?x?xf32> -> index
+// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// LOAD-GATHER: %[[RES:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[FINALIDX]]{{\]}}, %[[CST]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
// LOAD-GATHER: gpu.return %[[RES]] : vector<8x16xf32>
}
@@ -195,8 +199,9 @@ gpu.func @load_dynamic_source2(%source: memref<?x8x16xf32>,
// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
// LOAD-GATHER-DAG: %[[BCASTIDX:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
// LOAD-GATHER-DAG: %[[OFFSETS:.+]] = arith.addi %[[BCASTIDX]], {{.*}} : vector<8x16xindex>
-// LOAD-GATHER-DAG: %[[COLLAPSE:.+]] = memref.collapse_shape %arg0 {{\[}}[0, 1, 2]{{\]}} : memref<?x8x16xf32> into memref<?xf32>
-// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[OFFSETS]]{{\]}}, %[[CST_0]] : memref<?xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
+// LOAD-GATHER-DAG: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %arg0 : memref<?x8x16xf32> -> index
+// LOAD-GATHER-DAG: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[OFFSETS]]{{\]}}, %[[CST_0]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
}
@@ -224,8 +229,9 @@ gpu.func @load_dynamic_source3(%source: memref<?x?x?x?x?xf32>,
// LOAD-GATHER-COUNT3: arith.addi {{.*}} : vector<2x4x8x16xindex>
// LOAD-GATHER: %[[SPLAT:.+]] = vector.broadcast {{.*}} : index to vector<2x4x8x16xindex>
// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[SPLAT]], {{.*}} : vector<2x4x8x16xindex>
-// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2, 3, 4]{{\]}} : memref<?x?x?x?x?xf32> into memref<?xf32>
-// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : memref<?xf32>, vector<2x4x8x16xindex>, vector<2x4x8x16xi1> -> vector<2x4x8x16xf32>
+// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<?x?x?x?x?xf32> -> index
+// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<2x4x8x16xindex>, vector<2x4x8x16xi1> -> vector<2x4x8x16xf32>
// LOAD-GATHER: return %[[VEC]]
}
@@ -254,8 +260,9 @@ gpu.func @load_high_dim_vector(%source: memref<16x32x64xf32>,
// LOAD-GATHER-COUNT2: arith.addi {{.*}} : vector<8x16x32xindex>
// LOAD-GATHER: %[[BCASTOFF:.+]] = vector.broadcast {{.*}} : index to vector<8x16x32xindex>
// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[BCASTOFF]], {{.*}} : vector<8x16x32xindex>
-// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %arg0 {{\[}}[0, 1, 2]{{\]}} : memref<16x32x64xf32> into memref<32768xf32>
-// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]][%[[IDX]]], %[[CST]] : memref<32768xf32>, vector<8x16x32xindex>, vector<8x16x32xi1> -> vector<8x16x32xf32>
+// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %arg0 : memref<16x32x64xf32> -> index
+// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]][%[[IDX]]], %[[CST]] : i64, vector<8x16x32xindex>, vector<8x16x32xi1> -> vector<8x16x32xf32>
}
@@ -283,8 +290,9 @@ gpu.func @load_transpose_f16(%source: memref<32x64xf16>,
// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
// LOAD-GATHER: %[[BCAST2:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[BCAST2]], {{.*}}: vector<8x16xindex>
-// LOAD-GATHER: %[[COLLAPSE:.*]] = memref.collapse_shape %arg0 {{\[\[}}0, 1{{\]\]}} : memref<32x64xf16> into memref<2048xf16>
-// LOAD-GATHER: %[[LOAD:.*]] = xegpu.load %[[COLLAPSE]][%[[IDX]]], %[[CST]] : memref<2048xf16>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf16>
+// LOAD-GATHER: %[[COLLAPSE:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<32x64xf16> -> index
+// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// LOAD-GATHER: %[[LOAD:.*]] = xegpu.load %[[COLLAPSE_I]][%[[IDX]]], %[[CST]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf16>
}
// -----
@@ -396,3 +404,40 @@ gpu.func @no_load_unsupported_map(%source: memref<16x32x64xf32>,
// LOAD-GATHER: vector.transfer_read
}
+// -----
+gpu.module @xevm_module {
+gpu.func @load_from_subview(%source: memref<4096x4096xf16>, %off1: index, %off2: index) -> vector<8xf16> {
+ %c0 = arith.constant 0.0 : f16
+ %subview = memref.subview %source[%off1, %off2] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
+ %0 = vector.transfer_read %subview[%off2, %off2], %c0
+ {in_bounds = [true]} : memref<256x256xf16, strided<[4096, 1], offset: ?>>, vector<8xf16>
+ gpu.return %0 : vector<8xf16>
+}
+
+// LOAD-ND-LABEL: @load_from_subview(
+// LOAD-ND-SAME: %[[SRC:.+]]: memref<4096x4096xf16>,
+// LOAD-ND-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// LOAD-ND: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
+// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
+// LOAD-ND-SAME: %[[SUBVIEW]][%[[OFF2]], %[[OFF2]]]
+// LOAD-ND-SAME: memref<256x256xf16, strided<[4096, 1], offset: ?>> -> !xegpu.tensor_desc<8xf16,
+// LOAD-ND-SAME: boundary_check = false
+// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8xf16>
+// LOAD-ND: return %[[VEC]]
+
+// LOAD-GATHER-LABEL: @load_from_subview(
+// LOAD-GATHER-SAME: %[[SRC:.+]]: memref<4096x4096xf16>,
+// LOAD-GATHER-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// LOAD-GATHER: %[[CST:.+]] = arith.constant dense<true> : vector<8xi1>
+// LOAD-GATHER: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
+// LOAD-GATHER: %[[BB:.+]], %[[OFFSET:.+]],{{.*}},{{.*}} = memref.extract_strided_metadata %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> memref<f16>, index, index, index, index, index
+// LOAD-GATHER: %[[STEP:.+]] = vector.step : vector<8xindex>
+// LOAD-GATHER: arith.muli {{.*}} : index
+// LOAD-GATHER: arith.addi %[[OFFSET]]{{.*}} : index
+// LOAD-GATHER: arith.addi {{.*}} : index
+// LOAD-GATHER: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8xindex>
+// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[SPLAT]], %[[STEP]] : vector<8xindex>
+// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> index
+// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf16>
+}
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
index b3f761a545ee1..fcfc9414da4f6 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
@@ -30,8 +30,9 @@ gpu.func @store_1D_vector(%vec: vector<8xf32>,
// STORE-SCATTER-COUNT2: arith.addi {{.*}} : index
// STORE-SCATTER-DAG: %[[BCAST:.+]] = vector.broadcast {{.*}} : index to vector<8xindex>
// STORE-SCATTER-DAG: %[[IDX:.+]] = arith.addi %[[BCAST]], %{{.*}} : vector<8xindex>
-// STORE-SCATTER-DAG: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
-// STORE-SCATTER: xegpu.store %[[VEC]], %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : vector<8xf32>, memref<4096xf32>, vector<8xindex>, vector<8xi1>
+// STORE-SCATTER-DAG: %[[COLLAPSE:.+]] = memref.extract_aligned_point...
[truncated]
|
|
@llvm/pr-subscribers-mlir-gpu Author: Dmitry Chigarev (dchigarev) ChangesThis PR fixes a case where a source memref in <details><summary>An example of a failing test</summary> gpu.module @<!-- -->xevm_module {
gpu.func @<!-- -->load_from_subview(%source: memref<4096x4096xf16>, %off1: index, %off2: index) -> vector<8xf16> {
%c0 = arith.constant 0.0 : f16
%subview = memref.subview %source[%off1, %off2] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
%0 = vector.transfer_read %subview[%off2, %off2], %c0
{in_bounds = [true]} : memref<256x256xf16, strided<[4096, 1], offset: ?>>, vector<8xf16>
gpu.return %0 : vector<8xf16>
}
}Fails with: </details> A suggestion was to replace Patch is 27.33 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/158126.diff 3 Files Affected:
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 819c2e5973ffd..852c322cc6467 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -180,26 +180,31 @@ static void adjustStridesForPermutation(AffineMap permMap,
strides = applyPermutation(strides, perms64);
}
-// Computes memory strides for vector transfer operations, handling both
-// static and dynamic memrefs while applying permutation transformations
-// for XeGPU lowering.
-static SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp,
- PatternRewriter &rewriter) {
+// Computes memory strides and a memref offset for vector transfer operations,
+// handling both static and dynamic memrefs while applying permutation
+// transformations for XeGPU lowering.
+static std::pair<SmallVector<Value>, Value>
+computeMemrefMeta(VectorTransferOpInterface xferOp, PatternRewriter &rewriter) {
SmallVector<Value> strides;
Value baseMemref = xferOp.getBase();
AffineMap permMap = xferOp.getPermutationMap();
MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType());
Location loc = xferOp.getLoc();
+ Value offsetVal = nullptr;
if (memrefType.hasStaticShape()) {
int64_t offset;
SmallVector<int64_t> intStrides;
if (failed(memrefType.getStridesAndOffset(intStrides, offset)))
- return {};
+ return {{}, offsetVal};
// Wrap static strides as MLIR values
for (int64_t s : intStrides)
strides.push_back(arith::ConstantIndexOp::create(rewriter, loc, s));
- } else {
+ if (!ShapedType::isDynamic(offset))
+ offsetVal = arith::ConstantIndexOp::create(rewriter, loc, offset);
+ }
+
+ if (strides.empty() || !offsetVal) {
// For dynamic shape memref, use memref.extract_strided_metadata to get
// stride values
unsigned rank = memrefType.getRank();
@@ -220,11 +225,16 @@ static SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp,
auto meta = memref::ExtractStridedMetadataOp::create(
rewriter, loc, resultTypes, baseMemref);
- strides.append(meta.getStrides().begin(), meta.getStrides().end());
+
+ if (strides.empty())
+ strides.append(meta.getStrides().begin(), meta.getStrides().end());
+
+ if (!offsetVal)
+ offsetVal = meta.getOffset();
}
// Adjust strides according to the permutation map (e.g., for transpose)
adjustStridesForPermutation(permMap, strides);
- return strides;
+ return {strides, offsetVal};
}
// This function compute the vectors of localOffsets for scattered load/stores.
@@ -254,10 +264,10 @@ static SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp,
// %23 = arith.add %20, %21
// %local_offsets = arith.add %22, %23
// %orig_offset = %block_id_y * 4x2x6x32 // consider using affine map
-// %offsets = orig_offset + local_offsets
+// %offsets = memref_offset + orig_offset + local_offsets
static Value computeOffsets(VectorTransferOpInterface xferOp,
- PatternRewriter &rewriter,
- ArrayRef<Value> strides) {
+ PatternRewriter &rewriter, ArrayRef<Value> strides,
+ Value baseOffset) {
Location loc = xferOp.getLoc();
VectorType vectorType = xferOp.getVectorType();
SmallVector<Value> indices(xferOp.getIndices().begin(),
@@ -315,51 +325,30 @@ static Value computeOffsets(VectorTransferOpInterface xferOp,
arith::AddIOp::create(rewriter, loc, localOffsets, broadcasted[i]);
// Compute base offset from transfer read indices
- Value baseOffset = nullptr;
- if (!indices.empty()) {
- baseOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
- for (size_t i = 0; i < indices.size(); ++i) {
- Value strideVal = strides[i];
- Value offsetContrib =
- arith::MulIOp::create(rewriter, loc, indices[i], strideVal);
- baseOffset =
- arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
- }
- // Broadcast base offset to match vector shape
- Value bcastBase = vector::BroadcastOp::create(
- rewriter, loc, fullIndexVectorType, baseOffset);
- localOffsets =
- arith::AddIOp::create(rewriter, loc, bcastBase, localOffsets);
+ for (size_t i = 0; i < indices.size(); ++i) {
+ Value strideVal = strides[i];
+ Value offsetContrib =
+ arith::MulIOp::create(rewriter, loc, indices[i], strideVal);
+ baseOffset =
+ arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
}
+ // Broadcast base offset to match vector shape
+ Value bcastBase = vector::BroadcastOp::create(
+ rewriter, loc, fullIndexVectorType, baseOffset);
+ localOffsets = arith::AddIOp::create(rewriter, loc, bcastBase, localOffsets);
return localOffsets;
}
-// Collapse memref shape to 1D
-static Value collapseMemrefTo1D(VectorTransferOpInterface xferOp,
- PatternRewriter &rewriter) {
+// Convert memref to i64 base pointer
+static Value memrefToIndexPtr(VectorTransferOpInterface xferOp,
+ PatternRewriter &rewriter) {
Location loc = xferOp.getLoc();
-
- Value baseMemref = xferOp.getBase();
- MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType());
- Type elementType = memrefType.getElementType();
-
- // Compute the total number of elements in the memref
- MemRefType flatMemrefType;
- if (memrefType.hasStaticShape()) {
- auto totalElements = memrefType.getNumElements();
- flatMemrefType = MemRefType::get({totalElements}, elementType);
- } else {
- flatMemrefType = MemRefType::get({ShapedType::kDynamic}, elementType);
- }
-
- SmallVector<ReassociationIndices> reassociation;
- ReassociationIndices allDims =
- llvm::to_vector(llvm::seq<int64_t>(0, memrefType.getRank()));
- reassociation.push_back(allDims);
-
- auto collapseOp = memref::CollapseShapeOp::create(
- rewriter, loc, flatMemrefType, baseMemref, reassociation);
- return collapseOp;
+ auto indexPtr = memref::ExtractAlignedPointerAsIndexOp::create(
+ rewriter, loc, xferOp.getBase())
+ .getResult();
+ return arith::IndexCastOp::create(rewriter, loc, rewriter.getI64Type(),
+ indexPtr)
+ .getResult();
}
static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp,
@@ -372,13 +361,14 @@ static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp,
if (!memrefType)
return rewriter.notifyMatchFailure(readOp, "Expected memref source");
- SmallVector<Value> strides = computeStrides(readOp, rewriter);
- if (strides.empty())
+ auto meta = computeMemrefMeta(readOp, rewriter);
+ if (meta.first.empty())
return rewriter.notifyMatchFailure(readOp, "Failed to compute strides");
- Value localOffsets = computeOffsets(readOp, rewriter, strides);
+ Value localOffsets =
+ computeOffsets(readOp, rewriter, meta.first, meta.second);
- Value flatMemref = collapseMemrefTo1D(readOp, rewriter);
+ Value flatMemref = memrefToIndexPtr(readOp, rewriter);
Value mask = vector::ConstantMaskOp::create(
rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()),
@@ -405,11 +395,14 @@ static LogicalResult lowerToScatteredStoreOp(vector::TransferWriteOp writeOp,
if (!memrefType)
return rewriter.notifyMatchFailure(writeOp, "Expected memref source");
- SmallVector<Value> strides = computeStrides(writeOp, rewriter);
+ auto meta = computeMemrefMeta(writeOp, rewriter);
+ if (meta.first.empty())
+ return rewriter.notifyMatchFailure(writeOp, "Failed to compute strides");
- Value localOffsets = computeOffsets(writeOp, rewriter, strides);
+ Value localOffsets =
+ computeOffsets(writeOp, rewriter, meta.first, meta.second);
- Value flatMemref = collapseMemrefTo1D(writeOp, rewriter);
+ Value flatMemref = memrefToIndexPtr(writeOp, rewriter);
Value mask = vector::ConstantMaskOp::create(
rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()),
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
index b373bdab80567..c4ca79af1bd9a 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
@@ -27,8 +27,9 @@ gpu.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vector
// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
// LOAD-GATHER: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8xindex>
// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[SPLAT]], %[[STEP]] : vector<8xindex>
-// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
-// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : memref<4096xf32>, vector<8xindex>, vector<8xi1> -> vector<8xf32>
+// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x16x32xf32> -> index
+// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf32>
}
@@ -62,8 +63,9 @@ gpu.func @load_2D_vector(%source: memref<8x16x32xf32>,
// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
// LOAD-GATHER: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8x16xindex>
// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[SPLAT]], {{.*}}: vector<8x16xindex>
-// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
-// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : memref<4096xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
+// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x16x32xf32> -> index
+// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
}
@@ -124,8 +126,9 @@ gpu.func @load_transposed(%source: memref<32x64xf32>,
// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
// LOAD-GATHER: %[[BCAST2:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[BCAST2]], {{.*}}: vector<8x16xindex>
-// LOAD-GATHER: %[[COLLAPSE:.*]] = memref.collapse_shape %arg0 {{\[\[}}0, 1{{\]\]}} : memref<32x64xf32> into memref<2048xf32>
-// LOAD-GATHER: %[[LOAD:.*]] = xegpu.load %[[COLLAPSE]][%[[IDX]]], %[[CST]] : memref<2048xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
+// LOAD-GATHER: %[[COLLAPSE:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<32x64xf32> -> index
+// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// LOAD-GATHER: %[[LOAD:.*]] = xegpu.load %[[COLLAPSE_I]][%[[IDX]]], %[[CST]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
}
@@ -164,8 +167,9 @@ gpu.func @load_dynamic_source(%source: memref<?x?x?xf32>,
// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
// LOAD-GATHER: %[[BROADIDX:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
// LOAD-GATHER: %[[FINALIDX:.+]] = arith.addi %[[BROADIDX]], {{.*}} : vector<8x16xindex>
-// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2]{{\]}} : memref<?x?x?xf32> into memref<?xf32>
-// LOAD-GATHER: %[[RES:.+]] = xegpu.load %[[COLLAPSE]][%[[FINALIDX]]], %[[CST]] : memref<?xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
+// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<?x?x?xf32> -> index
+// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// LOAD-GATHER: %[[RES:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[FINALIDX]]{{\]}}, %[[CST]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
// LOAD-GATHER: gpu.return %[[RES]] : vector<8x16xf32>
}
@@ -195,8 +199,9 @@ gpu.func @load_dynamic_source2(%source: memref<?x8x16xf32>,
// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
// LOAD-GATHER-DAG: %[[BCASTIDX:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
// LOAD-GATHER-DAG: %[[OFFSETS:.+]] = arith.addi %[[BCASTIDX]], {{.*}} : vector<8x16xindex>
-// LOAD-GATHER-DAG: %[[COLLAPSE:.+]] = memref.collapse_shape %arg0 {{\[}}[0, 1, 2]{{\]}} : memref<?x8x16xf32> into memref<?xf32>
-// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[OFFSETS]]{{\]}}, %[[CST_0]] : memref<?xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
+// LOAD-GATHER-DAG: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %arg0 : memref<?x8x16xf32> -> index
+// LOAD-GATHER-DAG: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[OFFSETS]]{{\]}}, %[[CST_0]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
}
@@ -224,8 +229,9 @@ gpu.func @load_dynamic_source3(%source: memref<?x?x?x?x?xf32>,
// LOAD-GATHER-COUNT3: arith.addi {{.*}} : vector<2x4x8x16xindex>
// LOAD-GATHER: %[[SPLAT:.+]] = vector.broadcast {{.*}} : index to vector<2x4x8x16xindex>
// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[SPLAT]], {{.*}} : vector<2x4x8x16xindex>
-// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2, 3, 4]{{\]}} : memref<?x?x?x?x?xf32> into memref<?xf32>
-// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : memref<?xf32>, vector<2x4x8x16xindex>, vector<2x4x8x16xi1> -> vector<2x4x8x16xf32>
+// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<?x?x?x?x?xf32> -> index
+// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<2x4x8x16xindex>, vector<2x4x8x16xi1> -> vector<2x4x8x16xf32>
// LOAD-GATHER: return %[[VEC]]
}
@@ -254,8 +260,9 @@ gpu.func @load_high_dim_vector(%source: memref<16x32x64xf32>,
// LOAD-GATHER-COUNT2: arith.addi {{.*}} : vector<8x16x32xindex>
// LOAD-GATHER: %[[BCASTOFF:.+]] = vector.broadcast {{.*}} : index to vector<8x16x32xindex>
// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[BCASTOFF]], {{.*}} : vector<8x16x32xindex>
-// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %arg0 {{\[}}[0, 1, 2]{{\]}} : memref<16x32x64xf32> into memref<32768xf32>
-// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]][%[[IDX]]], %[[CST]] : memref<32768xf32>, vector<8x16x32xindex>, vector<8x16x32xi1> -> vector<8x16x32xf32>
+// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %arg0 : memref<16x32x64xf32> -> index
+// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]][%[[IDX]]], %[[CST]] : i64, vector<8x16x32xindex>, vector<8x16x32xi1> -> vector<8x16x32xf32>
}
@@ -283,8 +290,9 @@ gpu.func @load_transpose_f16(%source: memref<32x64xf16>,
// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
// LOAD-GATHER: %[[BCAST2:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[BCAST2]], {{.*}}: vector<8x16xindex>
-// LOAD-GATHER: %[[COLLAPSE:.*]] = memref.collapse_shape %arg0 {{\[\[}}0, 1{{\]\]}} : memref<32x64xf16> into memref<2048xf16>
-// LOAD-GATHER: %[[LOAD:.*]] = xegpu.load %[[COLLAPSE]][%[[IDX]]], %[[CST]] : memref<2048xf16>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf16>
+// LOAD-GATHER: %[[COLLAPSE:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<32x64xf16> -> index
+// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// LOAD-GATHER: %[[LOAD:.*]] = xegpu.load %[[COLLAPSE_I]][%[[IDX]]], %[[CST]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf16>
}
// -----
@@ -396,3 +404,40 @@ gpu.func @no_load_unsupported_map(%source: memref<16x32x64xf32>,
// LOAD-GATHER: vector.transfer_read
}
+// -----
+gpu.module @xevm_module {
+gpu.func @load_from_subview(%source: memref<4096x4096xf16>, %off1: index, %off2: index) -> vector<8xf16> {
+ %c0 = arith.constant 0.0 : f16
+ %subview = memref.subview %source[%off1, %off2] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
+ %0 = vector.transfer_read %subview[%off2, %off2], %c0
+ {in_bounds = [true]} : memref<256x256xf16, strided<[4096, 1], offset: ?>>, vector<8xf16>
+ gpu.return %0 : vector<8xf16>
+}
+
+// LOAD-ND-LABEL: @load_from_subview(
+// LOAD-ND-SAME: %[[SRC:.+]]: memref<4096x4096xf16>,
+// LOAD-ND-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// LOAD-ND: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
+// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
+// LOAD-ND-SAME: %[[SUBVIEW]][%[[OFF2]], %[[OFF2]]]
+// LOAD-ND-SAME: memref<256x256xf16, strided<[4096, 1], offset: ?>> -> !xegpu.tensor_desc<8xf16,
+// LOAD-ND-SAME: boundary_check = false
+// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8xf16>
+// LOAD-ND: return %[[VEC]]
+
+// LOAD-GATHER-LABEL: @load_from_subview(
+// LOAD-GATHER-SAME: %[[SRC:.+]]: memref<4096x4096xf16>,
+// LOAD-GATHER-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// LOAD-GATHER: %[[CST:.+]] = arith.constant dense<true> : vector<8xi1>
+// LOAD-GATHER: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
+// LOAD-GATHER: %[[BB:.+]], %[[OFFSET:.+]],{{.*}},{{.*}} = memref.extract_strided_metadata %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> memref<f16>, index, index, index, index, index
+// LOAD-GATHER: %[[STEP:.+]] = vector.step : vector<8xindex>
+// LOAD-GATHER: arith.muli {{.*}} : index
+// LOAD-GATHER: arith.addi %[[OFFSET]]{{.*}} : index
+// LOAD-GATHER: arith.addi {{.*}} : index
+// LOAD-GATHER: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8xindex>
+// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[SPLAT]], %[[STEP]] : vector<8xindex>
+// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> index
+// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf16>
+}
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
index b3f761a545ee1..fcfc9414da4f6 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
@@ -30,8 +30,9 @@ gpu.func @store_1D_vector(%vec: vector<8xf32>,
// STORE-SCATTER-COUNT2: arith.addi {{.*}} : index
// STORE-SCATTER-DAG: %[[BCAST:.+]] = vector.broadcast {{.*}} : index to vector<8xindex>
// STORE-SCATTER-DAG: %[[IDX:.+]] = arith.addi %[[BCAST]], %{{.*}} : vector<8xindex>
-// STORE-SCATTER-DAG: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
-// STORE-SCATTER: xegpu.store %[[VEC]], %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : vector<8xf32>, memref<4096xf32>, vector<8xindex>, vector<8xi1>
+// STORE-SCATTER-DAG: %[[COLLAPSE:.+]] = memref.extract_aligned_point...
[truncated]
|
|
@dchigarev Congratulations on having your first Pull Request (PR) merged into the LLVM Project! Your changes will be combined with recent changes from other authors, then tested by our build bots. If there is a problem with a build, you may receive a report in an email or a comment on this PR. Please check whether problems have been caused by your change specifically, as the builds can include changes from many authors. It is not uncommon for your change to be included in a build that fails due to someone else's changes, or infrastructure issues. How to do this, and the rest of the post-merge process, is covered in detail here. If your change does cause a problem, it may be reverted, or you can revert it yourself. This is a normal part of LLVM development. You can fix your changes and open a new PR to merge them again. If you don't get any reports, no action is required from you. Your changes are working as expected, well done! |
This PR fixes a case where a source memref in
vector.transfer_read/writeis not contiguous, which violates thememref.collapse_shapesemantic that is used in the lowering.An example of a failing test
Fails with:
A suggestion was to replace
memref.collapse_shapewithmemref.extract_aligned_pointer_as_indexwhich is done in this PR. Sinceextract_aligned_pointerapplied to a subview returns an original pointer without subview offsets, this PR also adds a logic to use an offset obtained frommemref.extract_strided_metadatainbaseOffsetcalculation incomputeOffsets.