-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[MLIR][XeGPU] Add transformation pattern for vector.broadcast in Wg to Sg pass #144417
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
f1509d2
c5cd274
2b23906
803a565
2c97ee7
692ae9e
717664f
9d71167
1d17537
425d677
00ffa57
8467c29
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -331,6 +331,65 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> { | |
| } | ||
| }; | ||
|
|
||
| /// This pattern transforms vector.broadcast ops to work at subgroup level. | ||
| struct WgToSgVectorBroadcastOp | ||
| : public OpConversionPattern<vector::BroadcastOp> { | ||
| using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern; | ||
|
|
||
| LogicalResult | ||
| matchAndRewrite(vector::BroadcastOp op, OneToNOpAdaptor adaptor, | ||
| ConversionPatternRewriter &rewriter) const override { | ||
| VectorType resultType = op.getResult().getType(); | ||
| ArrayRef<int64_t> wgShape = resultType.getShape(); | ||
|
|
||
| xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult()); | ||
| if (!layout || !layout.getSgLayout()) | ||
| return failure(); | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks to me that the current implementation is assuming the rank of source is the same as the rank of the result, which is a subset of the supported semantics of |
||
| // TODO: Currently only supports cases where the source and result ranks | ||
nbpatel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| // are the same. | ||
| auto srcType = | ||
| dyn_cast<VectorType>(adaptor.getOperands().front()[0].getType()); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can |
||
| if (!srcType || srcType.getRank() != resultType.getRank()) | ||
| return failure(); | ||
|
|
||
| SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first; | ||
| VectorType newResultType = | ||
| VectorType::get(sgShape, resultType.getElementType()); | ||
|
|
||
| // Check if the output layout is distributable | ||
| SmallVector<int64_t> sgLayout; | ||
| if (auto sgLayoutAttr = layout.getSgLayout()) | ||
| sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef()); | ||
| else | ||
| return failure(); | ||
|
|
||
| if (!xegpu::XeGPUDialect::isEvenlyDistributable(wgShape, layout)) | ||
| return failure(); | ||
|
|
||
| // Check if the srcShape has unit dim in dimensions being broadcasted, | ||
| // and the other dimensions are the same as the destination type | ||
| // TODO: Generalize it | ||
| auto srcShape = srcType.getShape(); | ||
| for (size_t i = 0; i < srcShape.size(); ++i) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems this check duplicates the check in broadcast verifier, unless there are cases where the source vector, e.g., vector<32x1x1xf32> can be distributed to a vector, e.g., <8x2x1>. |
||
| if (srcShape[i] != 1 && srcShape[i] != sgShape[i]) | ||
| return failure(); | ||
| } | ||
|
|
||
| SmallVector<Value> newBroadcastOps; | ||
| for (auto operand : adaptor.getOperands().front()) { | ||
| auto newBroadcast = rewriter.create<vector::BroadcastOp>( | ||
| op.getLoc(), newResultType, operand); | ||
| xegpu::setLayoutAttr(newBroadcast->getResult(0), | ||
| layout.dropSgLayoutAndData()); | ||
| newBroadcastOps.push_back(newBroadcast.getResult()); | ||
| } | ||
|
|
||
| rewriter.replaceOpWithMultiple(op, {newBroadcastOps}); | ||
| return success(); | ||
| } | ||
| }; | ||
|
|
||
| // This pattern transforms elementwise ops to work at subgroup level. | ||
| struct WgToSgElementwiseOp : public ConversionPattern { | ||
| WgToSgElementwiseOp(MLIRContext *ctx) | ||
|
|
@@ -473,8 +532,8 @@ namespace xegpu { | |
| void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) { | ||
| patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp, | ||
| WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp, | ||
| UnrealizedConversionCastOpPattern, WgToSgElementwiseOp>( | ||
| patterns.getContext()); | ||
| UnrealizedConversionCastOpPattern, WgToSgElementwiseOp, | ||
| WgToSgVectorBroadcastOp>(patterns.getContext()); | ||
| } | ||
| } // namespace xegpu | ||
| } // namespace mlir | ||
|
|
@@ -581,6 +640,11 @@ void XeGPUWgToSgDistributePass::runOnOperation() { | |
| return isLegal(layout); | ||
| }); | ||
|
|
||
| target.addDynamicallyLegalOp<vector::BroadcastOp>( | ||
| [=](vector::BroadcastOp op) -> bool { | ||
| return isLegal(xegpu::getLayoutAttr(op.getResult())); | ||
| }); | ||
|
|
||
| target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>( | ||
| [=](Operation *op) -> std::optional<bool> { | ||
| // Only handle elementwise mappable ops | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.