-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[MLIR][XeGPU] Add support for vector.multi_reduction in wg to sg pass [1/N] #157554
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-mlir Author: Nishant Patel (nbpatel) ChangesThis PR adds pattern for lowering vector.multi_reduction from workgroup to subgroup IR. It currently only supports simple reductions of form
Full diff: https://github.com/llvm/llvm-project/pull/157554.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 5d0f1d18402f2..fab2b8773a6b8 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -757,8 +757,10 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
auto sgAttr = DenseElementsAttr::get(newType, singleVal);
auto cstOp =
arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr);
- if (auto newLayout = layout.dropSgLayoutAndData())
- xegpu::setDistributeLayoutAttr(cstOp->getResult(0), newLayout);
+ if (!layout.getLaneLayoutAsInt().empty() ||
+ !layout.getLaneDataAsInt().empty())
+ xegpu::setDistributeLayoutAttr(cstOp->getResult(0),
+ layout.dropSgLayoutAndData());
SmallVector<Value> newConsts(count, cstOp);
rewriter.replaceOpWithMultiple(op, {newConsts});
@@ -919,6 +921,59 @@ struct WgToSgStoreMatrixOp : public OpConversionPattern<xegpu::StoreMatrixOp> {
}
};
+// Pattern for lowering vector.multi_reduction op to subgroup level.
+struct WgToSgMultiDimReductionOp
+ : public OpConversionPattern<vector::MultiDimReductionOp> {
+ using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ VectorType srcType = dyn_cast<VectorType>(op.getSource().getType());
+ VectorType dstType = dyn_cast<VectorType>(op.getResult().getType());
+ if (!srcType || !dstType)
+ return failure();
+
+ // Only handle [m,1]->[m] or [1,m]->[m]
+ // TODO: generalize it
+ auto srcShape = srcType.getShape();
+ auto dstShape = dstType.getShape();
+ if (srcShape.size() != 2 || dstShape.size() != 1)
+ return failure();
+
+ if (!((srcShape[1] == 1 && srcShape[0] == dstShape[0]) ||
+ (srcShape[0] == 1 && srcShape[1] == dstShape[0])))
+ return failure();
+
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(op.getSource());
+ if (!layout || !layout.isForWorkgroup())
+ return failure();
+
+ SmallVector<int64_t> sgShape = getSgShapeAndCount(srcShape, layout).first;
+ VectorType newDstType;
+ if (op.getReductionDims() == ArrayRef<int64_t>({0}))
+ newDstType = VectorType::get({sgShape[1]}, dstType.getElementType());
+ else
+ newDstType = VectorType::get({sgShape[0]}, dstType.getElementType());
+
+ SmallVector<Value> newReductions;
+ for (auto [sgSrc, sgAcc] :
+ llvm::zip(adaptor.getSource(), adaptor.getAcc())) {
+ auto newOp = rewriter.create<vector::MultiDimReductionOp>(
+ op.getLoc(), newDstType, op.getKind(), sgSrc, sgAcc,
+ op.getReductionDims());
+ if (!layout.getLaneLayoutAsInt().empty() ||
+ !layout.getLaneDataAsInt().empty())
+ xegpu::setDistributeLayoutAttr(newOp->getResult(0),
+ layout.dropSgLayoutAndData());
+ newReductions.push_back(newOp.getResult());
+ }
+ rewriter.replaceOpWithMultiple(op, {newReductions});
+ return success();
+ }
+};
+
} // namespace
namespace mlir {
@@ -932,7 +987,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
- WgToSgStoreMatrixOp>(patterns.getContext());
+ WgToSgStoreMatrixOp, WgToSgMultiDimReductionOp>(
+ patterns.getContext());
}
} // namespace xegpu
} // namespace mlir
@@ -1077,6 +1133,11 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));
});
+ target.addDynamicallyLegalOp<vector::MultiDimReductionOp>(
+ [=](vector::MultiDimReductionOp op) -> bool {
+ return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));
+ });
+
target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
[=](xegpu::ConvertLayoutOp op) -> bool {
return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index afb2bf876c18f..47e6f4cfd6d08 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -365,4 +365,32 @@ gpu.module @test_distribution {
xegpu.store_matrix %cst, %mdesc[0, 0] {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [32, 32]>} : vector<64x128xf32>, !xegpu.mem_desc<64x128xf32>
gpu.return
}
+
+ // CHECK-LABEL: @vector_reduce_dim_0
+ gpu.func @vector_reduce_dim_0(%src: memref<1x128xf32>) {
+ %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [1, 32], sg_data = [1, 4]>, dims = [0]>} dense<1.0> : vector<128xf32>
+ %tdesc = xegpu.create_nd_tdesc %src : memref<1x128xf32>
+ -> !xegpu.tensor_desc<1x128xf32, #xegpu.layout<sg_layout = [1, 32], sg_data = [1, 4]>>
+ %load = xegpu.load_nd %tdesc[0, 0]
+ : !xegpu.tensor_desc<1x128xf32, #xegpu.layout<sg_layout = [1, 32], sg_data = [1, 4]>>
+ -> vector<1x128xf32>
+ // CHECK: vector.multi_reduction <add>, {{.*}}, {{.*}} [0] : vector<1x4xf32> to vector<4xf32>
+ %reduce = vector.multi_reduction <add>, %load, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [1, 32], sg_data = [1, 4]>, dims = [0]>} [0]
+ : vector<1x128xf32> to vector<128xf32>
+ gpu.return
+ }
+
+ // CHECK-LABEL: @vector_reduce_dim_1
+ gpu.func @vector_reduce_dim_1(%src: memref<256x1xf32>) {
+ %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [32, 1], sg_data = [8, 1]>, dims = [1]>} dense<1.0> : vector<256xf32>
+ %tdesc = xegpu.create_nd_tdesc %src : memref<256x1xf32>
+ -> !xegpu.tensor_desc<256x1xf32, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 1]>>
+ %load = xegpu.load_nd %tdesc[0, 0]
+ : !xegpu.tensor_desc<256x1xf32, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 1]>>
+ -> vector<256x1xf32>
+ // CHECK: vector.multi_reduction <add>, {{.*}}, {{.*}} [1] : vector<8x1xf32> to vector<8xf32>
+ %reduce = vector.multi_reduction <add>, %load, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [32, 1], sg_data = [8, 1]>, dims = [1]>} [1]
+ : vector<256x1xf32> to vector<256xf32>
+ gpu.return
+ }
}
|
Please modify the PR title so we know it is for the simplest case. It can be e.g. [1/N] tag or explicit mentioning of the case supported |
Addressed feedback. Please take a look |
Addressed Feedback |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
a402217
to
4217fd6
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
… [1/N] (llvm#157554) This PR adds pattern for lowering vector.multi_reduction from workgroup to subgroup IR. It currently only supports sg local reductions
This PR adds pattern for lowering vector.multi_reduction from workgroup to subgroup IR. It currently only supports sg local reductions