-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[mlir][XeGPU] Add optional layout attribute to LoadGather StoreScatter ops #163414
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a66d2f6
a874531
5fd0ba1
784f7bb
05a71dd
3a72f6d
3da20eb
15f6907
027557d
f1208f2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -144,6 +144,11 @@ xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) { | |
| std::string layoutName = getLayoutName(result); | ||
| if (defOp->hasAttr(layoutName)) | ||
| return defOp->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName); | ||
|
|
||
| // check for "permament" layout only after "temporary" layout name lookup | ||
| // for backward compatibility | ||
| if (auto loadGatherOp = dyn_cast<xegpu::LoadGatherOp>(defOp)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no store_matrix here?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think since the store scatter does not return a result it is not needed here. We only need load gather. |
||
| return loadGatherOp.getLayoutAttr(); | ||
| } | ||
|
|
||
| if (auto arg = dyn_cast<BlockArgument>(value)) { | ||
|
|
@@ -171,27 +176,77 @@ xegpu::getDistributeLayoutAttr(const OpOperand &opr) { | |
| std::string layoutName = xegpu::getLayoutName(opr); | ||
| if (op->hasAttr(layoutName)) | ||
| return op->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName); | ||
|
|
||
| // check for "permament" layout only after "temporary" layout name lookup | ||
| if (auto storeScatterOp = dyn_cast<xegpu::StoreScatterOp>(op)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not sure why there is only storescatter, no loadgather? Note that there are both load_matrix and store_matrix.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess we should handle the anchore ops first. So all following ops should be proceesed first here. Load/Store matrix I also think store ops can be omitted (but fine to have in code), because they don't return anything. so they can never be used as an
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
this function handles |
||
| if (auto layout = storeScatterOp.getLayoutAttr()) | ||
| return layout; | ||
|
|
||
| return getDistributeLayoutAttr(opr.get()); | ||
| } | ||
|
|
||
| // Returns the permanent layout attribute for the given result if it's | ||
| // available on the defining op. Otherwise returns the provided layout. | ||
| xegpu::DistributeLayoutAttr | ||
| maybePickPermanentLayout(xegpu::DistributeLayoutAttr layout, | ||
| const OpResult &result, mlir::Operation *owner, | ||
| const std::string &name) { | ||
| xegpu::DistributeLayoutAttr candidate = layout; | ||
|
|
||
| if (auto loadOp = dyn_cast<xegpu::LoadGatherOp>(owner)) { | ||
| if (auto perm = loadOp.getLayoutAttr()) | ||
| candidate = perm; | ||
| } | ||
|
|
||
| return candidate; | ||
| } | ||
|
|
||
| // Returns the permanent layout attribute for the given operand if it's | ||
| // available on the defining op. Otherwise returns the provided layout. | ||
| xegpu::DistributeLayoutAttr | ||
| maybePickPermanentLayout(xegpu::DistributeLayoutAttr layout, | ||
| const OpOperand &operand, mlir::Operation *owner, | ||
| const std::string &name) { | ||
| xegpu::DistributeLayoutAttr candidate = layout; | ||
| unsigned idx = const_cast<OpOperand &>(operand).getOperandNumber(); | ||
|
|
||
| if (auto storeOp = dyn_cast<xegpu::StoreScatterOp>(owner)) { | ||
| if (idx == 0) { | ||
| if (auto perm = storeOp.getLayoutAttr()) | ||
| candidate = perm; | ||
| } | ||
| } | ||
|
|
||
| return candidate; | ||
| } | ||
|
|
||
| template <typename T, typename> | ||
| void xegpu::setDistributeLayoutAttr(const T &operandOrResult, | ||
| const DistributeLayoutAttr layout) { | ||
| const DistributeLayoutAttr layout, | ||
| bool respectPermLayout) { | ||
| Operation *owner = operandOrResult.getOwner(); | ||
| std::string name = xegpu::getLayoutName(operandOrResult); | ||
| if (layout && !owner->hasAttrOfType<DistributeLayoutAttr>(name)) | ||
| owner->setAttr(name, layout); | ||
|
|
||
| if (owner->hasAttrOfType<DistributeLayoutAttr>(name)) | ||
| return; | ||
|
|
||
| DistributeLayoutAttr candidate = layout; | ||
| if (respectPermLayout) | ||
| candidate = maybePickPermanentLayout(layout, operandOrResult, owner, name); | ||
|
|
||
| if (candidate) | ||
| owner->setAttr(name, candidate); | ||
| } | ||
|
|
||
| // Explicit instantiation for OpResult | ||
| template void xegpu::setDistributeLayoutAttr<mlir::OpResult>( | ||
| const mlir::OpResult &result, | ||
| const mlir::xegpu::DistributeLayoutAttr layout); | ||
| const mlir::xegpu::DistributeLayoutAttr layout, bool respectPermLayout); | ||
|
|
||
| // Explicit instantiation for OpOperand | ||
| template void xegpu::setDistributeLayoutAttr<mlir::OpOperand>( | ||
| const mlir::OpOperand &operand, | ||
| const mlir::xegpu::DistributeLayoutAttr layout); | ||
| const mlir::xegpu::DistributeLayoutAttr layout, bool respectPermLayout); | ||
|
|
||
| void xegpu::setDistributeLayoutAttrs( | ||
| Operation *op, function_ref<DistributeLayoutAttr(Value)> getLayoutImpl) { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#163071 to fill the gap