-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[mlir][XeGPU][VectorToXeGPU] Propagate vector layouts to xegpu ops #163071
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
a22b251
76c8129
fdb0540
79e37d8
3afe5d5
2a43ee6
e0cf57b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -97,6 +97,21 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter, | |
return success(); | ||
} | ||
|
||
// Extract cache hints from the op attributes if available. | ||
static SmallVector<xegpu::CachePolicyAttr, 3> getOpCacheHints(Operation *op) { | ||
SmallVector<xegpu::CachePolicyAttr, 3> cacheHints{xegpu::CachePolicyAttr{}, | ||
xegpu::CachePolicyAttr{}, | ||
xegpu::CachePolicyAttr{}}; | ||
// get l1, l2, l3 hints from attributes if available. | ||
if (auto l1Attr = op->getAttrOfType<xegpu::CachePolicyAttr>("l1_hint")) | ||
cacheHints[0] = l1Attr; | ||
if (auto l2Attr = op->getAttrOfType<xegpu::CachePolicyAttr>("l2_hint")) | ||
cacheHints[1] = l2Attr; | ||
if (auto l3Attr = op->getAttrOfType<xegpu::CachePolicyAttr>("l3_hint")) | ||
cacheHints[2] = l3Attr; | ||
return cacheHints; | ||
} | ||
|
||
static xegpu::CreateNdDescOp | ||
createNdDescriptor(PatternRewriter &rewriter, Location loc, | ||
xegpu::TensorDescType descType, TypedValue<MemRefType> src, | ||
|
@@ -430,12 +445,16 @@ static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp, | |
Value mask = vector::ConstantMaskOp::create( | ||
rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()), | ||
vectorShape); | ||
auto gatherOp = xegpu::LoadGatherOp::create( | ||
rewriter, loc, vectorType, flatMemref, localOffsets, mask, | ||
/*chunk_size=*/IntegerAttr{}, | ||
/*l1_hint=*/xegpu::CachePolicyAttr{}, | ||
/*l2_hint=*/xegpu::CachePolicyAttr{}, | ||
/*l3_hint=*/xegpu::CachePolicyAttr{}); | ||
SmallVector<xegpu::CachePolicyAttr, 3> cacheHints = getOpCacheHints(readOp); | ||
auto gatherOp = xegpu::LoadGatherOp::create(rewriter, loc, vectorType, | ||
flatMemref, localOffsets, mask, | ||
/*chunk_size=*/IntegerAttr{}, | ||
/*l1_hint=*/cacheHints[0], | ||
/*l2_hint=*/cacheHints[1], | ||
/*l3_hint=*/cacheHints[2]); | ||
auto resLayout = xegpu::getDistributeLayoutAttr(readOp.getResult()); | ||
xegpu::setDistributeLayoutAttrs(gatherOp, | ||
[&](Value val) { return resLayout; }); | ||
|
||
rewriter.replaceOp(readOp, gatherOp.getResult()); | ||
return success(); | ||
|
@@ -464,12 +483,16 @@ static LogicalResult lowerToScatteredStoreOp(vector::TransferWriteOp writeOp, | |
Value mask = vector::ConstantMaskOp::create( | ||
rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()), | ||
vectorShape); | ||
xegpu::StoreScatterOp::create(rewriter, loc, writeOp.getVector(), flatMemref, | ||
localOffsets, mask, | ||
/*chunk_size=*/IntegerAttr{}, | ||
/*l1_hint=*/xegpu::CachePolicyAttr{}, | ||
/*l2_hint=*/xegpu::CachePolicyAttr{}, | ||
/*l3_hint=*/xegpu::CachePolicyAttr{}); | ||
auto cacheHints = getOpCacheHints(writeOp); | ||
auto storeOp = xegpu::StoreScatterOp::create( | ||
rewriter, loc, writeOp.getVector(), flatMemref, localOffsets, mask, | ||
/*chunk_size=*/IntegerAttr{}, | ||
/*l1_hint=*/cacheHints[0], | ||
/*l2_hint=*/cacheHints[1], | ||
/*l3_hint=*/cacheHints[2]); | ||
auto valueLayout = xegpu::getDistributeLayoutAttr(writeOp->getOpOperand(0)); | ||
xegpu::setDistributeLayoutAttrs(storeOp, | ||
[&](Value val) { return valueLayout; }); | ||
rewriter.eraseOp(writeOp); | ||
return success(); | ||
} | ||
|
@@ -519,9 +542,11 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> { | |
SmallVector<int64_t> descShape(vecTy.getShape()); | ||
if (isTransposeLoad) | ||
std::reverse(descShape.begin(), descShape.end()); | ||
auto descType = xegpu::TensorDescType::get( | ||
descShape, elementType, /*array_length=*/1, | ||
/*boundary_check=*/isOutOfBounds, xegpu::MemorySpace::Global); | ||
auto resLayout = xegpu::getDistributeLayoutAttr(readOp.getResult()); | ||
auto descType = | ||
xegpu::TensorDescType::get(descShape, elementType, /*array_length=*/1, | ||
/*boundary_check=*/isOutOfBounds, | ||
xegpu::MemorySpace::Global, resLayout); | ||
|
||
xegpu::CreateNdDescOp ndDesc = | ||
createNdDescriptor(rewriter, loc, descType, | ||
|
@@ -532,12 +557,12 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> { | |
!isTransposeLoad ? nullptr | ||
: DenseI64ArrayAttr::get(rewriter.getContext(), | ||
ArrayRef<int64_t>{1, 0}); | ||
// By default, no specific caching policy is assigned. | ||
xegpu::CachePolicyAttr hint = nullptr; | ||
auto cacheHints = getOpCacheHints(readOp); | ||
auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, | ||
/*packed=*/nullptr, transposeAttr, | ||
/*l1_hint=*/hint, | ||
/*l2_hint=*/hint, /*l3_hint=*/hint); | ||
/*l1_hint=*/cacheHints[0], | ||
/*l2_hint=*/cacheHints[1], | ||
/*l3_hint=*/cacheHints[2]); | ||
rewriter.replaceOp(readOp, loadOp); | ||
|
||
return success(); | ||
|
@@ -575,21 +600,24 @@ struct TransferWriteLowering | |
if (!map.isMinorIdentity()) | ||
return rewriter.notifyMatchFailure(writeOp, "Expects identity map"); | ||
|
||
auto valLayout = xegpu::getDistributeLayoutAttr(writeOp->getOpOperand(0)); | ||
auto descType = xegpu::TensorDescType::get( | ||
vecTy.getShape(), vecTy.getElementType(), | ||
/*array_length=*/1, /*boundary_check=*/writeOp.hasOutOfBoundsDim(), | ||
xegpu::MemorySpace::Global); | ||
xegpu::MemorySpace::Global, valLayout); | ||
xegpu::CreateNdDescOp ndDesc = | ||
createNdDescriptor(rewriter, loc, descType, | ||
dyn_cast<TypedValue<MemRefType>>(writeOp.getBase()), | ||
writeOp.getIndices()); | ||
|
||
// By default, no specific caching policy is assigned. | ||
xegpu::CachePolicyAttr hint = nullptr; | ||
auto cacheHints = getOpCacheHints(writeOp); | ||
auto storeOp = | ||
xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), ndDesc, | ||
/*l1_hint=*/hint, | ||
/*l2_hint=*/hint, /*l3_hint=*/hint); | ||
/*l1_hint=*/cacheHints[0], | ||
/*l2_hint=*/cacheHints[1], | ||
/*l3_hint=*/cacheHints[2]); | ||
rewriter.replaceOp(writeOp, storeOp); | ||
|
||
return success(); | ||
|
@@ -616,16 +644,33 @@ struct GatherLowering : public OpRewritePattern<vector::GatherOp> { | |
computeOffsets(rewriter, gatherOp, meta.first, meta.second); | ||
Value flatMemref = memrefToIndexPtr(gatherOp, rewriter); | ||
|
||
auto layoutRes = xegpu::getDistributeLayoutAttr(gatherOp.getResult()); | ||
auto layoutIndices = | ||
xegpu::getDistributeLayoutAttr(gatherOp.getIndicesMutable()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any reason to use mutable getters? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. otherwise it won't get |
||
auto layoutMask = xegpu::getDistributeLayoutAttr(gatherOp.getMaskMutable()); | ||
auto layoutPassThru = | ||
xegpu::getDistributeLayoutAttr(gatherOp.getPassThruMutable()); | ||
SmallVector<xegpu::CachePolicyAttr, 3> cacheHints = | ||
getOpCacheHints(gatherOp); | ||
auto xeGatherOp = xegpu::LoadGatherOp::create( | ||
rewriter, loc, vectorType, flatMemref, localOffsets, gatherOp.getMask(), | ||
/*chunk_size=*/IntegerAttr{}, | ||
/*l1_hint=*/xegpu::CachePolicyAttr{}, | ||
/*l2_hint=*/xegpu::CachePolicyAttr{}, | ||
/*l3_hint=*/xegpu::CachePolicyAttr{}); | ||
/*l1_hint=*/cacheHints[0], | ||
/*l2_hint=*/cacheHints[1], | ||
/*l3_hint=*/cacheHints[2]); | ||
xegpu::setDistributeLayoutAttr(xeGatherOp->getOpResult(0), layoutRes); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should expand xegpu.load/store with optional attribute DistributeLayoutAttr (like loadmatrix op definition), and set the operation's attribute only. The propagation process will propagate it to other ops. |
||
xegpu::setDistributeLayoutAttr(xeGatherOp.getOffsetsMutable()[0], | ||
layoutIndices); | ||
xegpu::setDistributeLayoutAttr(xeGatherOp.getMaskMutable(), layoutMask); | ||
|
||
auto selectOp = | ||
arith::SelectOp::create(rewriter, loc, gatherOp.getMask(), | ||
xeGatherOp.getResult(), gatherOp.getPassThru()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just to double-check, the layout isn't assigned to the second operand (LoadGather) as it's already in the producer's result. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Right, I thought that's enough. There seems to be no drawbacks though from assigning layout in both places. Applied layout in both places in the last commit just in case There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. setting the operand layout is done at the client of layout propagation result. So I think there is no need to update There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it always the case? Say, the value to store comes as a function's argument (our func is not inlined yet) and it's impossible to determine the producer but it's possible to determine the layout from |
||
xegpu::setDistributeLayoutAttr(selectOp.getConditionMutable(), layoutMask); | ||
xegpu::setDistributeLayoutAttr(selectOp.getFalseValueMutable(), | ||
layoutPassThru); | ||
xegpu::setDistributeLayoutAttr(selectOp->getOpResult(0), layoutRes); | ||
|
||
rewriter.replaceOp(gatherOp, selectOp.getResult()); | ||
return success(); | ||
} | ||
|
@@ -650,12 +695,25 @@ struct ScatterLowering : public OpRewritePattern<vector::ScatterOp> { | |
computeOffsets(rewriter, scatterOp, meta.first, meta.second); | ||
Value flatMemref = memrefToIndexPtr(scatterOp, rewriter); | ||
|
||
xegpu::StoreScatterOp::create(rewriter, loc, scatterOp.getValueToStore(), | ||
flatMemref, localOffsets, scatterOp.getMask(), | ||
/*chunk_size=*/IntegerAttr{}, | ||
/*l1_hint=*/xegpu::CachePolicyAttr{}, | ||
/*l2_hint=*/xegpu::CachePolicyAttr{}, | ||
/*l3_hint=*/xegpu::CachePolicyAttr{}); | ||
auto layoutIndices = | ||
xegpu::getDistributeLayoutAttr(scatterOp.getIndicesMutable()); | ||
auto layoutMask = | ||
xegpu::getDistributeLayoutAttr(scatterOp.getMaskMutable()); | ||
auto layoutVal = | ||
xegpu::getDistributeLayoutAttr(scatterOp.getValueToStoreMutable()); | ||
SmallVector<xegpu::CachePolicyAttr, 3> cacheHints = | ||
getOpCacheHints(scatterOp); | ||
auto storeOp = xegpu::StoreScatterOp::create( | ||
rewriter, loc, scatterOp.getValueToStore(), flatMemref, localOffsets, | ||
scatterOp.getMask(), | ||
/*chunk_size=*/IntegerAttr{}, | ||
/*l1_hint=*/cacheHints[0], | ||
/*l2_hint=*/cacheHints[1], | ||
/*l3_hint=*/cacheHints[2]); | ||
xegpu::setDistributeLayoutAttr(storeOp.getValueMutable(), layoutVal); | ||
xegpu::setDistributeLayoutAttr(storeOp.getOffsetsMutable()[0], | ||
layoutIndices); | ||
xegpu::setDistributeLayoutAttr(storeOp.getMaskMutable(), layoutMask); | ||
rewriter.eraseOp(scatterOp); | ||
return success(); | ||
} | ||
|
@@ -675,18 +733,20 @@ struct LoadLowering : public OpRewritePattern<vector::LoadOp> { | |
// Boundary check is available only for block instructions. | ||
bool boundaryCheck = vecTy.getRank() > 1; | ||
|
||
auto resLayout = xegpu::getDistributeLayoutAttr(loadOp.getResult()); | ||
auto descType = xegpu::TensorDescType::get( | ||
vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1, | ||
boundaryCheck, xegpu::MemorySpace::Global); | ||
boundaryCheck, xegpu::MemorySpace::Global, resLayout); | ||
xegpu::CreateNdDescOp ndDesc = createNdDescriptor( | ||
rewriter, loc, descType, loadOp.getBase(), loadOp.getIndices()); | ||
|
||
// By default, no specific caching policy is assigned. | ||
xegpu::CachePolicyAttr hint = nullptr; | ||
auto cacheHints = getOpCacheHints(loadOp); | ||
auto loadNdOp = xegpu::LoadNdOp::create( | ||
rewriter, loc, vecTy, ndDesc, /*packed=*/nullptr, /*transpose=*/nullptr, | ||
/*l1_hint=*/hint, | ||
/*l2_hint=*/hint, /*l3_hint=*/hint); | ||
/*l1_hint=*/cacheHints[0], | ||
/*l2_hint=*/cacheHints[1], /*l3_hint=*/cacheHints[2]); | ||
rewriter.replaceOp(loadOp, loadNdOp); | ||
|
||
return success(); | ||
|
@@ -708,18 +768,21 @@ struct StoreLowering : public OpRewritePattern<vector::StoreOp> { | |
// Boundary check is available only for block instructions. | ||
bool boundaryCheck = vecTy.getRank() > 1; | ||
|
||
auto descType = xegpu::TensorDescType::get( | ||
vecTy.getShape(), vecTy.getElementType(), | ||
/*array_length=*/1, boundaryCheck, xegpu::MemorySpace::Global); | ||
auto valLayout = xegpu::getDistributeLayoutAttr(storeOp->getOpOperand(0)); | ||
auto descType = | ||
xegpu::TensorDescType::get(vecTy.getShape(), vecTy.getElementType(), | ||
/*array_length=*/1, boundaryCheck, | ||
xegpu::MemorySpace::Global, valLayout); | ||
xegpu::CreateNdDescOp ndDesc = createNdDescriptor( | ||
rewriter, loc, descType, storeOp.getBase(), storeOp.getIndices()); | ||
|
||
// By default, no specific caching policy is assigned. | ||
xegpu::CachePolicyAttr hint = nullptr; | ||
auto storeNdOp = | ||
xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc, | ||
/*l1_hint=*/hint, | ||
/*l2_hint=*/hint, /*l3_hint=*/hint); | ||
auto cacheHints = getOpCacheHints(storeOp); | ||
auto storeNdOp = xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc, | ||
/*l1_hint=*/cacheHints[0], | ||
/*l2_hint=*/cacheHints[1], | ||
/*l3_hint=*/cacheHints[2]); | ||
rewriter.replaceOp(storeOp, storeNdOp); | ||
|
||
return success(); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All three layouts should be exactly the same.
We can take the layout from Result vector, and set to the store op's layout attribute (need to expand the op definition to make the attribute persistent, not lost after optimization passes).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I found at least one scenario in our test cases where not all operands of
xegpu.store
have the same layout.It also looks quite confusing in case of
xegpu.load
to have mixedlayout_operand_*
andlayout
attribute:Given that, I don't think we need to have a separate layout attribute for xegpu load/store ops (or we have to rethink it once again, and make it not like
load_matrix/store_matrix
layout attribute).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At the entry XeGPU IR, we need to set the layout to anchor op, and rely on the propagation pass to propagate the layout to all values. The anchor op include: load/store, load_matrix/store_matrix, load_nd/store_nd, and dpas, convert_layout. The layout describes the layout for the tensor tile: for memory operation, it describes the memory operand (like memref, mem_desc, nd_tdesc), for dpas, it describes the vector operands.
The layout propagated are temporarily attributes and can be lost during IR transformation or lowering. But the anchor op defined layout attributes are permanents and we rely on these attributes being set up properly so able to recover in case user compose xegpu passes with their own passes.
The example you show above actually looks fine to me, as long as the propogation rules decides to drop one dimension for the mask and index and keep the tensor tile layout as 2d. It does look a bit complex, but I don't expect this kind of layout is exposed to user since chunkload should be lower-level representation for memory coalesce or load_matrix lowering.
So I still suggest changing this PR and just set the layout attribute according to the tensor tile's layout for anchor ops only and remove the attributes setting for other ops and operands.