-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[MLIR][XeGPU] Distribute load_gather/store_scatter op from Wg To Sg #154420
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
b94a37f
4f490ef
459e98a
e3c02a6
a25c40d
bdbf14f
c93090f
a7b780d
21f1f4f
73204b7
efc211b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -763,6 +763,88 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> { | |
| } | ||
| }; | ||
|
|
||
| // This pattern transforms the LoadGatherOp with explicit offsets to load | ||
| // subgroup data, similar to WgToSgLoadNdOpWithOffset. | ||
| struct WgToSgLoadGatherOpWithOffset | ||
| : public OpConversionPattern<xegpu::LoadGatherOp> { | ||
| using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern; | ||
| LogicalResult | ||
| matchAndRewrite(xegpu::LoadGatherOp op, OneToNOpAdaptor adaptor, | ||
| ConversionPatternRewriter &rewriter) const override { | ||
|
|
||
| if (!op.getOffsets()) | ||
| return failure(); | ||
|
|
||
| Location loc = op.getLoc(); | ||
| VectorType resultType = op.getResult().getType(); | ||
| ArrayRef<int64_t> wgShape = resultType.getShape(); | ||
|
|
||
| xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult()); | ||
| if (!layout || !layout.getSgLayout()) | ||
| return failure(); | ||
|
|
||
| SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first; | ||
|
|
||
| SmallVector<Value> newLoadOps; | ||
| auto chunkSizeAttr = | ||
| rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1)); | ||
| VectorType newTy = VectorType::get(sgShape, resultType.getElementType()); | ||
| for (auto [offsets, mask] : | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: Here the code assums the offset has been distributed by its defining op. It is not always true currently,
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1, could you add a test for this?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. with the current design, its not possible to add negative tests |
||
| llvm::zip(adaptor.getOffsets(), adaptor.getMask())) { | ||
| auto newLoadOp = rewriter.create<xegpu::LoadGatherOp>( | ||
| loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr, | ||
| op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr()); | ||
| xegpu::setLayoutAttr(newLoadOp->getResult(0), | ||
| layout.dropSgLayoutAndData()); | ||
| newLoadOps.push_back(newLoadOp); | ||
| } | ||
| rewriter.replaceOpWithMultiple(op, {newLoadOps}); | ||
| return success(); | ||
| } | ||
| }; | ||
|
|
||
| // This pattern transforms the StoreScatterOp with explicit offsets to store | ||
| // subgroup data, similar to WgToSgStoreNdOpWithOffset. | ||
| struct WgToSgStoreScatterOpWithOffset | ||
| : public OpConversionPattern<xegpu::StoreScatterOp> { | ||
| using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern; | ||
| LogicalResult | ||
| matchAndRewrite(xegpu::StoreScatterOp op, OneToNOpAdaptor adaptor, | ||
| ConversionPatternRewriter &rewriter) const override { | ||
|
|
||
| if (!op.getOffsets()) | ||
| return failure(); | ||
|
|
||
| Location loc = op.getLoc(); | ||
| VectorType valueType = dyn_cast<VectorType>(op.getValue().getType()); | ||
| if (!valueType) | ||
| return failure(); | ||
|
|
||
| ArrayRef<int64_t> wgShape = valueType.getShape(); | ||
| xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getValue()); | ||
| if (!layout || !layout.getSgLayout()) | ||
nbpatel marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return failure(); | ||
|
|
||
| auto chunkSizeOpt = op.getChunkSize(); | ||
| int64_t chunkSize = chunkSizeOpt ? static_cast<int64_t>(*chunkSizeOpt) : 1; | ||
| auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize); | ||
| for (auto [val, offs, mask] : llvm::zip( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same check for offsets as above. |
||
| adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) { | ||
| rewriter.create<xegpu::StoreScatterOp>( | ||
| loc, val, op.getDest(), offs, mask, chunkSizeAttr, op.getL1HintAttr(), | ||
| op.getL2HintAttr(), op.getL3HintAttr()); | ||
| // Update the layout_result_0 attribute to drop sg_layout and sg_data. | ||
| if (auto layoutAttr = | ||
|
||
| op->getAttrOfType<xegpu::LayoutAttr>("layout_result_0")) { | ||
| if (auto newLayout = layoutAttr.dropSgLayoutAndData()) | ||
| op->setAttr("layout_result_0", newLayout); | ||
| } | ||
| } | ||
| rewriter.eraseOp(op); | ||
| return success(); | ||
| } | ||
| }; | ||
|
|
||
| struct WgToSgLoadMatrixOp : public OpConversionPattern<xegpu::LoadMatrixOp> { | ||
| using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern; | ||
| LogicalResult | ||
|
|
@@ -824,8 +906,9 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) { | |
| WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp, | ||
| WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern, | ||
| WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp, | ||
| WgToSgArithConstantOp, WgToSgLoadMatrixOp, WgToSgStoreMatrixOp>( | ||
| patterns.getContext()); | ||
| WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset, | ||
| WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp, | ||
| WgToSgStoreMatrixOp>(patterns.getContext()); | ||
| } | ||
| } // namespace xegpu | ||
| } // namespace mlir | ||
|
|
@@ -950,6 +1033,21 @@ void XeGPUWgToSgDistributePass::runOnOperation() { | |
| return isLegal(xegpu::getLayoutAttr(op.getResult())); | ||
| }); | ||
|
|
||
| target.addDynamicallyLegalOp<xegpu::LoadGatherOp>( | ||
| [=](xegpu::LoadGatherOp op) -> bool { | ||
| auto layout = xegpu::getLayoutAttr(op.getResult()); | ||
| return isLegal(layout); | ||
| }); | ||
|
|
||
| target.addDynamicallyLegalOp<xegpu::StoreScatterOp>( | ||
| [=](xegpu::StoreScatterOp op) -> bool { | ||
| // Check if the layout attribute is present on the result. | ||
| auto layout = op->getAttrOfType<xegpu::LayoutAttr>("layout_result_0"); | ||
| if (!layout) | ||
| return true; | ||
| return isLegal(layout); | ||
| }); | ||
|
|
||
| target.addDynamicallyLegalOp<vector::BroadcastOp>( | ||
| [=](vector::BroadcastOp op) -> bool { | ||
| return isLegal(xegpu::getLayoutAttr(op.getResult())); | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.