-
Notifications
You must be signed in to change notification settings - Fork 15.5k
[MLIR][XeGPU] Add support for subgroup_id_range #148661
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
1d18b89
b4e3068
70fe19c
e6528ef
07b9eff
09fdbfc
1cecfbe
3cde920
343d630
56ad954
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -211,6 +211,21 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError, | |
| return success(); | ||
| } | ||
|
|
||
| //===----------------------------------------------------------------------===// | ||
| // XeGPU_RangeAttr | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| LogicalResult | ||
| RangeAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add one invalid test case?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not sure if its possible to add a negative test case with this pass...because it will always give legalization error for the create_nd_desc op if the pattern returns a failure in this case |
||
| IntegerAttr startOfRange, IntegerAttr endOfRange) { | ||
| if (startOfRange.getInt() >= endOfRange.getInt()) | ||
| return emitError() << "'end' : " << endOfRange.getInt() | ||
| << " must be greater than 'start' : " | ||
| << startOfRange.getInt(); | ||
|
|
||
| return success(); | ||
| } | ||
|
|
||
| //===----------------------------------------------------------------------===// | ||
| // XeGPU_TensorDescType | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -34,6 +34,26 @@ using namespace mlir; | |
|
|
||
| namespace { | ||
|
|
||
| // Check if there is sg id range attached to the scf.if op. | ||
| static bool isSgIdRangeSpecified(Operation *op, int64_t &startOfRange, | ||
| int64_t &endOfRange) { | ||
| Operation *parent = op->getParentOp(); | ||
| // Find the outermost scf::IfOp with xegpu.sg_id_range. | ||
| while (parent) { | ||
| if (auto ifOp = dyn_cast<scf::IfOp>(parent)) { | ||
| if (auto attr = llvm::dyn_cast_or_null<xegpu::RangeAttr>( | ||
| ifOp->getAttr("sg_id_range"))) { | ||
| startOfRange = attr.getStart().getInt(); | ||
| endOfRange = attr.getEnd().getInt(); | ||
|
Comment on lines
+46
to
+47
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. General suggestion, non-blocker here: getting int value directly would make for a nice attribute helper method |
||
| break; | ||
| } | ||
| } | ||
| parent = parent->getParentOp(); | ||
| } | ||
| // Return false if startOfRange is 0 | ||
nbpatel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return (startOfRange > 0 && endOfRange > startOfRange); | ||
| } | ||
|
|
||
| static std::pair<SmallVector<int64_t>, int> | ||
| getSgShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) { | ||
| int count = 1; | ||
|
|
@@ -174,8 +194,26 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> { | |
| sgDataDim[i] = arith::ConstantIndexOp::create(rewriter, loc, sgShape[i]); | ||
| } | ||
|
|
||
| int64_t startOfRange = -1, endOfRange = -1; | ||
| bool sgIdRangeSpecified = | ||
| isSgIdRangeSpecified(op, startOfRange, endOfRange); | ||
|
|
||
| Value adjustedSgId = linearSgId; | ||
| if (sgIdRangeSpecified) { | ||
| int64_t sgCount = endOfRange - startOfRange; | ||
| if (computeProduct(sgLayout) != sgCount) | ||
| return rewriter.notifyMatchFailure( | ||
| op, "sg_layout size must match the sg_id_range"); | ||
| // Subtract startOfRange from the original subgroup id to get the adjusted | ||
| // sg id | ||
| Value startOfRangeVal = | ||
| rewriter.create<arith::ConstantIndexOp>(loc, startOfRange); | ||
| adjustedSgId = | ||
| rewriter.createOrFold<index::SubOp>(loc, linearSgId, startOfRangeVal); | ||
| } | ||
|
|
||
| auto deLinearizeSgId = | ||
| affine::delinearizeIndex(rewriter, loc, linearSgId, sgLayoutDim); | ||
| affine::delinearizeIndex(rewriter, loc, adjustedSgId, sgLayoutDim); | ||
| if (failed(deLinearizeSgId)) | ||
| return failure(); | ||
| SmallVector<Value> sgIds = *deLinearizeSgId; | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.