-
Notifications
You must be signed in to change notification settings - Fork 15.5k
[MLIR][XeGPU] Add support for subgroup_id_range #148661
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
1d18b89
b4e3068
70fe19c
e6528ef
07b9eff
09fdbfc
1cecfbe
3cde920
343d630
56ad954
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -315,4 +315,32 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> { | |||||
| let genVerifyDecl = 1; | ||||||
| } | ||||||
|
|
||||||
| def XeGPU_RangeAttr : XeGPUAttr<"Range", "range"> { | ||||||
| let summary = [{Specifies a half-open range}]; | ||||||
| let description = [{ | ||||||
| `RangeAttr` is an attribute that defines a half-open range [start, end). | ||||||
| The range is inclusive of the start value and exclusive of the end value. | ||||||
| One usage of this attribute can be to specify the subgroup id range. | ||||||
| The subgroup id range can be specified using this attribute, | ||||||
| and it can be attached to a scf.if op like | ||||||
| ```mlir | ||||||
| scf.if %cond { | ||||||
| // some operations | ||||||
| }{sg_id_range = #xegpu.range<[2, 4]>} | ||||||
|
||||||
| }{sg_id_range = #xegpu.range<[2, 4]>} | |
| } {sg_id_range = #xegpu.range<[2, 4]>} |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| let assemblyFormat = "`<` `[`$start ```,` $end `]``>`"; | |
| let assemblyFormat = "`<` `[`$start `,` $end `]` `>`"; |
nit: minor cleanup
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -34,6 +34,26 @@ using namespace mlir; | |
|
|
||
| namespace { | ||
|
|
||
| // Check if there is sg id range attached to the scf.if op. | ||
| static bool isSgIdRangeSpecified(Operation *op, int64_t &startOfRange, | ||
| int64_t &endOfRange) { | ||
| Operation *parent = op->getParentOp(); | ||
| // Find the outermost scf::IfOp with xegpu.sg_id_range. | ||
| while (parent) { | ||
| if (auto ifOp = dyn_cast<scf::IfOp>(parent)) { | ||
| if (auto attr = llvm::dyn_cast_or_null<xegpu::RangeAttr>( | ||
| ifOp->getAttr("sg_id_range"))) { | ||
| startOfRange = attr.getStart().getInt(); | ||
| endOfRange = attr.getEnd().getInt(); | ||
|
Comment on lines
+46
to
+47
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. General suggestion, non-blocker here: getting int value directly would make for a nice attribute helper method |
||
| break; | ||
| } | ||
| } | ||
| parent = parent->getParentOp(); | ||
| } | ||
| // Return false if startOfRange is 0 | ||
nbpatel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return (startOfRange > 0 && endOfRange > startOfRange); | ||
| } | ||
|
|
||
| static std::pair<SmallVector<int64_t>, int> | ||
| getSgShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) { | ||
| int count = 1; | ||
|
|
@@ -174,8 +194,27 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> { | |
| sgDataDim[i] = rewriter.create<arith::ConstantIndexOp>(loc, sgShape[i]); | ||
| } | ||
|
|
||
| int64_t startOfRange = -1, endOfRange = -1; | ||
| bool sgIdRangeSpecified = | ||
| isSgIdRangeSpecified(op, startOfRange, endOfRange); | ||
|
|
||
| Value adjustedSgId = linearSgId; | ||
| if (sgIdRangeSpecified) { | ||
| int64_t sgCount = endOfRange - startOfRange; | ||
| if (computeProduct(sgLayout) != sgCount) { | ||
| return rewriter.notifyMatchFailure( | ||
| op, "sg_layout size must match the sg_id_range"); | ||
| } | ||
nbpatel marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| // Subtract startOfRange from the original subgroup id to get the adjusted | ||
| // sg id | ||
| Value startOfRangeVal = | ||
| rewriter.create<arith::ConstantIndexOp>(loc, startOfRange); | ||
| adjustedSgId = | ||
| rewriter.createOrFold<index::SubOp>(loc, linearSgId, startOfRangeVal); | ||
| } | ||
|
|
||
| auto deLinearizeSgId = | ||
| affine::delinearizeIndex(rewriter, loc, linearSgId, sgLayoutDim); | ||
| affine::delinearizeIndex(rewriter, loc, adjustedSgId, sgLayoutDim); | ||
| if (failed(deLinearizeSgId)) | ||
| return failure(); | ||
| SmallVector<Value> sgIds = *deLinearizeSgId; | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.