-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[MLIR][XeGPU][VectorToXeGPU] Fix transfer_read/write cases with non-contiguous memrefs #158126
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -183,23 +183,28 @@ static void adjustStridesForPermutation(AffineMap permMap, | |
| // Computes memory strides for vector transfer operations, handling both | ||
| // static and dynamic memrefs while applying permutation transformations | ||
| // for XeGPU lowering. | ||
| static SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp, | ||
| PatternRewriter &rewriter) { | ||
| static std::pair<SmallVector<Value>, Value> | ||
| computeMemrefMeta(VectorTransferOpInterface xferOp, PatternRewriter &rewriter) { | ||
| SmallVector<Value> strides; | ||
| Value baseMemref = xferOp.getBase(); | ||
| AffineMap permMap = xferOp.getPermutationMap(); | ||
| MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType()); | ||
|
|
||
| Location loc = xferOp.getLoc(); | ||
| Value offsetVal = nullptr; | ||
| if (memrefType.hasStaticShape()) { | ||
| int64_t offset; | ||
| SmallVector<int64_t> intStrides; | ||
| if (failed(memrefType.getStridesAndOffset(intStrides, offset))) | ||
| return {}; | ||
| return {{}, offsetVal}; | ||
| // Wrap static strides as MLIR values | ||
| for (int64_t s : intStrides) | ||
| strides.push_back(arith::ConstantIndexOp::create(rewriter, loc, s)); | ||
| } else { | ||
| if (!ShapedType::isDynamic(offset)) | ||
| offsetVal = arith::ConstantIndexOp::create(rewriter, loc, offset); | ||
| } | ||
|
|
||
| if (strides.empty() || !offsetVal) { | ||
| // For dynamic shape memref, use memref.extract_strided_metadata to get | ||
| // stride values | ||
| unsigned rank = memrefType.getRank(); | ||
|
|
@@ -220,11 +225,16 @@ static SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp, | |
|
|
||
| auto meta = memref::ExtractStridedMetadataOp::create( | ||
| rewriter, loc, resultTypes, baseMemref); | ||
| strides.append(meta.getStrides().begin(), meta.getStrides().end()); | ||
|
|
||
| if (strides.empty()) | ||
| strides.append(meta.getStrides().begin(), meta.getStrides().end()); | ||
|
|
||
| if (!offsetVal) | ||
| offsetVal = meta.getOffset(); | ||
| } | ||
| // Adjust strides according to the permutation map (e.g., for transpose) | ||
| adjustStridesForPermutation(permMap, strides); | ||
| return strides; | ||
| return {strides, offsetVal}; | ||
| } | ||
|
|
||
| // This function compute the vectors of localOffsets for scattered load/stores. | ||
|
|
@@ -256,8 +266,8 @@ static SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp, | |
| // %orig_offset = %block_id_y * 4x2x6x32 // consider using affine map | ||
| // %offsets = orig_offset + local_offsets | ||
| static Value computeOffsets(VectorTransferOpInterface xferOp, | ||
| PatternRewriter &rewriter, | ||
| ArrayRef<Value> strides) { | ||
| PatternRewriter &rewriter, ArrayRef<Value> strides, | ||
| Value baseOffset) { | ||
|
Comment on lines
268
to
+270
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| Location loc = xferOp.getLoc(); | ||
| VectorType vectorType = xferOp.getVectorType(); | ||
| SmallVector<Value> indices(xferOp.getIndices().begin(), | ||
|
|
@@ -315,51 +325,30 @@ static Value computeOffsets(VectorTransferOpInterface xferOp, | |
| arith::AddIOp::create(rewriter, loc, localOffsets, broadcasted[i]); | ||
|
|
||
| // Compute base offset from transfer read indices | ||
| Value baseOffset = nullptr; | ||
| if (!indices.empty()) { | ||
|
Comment on lines
-318
to
-319
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. made the branch unconditional since we always want to consider the |
||
| baseOffset = arith::ConstantIndexOp::create(rewriter, loc, 0); | ||
| for (size_t i = 0; i < indices.size(); ++i) { | ||
| Value strideVal = strides[i]; | ||
| Value offsetContrib = | ||
| arith::MulIOp::create(rewriter, loc, indices[i], strideVal); | ||
| baseOffset = | ||
| arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib); | ||
| } | ||
| // Broadcast base offset to match vector shape | ||
| Value bcastBase = vector::BroadcastOp::create( | ||
| rewriter, loc, fullIndexVectorType, baseOffset); | ||
| localOffsets = | ||
| arith::AddIOp::create(rewriter, loc, bcastBase, localOffsets); | ||
| for (size_t i = 0; i < indices.size(); ++i) { | ||
| Value strideVal = strides[i]; | ||
| Value offsetContrib = | ||
| arith::MulIOp::create(rewriter, loc, indices[i], strideVal); | ||
| baseOffset = | ||
| arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib); | ||
| } | ||
| // Broadcast base offset to match vector shape | ||
| Value bcastBase = vector::BroadcastOp::create( | ||
| rewriter, loc, fullIndexVectorType, baseOffset); | ||
| localOffsets = arith::AddIOp::create(rewriter, loc, bcastBase, localOffsets); | ||
| return localOffsets; | ||
| } | ||
|
|
||
| // Collapse memref shape to 1D | ||
| static Value collapseMemrefTo1D(VectorTransferOpInterface xferOp, | ||
| PatternRewriter &rewriter) { | ||
| static Value memrefToIndexPtr(VectorTransferOpInterface xferOp, | ||
| PatternRewriter &rewriter) { | ||
|
Comment on lines
+343
to
+344
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| Location loc = xferOp.getLoc(); | ||
|
|
||
| Value baseMemref = xferOp.getBase(); | ||
| MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType()); | ||
| Type elementType = memrefType.getElementType(); | ||
|
|
||
| // Compute the total number of elements in the memref | ||
| MemRefType flatMemrefType; | ||
| if (memrefType.hasStaticShape()) { | ||
| auto totalElements = memrefType.getNumElements(); | ||
| flatMemrefType = MemRefType::get({totalElements}, elementType); | ||
| } else { | ||
| flatMemrefType = MemRefType::get({ShapedType::kDynamic}, elementType); | ||
| } | ||
|
|
||
| SmallVector<ReassociationIndices> reassociation; | ||
| ReassociationIndices allDims = | ||
| llvm::to_vector(llvm::seq<int64_t>(0, memrefType.getRank())); | ||
| reassociation.push_back(allDims); | ||
|
|
||
| auto collapseOp = memref::CollapseShapeOp::create( | ||
| rewriter, loc, flatMemrefType, baseMemref, reassociation); | ||
| return collapseOp; | ||
| auto indexPtr = memref::ExtractAlignedPointerAsIndexOp::create( | ||
| rewriter, loc, xferOp.getBase()) | ||
| .getResult(); | ||
| return arith::IndexCastOp::create(rewriter, loc, rewriter.getI64Type(), | ||
| indexPtr) | ||
| .getResult(); | ||
| } | ||
|
|
||
| static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp, | ||
|
|
@@ -372,13 +361,14 @@ static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp, | |
| if (!memrefType) | ||
| return rewriter.notifyMatchFailure(readOp, "Expected memref source"); | ||
|
|
||
| SmallVector<Value> strides = computeStrides(readOp, rewriter); | ||
| if (strides.empty()) | ||
| auto meta = computeMemrefMeta(readOp, rewriter); | ||
| if (meta.first.empty()) | ||
| return rewriter.notifyMatchFailure(readOp, "Failed to compute strides"); | ||
|
|
||
| Value localOffsets = computeOffsets(readOp, rewriter, strides); | ||
| Value localOffsets = | ||
| computeOffsets(readOp, rewriter, meta.first, meta.second); | ||
|
|
||
| Value flatMemref = collapseMemrefTo1D(readOp, rewriter); | ||
| Value flatMemref = memrefToIndexPtr(readOp, rewriter); | ||
|
|
||
| Value mask = vector::ConstantMaskOp::create( | ||
| rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()), | ||
|
|
@@ -405,11 +395,14 @@ static LogicalResult lowerToScatteredStoreOp(vector::TransferWriteOp writeOp, | |
| if (!memrefType) | ||
| return rewriter.notifyMatchFailure(writeOp, "Expected memref source"); | ||
|
|
||
| SmallVector<Value> strides = computeStrides(writeOp, rewriter); | ||
| auto meta = computeMemrefMeta(writeOp, rewriter); | ||
| if (meta.first.empty()) | ||
| return rewriter.notifyMatchFailure(writeOp, "Failed to compute strides"); | ||
|
|
||
| Value localOffsets = computeOffsets(writeOp, rewriter, strides); | ||
| Value localOffsets = | ||
| computeOffsets(writeOp, rewriter, meta.first, meta.second); | ||
|
|
||
| Value flatMemref = collapseMemrefTo1D(writeOp, rewriter); | ||
| Value flatMemref = memrefToIndexPtr(writeOp, rewriter); | ||
|
|
||
| Value mask = vector::ConstantMaskOp::create( | ||
| rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()), | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A function that calls
memref.extract_strided_metadatanow also returns memref's offset together with the strides