-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[mlir][XeGPU] Add optional layout attribute to LoadGather StoreScatter ops #163414
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
a66d2f6
a874531
5fd0ba1
784f7bb
05a71dd
3a72f6d
3da20eb
15f6907
027557d
f1208f2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -435,7 +435,8 @@ static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp, | |
| /*chunk_size=*/IntegerAttr{}, | ||
| /*l1_hint=*/xegpu::CachePolicyAttr{}, | ||
| /*l2_hint=*/xegpu::CachePolicyAttr{}, | ||
| /*l3_hint=*/xegpu::CachePolicyAttr{}); | ||
| /*l3_hint=*/xegpu::CachePolicyAttr{}, | ||
| /*layout=*/nullptr); | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. #163071 to fill the gap |
||
|
|
||
| rewriter.replaceOp(readOp, gatherOp.getResult()); | ||
| return success(); | ||
|
|
@@ -469,7 +470,8 @@ static LogicalResult lowerToScatteredStoreOp(vector::TransferWriteOp writeOp, | |
| /*chunk_size=*/IntegerAttr{}, | ||
| /*l1_hint=*/xegpu::CachePolicyAttr{}, | ||
| /*l2_hint=*/xegpu::CachePolicyAttr{}, | ||
| /*l3_hint=*/xegpu::CachePolicyAttr{}); | ||
| /*l3_hint=*/xegpu::CachePolicyAttr{}, | ||
| /*layout=*/nullptr); | ||
| rewriter.eraseOp(writeOp); | ||
| return success(); | ||
| } | ||
|
|
@@ -621,7 +623,8 @@ struct GatherLowering : public OpRewritePattern<vector::GatherOp> { | |
| /*chunk_size=*/IntegerAttr{}, | ||
| /*l1_hint=*/xegpu::CachePolicyAttr{}, | ||
| /*l2_hint=*/xegpu::CachePolicyAttr{}, | ||
| /*l3_hint=*/xegpu::CachePolicyAttr{}); | ||
| /*l3_hint=*/xegpu::CachePolicyAttr{}, | ||
| /*layout=*/nullptr); | ||
|
|
||
| auto selectOp = | ||
| arith::SelectOp::create(rewriter, loc, gatherOp.getMask(), | ||
|
|
@@ -655,7 +658,8 @@ struct ScatterLowering : public OpRewritePattern<vector::ScatterOp> { | |
| /*chunk_size=*/IntegerAttr{}, | ||
| /*l1_hint=*/xegpu::CachePolicyAttr{}, | ||
| /*l2_hint=*/xegpu::CachePolicyAttr{}, | ||
| /*l3_hint=*/xegpu::CachePolicyAttr{}); | ||
| /*l3_hint=*/xegpu::CachePolicyAttr{}, | ||
| /*layout=*/nullptr); | ||
| rewriter.eraseOp(scatterOp); | ||
| return success(); | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -105,12 +105,22 @@ mlir::xegpu::getDistributedVectorType(VectorType originalType, | |
| std::string xegpu::getLayoutName(const OpOperand &operand) { | ||
| const StringRef prefix("layout_operand_"); | ||
| unsigned idx = const_cast<OpOperand &>(operand).getOperandNumber(); | ||
| return llvm::formatv("{0}{1}", prefix, idx).str(); | ||
| auto owner = operand.getOwner(); | ||
| auto tempLayout = llvm::formatv("{0}{1}", prefix, idx).str(); | ||
| if (isa<StoreScatterOp>(operand.getOwner()) && idx == 0 && | ||
dchigarev marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| !owner->hasAttr(tempLayout)) | ||
| return "layout"; | ||
| return tempLayout; | ||
| } | ||
|
|
||
| std::string xegpu::getLayoutName(const OpResult result) { | ||
| const StringRef prefix = "layout_result_"; | ||
| return llvm::formatv("{0}{1}", prefix, result.getResultNumber()).str(); | ||
| auto owner = result.getOwner(); | ||
| auto tempLayout = | ||
dchigarev marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| llvm::formatv("{0}{1}", prefix, result.getResultNumber()).str(); | ||
| if (isa<LoadGatherOp>(owner) && !owner->hasAttr(tempLayout)) | ||
| return "layout"; | ||
| return tempLayout; | ||
| } | ||
|
|
||
| xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) { | ||
|
|
@@ -144,6 +154,11 @@ xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) { | |
| std::string layoutName = getLayoutName(result); | ||
| if (defOp->hasAttr(layoutName)) | ||
| return defOp->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName); | ||
|
|
||
| // check for "permament" layout only after "temporary" layout name lookup | ||
| // for backward compatibility | ||
| if (auto loadGatherOp = dyn_cast<xegpu::LoadGatherOp>(defOp)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no store_matrix here?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think since the store scatter does not return a result it is not needed here. We only need load gather. |
||
| return loadGatherOp.getLayoutAttr(); | ||
| } | ||
|
|
||
| if (auto arg = dyn_cast<BlockArgument>(value)) { | ||
|
|
@@ -171,6 +186,13 @@ xegpu::getDistributeLayoutAttr(const OpOperand &opr) { | |
| std::string layoutName = xegpu::getLayoutName(opr); | ||
| if (op->hasAttr(layoutName)) | ||
| return op->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName); | ||
|
|
||
| // check for "permament" layout only after "temporary" layout name lookup | ||
| // for backward compatibility | ||
| if (auto storeScatterOp = dyn_cast<xegpu::StoreScatterOp>(op)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not sure why there is only storescatter, no loadgather? Note that there are both load_matrix and store_matrix.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess we should handle the anchore ops first. So all following ops should be proceesed first here. Load/Store matrix I also think store ops can be omitted (but fine to have in code), because they don't return anything. so they can never be used as an
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
this function handles |
||
| if (auto layout = storeScatterOp.getLayoutAttr()) | ||
| return layout; | ||
|
|
||
| return getDistributeLayoutAttr(opr.get()); | ||
| } | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: ideally this needs to be
xegpu::LayoutAttr.DistributedLayoutAttris derivative layout because it could either be pure LayoutAttr or SliceAttr.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed,
btw
Load/StoreMatrixOps takeDistributeLayoutAttr, I believe we should adjust them as well eventually