-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[MLIR][XeGPU] Distribute load_nd/store_nd/prefetch_nd with offsets from Wg to Sg #153432
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 14 commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
fe75a08
Add create_nd_desc pattern without offset
nbpatel 14d6fb1
Merge branch 'main' into xegpu-create_nd_no_offset
nbpatel d43e55e
Add newfile
nbpatel 036bc99
Newline
nbpatel e219003
Newline
nbpatel 9e799d6
Fix builders
nbpatel fcbdb91
Add pattern for load/store/prefetch nd with offsets
nbpatel 7783591
Merge branch 'main' into xegpu-create_nd_no_offset-backup
nbpatel 639e997
Add tests
nbpatel 35bdf57
Refactor
nbpatel a1b35a4
Add more tests
nbpatel 45e56ff
Address feedback
nbpatel bbd38af
Merge branch 'main' into xegpu-load-nd-offset
nbpatel 7d3dde7
change variable name
nbpatel df3a466
feedback
nbpatel 1991a65
Merge branch 'main' into xegpu-load-nd-offset
nbpatel File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -296,6 +296,208 @@ struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> { | |
} | ||
}; | ||
|
||
// Utility function to compute global offsets for subgroup operations. | ||
// Returns a vector of new offsets for each subgroup, given the original op's | ||
// offsets and subgroup relative offsets. | ||
static SmallVector<SmallVector<OpFoldResult>> | ||
computeGlobalOffsets(Operation *op, ArrayRef<SmallVector<Value>> sgOffsetsList, | ||
ArrayRef<OpFoldResult> wgOffsets, | ||
ConversionPatternRewriter &rewriter) { | ||
SmallVector<SmallVector<OpFoldResult>> globalOffsets; | ||
Location loc = op->getLoc(); | ||
for (const auto &sgOffsets : sgOffsetsList) { | ||
SmallVector<OpFoldResult> newOffsets; | ||
size_t rank = sgOffsets.size(); | ||
for (size_t i = 0; i < rank; i++) { | ||
size_t idx = wgOffsets.size() - rank + i; | ||
Value add = rewriter.createOrFold<index::AddOp>( | ||
loc, sgOffsets[i], | ||
getValueOrCreateConstantIndexOp(rewriter, loc, wgOffsets[idx])); | ||
newOffsets.push_back(add); | ||
} | ||
globalOffsets.push_back(std::move(newOffsets)); | ||
} | ||
return globalOffsets; | ||
} | ||
|
||
// Utility function to get sgShape, sgOffsetList for a given | ||
// op. | ||
template <typename OpTy, typename AdaptorTy> | ||
LogicalResult getSgOffsets(OpTy op, AdaptorTy adaptor, | ||
ConversionPatternRewriter &rewriter, | ||
SmallVector<int64_t> &sgShape, | ||
SmallVector<SmallVector<Value>> &sgOffsetList) { | ||
int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size()); | ||
if (offsetSize == 0 && (!op.getConstOffsetsAttr())) | ||
return failure(); | ||
|
||
Location loc = op.getLoc(); | ||
Value tdesc = op.getTensorDesc(); | ||
auto tdescTy = dyn_cast<xegpu::TensorDescType>(tdesc.getType()); | ||
if (!tdescTy) | ||
return failure(); | ||
auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout()); | ||
if (!layout) | ||
return failure(); | ||
|
||
SmallVector<int64_t> sgLayout; | ||
if (auto sgLayoutAttr = layout.getSgLayout()) | ||
sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef()); | ||
else | ||
return rewriter.notifyMatchFailure( | ||
op, "sgLayout attribute is required in layout"); | ||
|
||
ArrayRef<int64_t> wgShape = tdescTy.getShape(); | ||
int count; | ||
std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout); | ||
|
||
// Get the subgroup ID | ||
Value linearSgId = | ||
gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr); | ||
|
||
int64_t startOfRange = -1, endOfRange = -1; | ||
bool sgIdRangeSpecified = isSgIdRangeSpecified(op, startOfRange, endOfRange); | ||
|
||
if (sgIdRangeSpecified) { | ||
int64_t sgCount = endOfRange - startOfRange; | ||
if (computeProduct(sgLayout) != sgCount) | ||
return rewriter.notifyMatchFailure( | ||
op, "sg_layout size must match the sg_id_range"); | ||
Value startOfRangeVal = | ||
rewriter.create<arith::ConstantIndexOp>(loc, startOfRange); | ||
linearSgId = | ||
rewriter.createOrFold<index::SubOp>(loc, linearSgId, startOfRangeVal); | ||
} | ||
|
||
auto sgOffsets = layout.getOffsets(rewriter, loc, linearSgId, wgShape); | ||
if (failed(sgOffsets)) | ||
return failure(); | ||
|
||
sgOffsetList = *sgOffsets; | ||
return success(); | ||
} | ||
|
||
template <typename OpTy> | ||
SmallVector<OpFoldResult> getWgOffsets(OpTy op, | ||
ConversionPatternRewriter &rewriter) { | ||
SmallVector<OpFoldResult> wgOffsets; | ||
if (auto constOffsets = op.getConstOffsetsAttr()) { | ||
for (auto attr : constOffsets.asArrayRef()) | ||
wgOffsets.push_back(rewriter.getIndexAttr(attr)); | ||
} | ||
for (auto v : op.getOffsets()) | ||
wgOffsets.push_back(v); | ||
return wgOffsets; | ||
} | ||
|
||
// This pattern transforms the LoadNdOp with explicit offsets to load | ||
// subgroup data. | ||
struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> { | ||
using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern; | ||
LogicalResult | ||
matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
|
||
SmallVector<int64_t> sgShape; | ||
SmallVector<SmallVector<Value>> sgOffsetList; | ||
|
||
// Do the distribution from workgroup to subgroup and get subgroup offsets | ||
if (failed(getSgOffsets(op, adaptor, rewriter, sgShape, sgOffsetList))) | ||
return failure(); | ||
|
||
// Get the original workgroup offsets | ||
SmallVector<OpFoldResult> wgOffsets = getWgOffsets(op, rewriter); | ||
nbpatel marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
// Calculate the global offsets | ||
auto globalOffsets = | ||
|
||
computeGlobalOffsets(op, sgOffsetList, wgOffsets, rewriter); | ||
|
||
SmallVector<Value> newLoadOps; | ||
for (auto [offsets, tdesc] : | ||
llvm::zip(globalOffsets, adaptor.getTensorDesc())) { | ||
VectorType newResTy = VectorType::get( | ||
sgShape, | ||
dyn_cast<xegpu::TensorDescType>(tdesc.getType()).getElementType()); | ||
auto newLoadOp = rewriter.create<xegpu::LoadNdOp>( | ||
op.getLoc(), newResTy, tdesc, offsets, | ||
/*packed=*/nullptr, | ||
/*transpose=*/nullptr, op.getL1HintAttr(), op.getL2HintAttr(), | ||
op.getL3HintAttr()); | ||
newLoadOps.push_back(newLoadOp); | ||
} | ||
rewriter.replaceOpWithMultiple(op, {newLoadOps}); | ||
return success(); | ||
} | ||
}; | ||
|
||
// This pattern transforms the StoreNdOp with explicit offsets to store | ||
// subgroup data. | ||
struct WgToSgStoreNdOpWithOffset | ||
: public OpConversionPattern<xegpu::StoreNdOp> { | ||
using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern; | ||
LogicalResult | ||
matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
|
||
SmallVector<int64_t> sgShape; | ||
SmallVector<SmallVector<Value>> sgOffsetList; | ||
|
||
// Do the distribution from workgroup to subgroup and get subgroup offsets | ||
if (failed(getSgOffsets(op, adaptor, rewriter, sgShape, sgOffsetList))) | ||
return failure(); | ||
|
||
// Get the original workgroup offsets | ||
SmallVector<OpFoldResult> wgOffsets = getWgOffsets(op, rewriter); | ||
|
||
// Calculate the global offsets | ||
auto globalOffsets = | ||
computeGlobalOffsets(op, sgOffsetList, wgOffsets, rewriter); | ||
|
||
for (auto [offsets, tdesc, value] : llvm::zip( | ||
globalOffsets, adaptor.getTensorDesc(), adaptor.getValue())) { | ||
rewriter.create<xegpu::StoreNdOp>(op.getLoc(), value, tdesc, offsets, | ||
op.getL1HintAttr(), op.getL2HintAttr(), | ||
op.getL3HintAttr()); | ||
} | ||
rewriter.eraseOp(op); | ||
return success(); | ||
} | ||
}; | ||
|
||
// This pattern transforms the PrefetchNdOp with explicit offsets to prefetch | ||
// subgroup data. | ||
struct WgToSgPrefetchNdOpWithOffset | ||
: public OpConversionPattern<xegpu::PrefetchNdOp> { | ||
using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern; | ||
LogicalResult | ||
matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
|
||
SmallVector<int64_t> sgShape; | ||
SmallVector<SmallVector<Value>> sgOffsetList; | ||
|
||
// Do the distribution from workgroup to subgroup and get subgroup offsets | ||
if (failed(getSgOffsets(op, adaptor, rewriter, sgShape, sgOffsetList))) | ||
return failure(); | ||
|
||
// Get the original workgroup offsets | ||
SmallVector<OpFoldResult> wgOffsets = getWgOffsets(op, rewriter); | ||
|
||
// calculate the global offsets | ||
auto globalOffsets = | ||
computeGlobalOffsets(op, sgOffsetList, wgOffsets, rewriter); | ||
|
||
for (auto [offsets, tdesc] : | ||
llvm::zip(globalOffsets, adaptor.getTensorDesc())) { | ||
rewriter.create<xegpu::PrefetchNdOp>( | ||
op.getLoc(), tdesc, offsets, op.getL1HintAttr(), op.getL2HintAttr(), | ||
op.getL3HintAttr()); | ||
} | ||
rewriter.eraseOp(op); | ||
return success(); | ||
} | ||
}; | ||
|
||
/// This pattern transforms the UpdateNdOffsetOp to update the offsets of a | ||
/// subgroup descriptor. It creates an UpdateNdOffsetOp op to update the | ||
/// offsets of the new subgroup src tensor descriptors. | ||
|
@@ -690,12 +892,13 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> { | |
namespace mlir { | ||
namespace xegpu { | ||
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) { | ||
patterns.add<WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp, | ||
WgToSgStoreNdOp, WgToSgUpdateNdOffsetOp, WgToSgDpasOp, | ||
WgToSgPrefetchNdOp, UnrealizedConversionCastOpPattern, | ||
WgToSgElementwiseOp, WgToSgVectorBroadcastOp, | ||
WgToSgConvertLayoutOp, WgToSgArithConstantOp>( | ||
patterns.getContext()); | ||
patterns | ||
.add<WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp, | ||
WgToSgLoadNdOpWithOffset, WgToSgStoreNdOp, WgToSgStoreNdOpWithOffset, | ||
WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp, | ||
WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern, | ||
WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp, | ||
WgToSgArithConstantOp>(patterns.getContext()); | ||
} | ||
} // namespace xegpu | ||
} // namespace mlir | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.