From d55504607f2a3341146119b0cd893474bc5e583a Mon Sep 17 00:00:00 2001 From: nbpatel Date: Mon, 8 Sep 2025 18:05:00 +0000 Subject: [PATCH 1/8] Add pattern for reduction --- .../Transforms/XeGPUWgToSgDistribute.cpp | 67 ++++++++++++++++++- .../XeGPU/xegpu-wg-to-sg-unify-ops.mlir | 28 ++++++++ 2 files changed, 92 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 5d0f1d18402f2..fab2b8773a6b8 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -757,8 +757,10 @@ struct WgToSgArithConstantOp : public OpConversionPattern { auto sgAttr = DenseElementsAttr::get(newType, singleVal); auto cstOp = arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr); - if (auto newLayout = layout.dropSgLayoutAndData()) - xegpu::setDistributeLayoutAttr(cstOp->getResult(0), newLayout); + if (!layout.getLaneLayoutAsInt().empty() || + !layout.getLaneDataAsInt().empty()) + xegpu::setDistributeLayoutAttr(cstOp->getResult(0), + layout.dropSgLayoutAndData()); SmallVector newConsts(count, cstOp); rewriter.replaceOpWithMultiple(op, {newConsts}); @@ -919,6 +921,59 @@ struct WgToSgStoreMatrixOp : public OpConversionPattern { } }; +// Pattern for lowering vector.multi_reduction op to subgroup level. +struct WgToSgMultiDimReductionOp + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + VectorType srcType = dyn_cast(op.getSource().getType()); + VectorType dstType = dyn_cast(op.getResult().getType()); + if (!srcType || !dstType) + return failure(); + + // Only handle [m,1]->[m] or [1,m]->[m] + // TODO: generalize it + auto srcShape = srcType.getShape(); + auto dstShape = dstType.getShape(); + if (srcShape.size() != 2 || dstShape.size() != 1) + return failure(); + + if (!((srcShape[1] == 1 && srcShape[0] == dstShape[0]) || + (srcShape[0] == 1 && srcShape[1] == dstShape[0]))) + return failure(); + + xegpu::DistributeLayoutAttr layout = + xegpu::getDistributeLayoutAttr(op.getSource()); + if (!layout || !layout.isForWorkgroup()) + return failure(); + + SmallVector sgShape = getSgShapeAndCount(srcShape, layout).first; + VectorType newDstType; + if (op.getReductionDims() == ArrayRef({0})) + newDstType = VectorType::get({sgShape[1]}, dstType.getElementType()); + else + newDstType = VectorType::get({sgShape[0]}, dstType.getElementType()); + + SmallVector newReductions; + for (auto [sgSrc, sgAcc] : + llvm::zip(adaptor.getSource(), adaptor.getAcc())) { + auto newOp = rewriter.create( + op.getLoc(), newDstType, op.getKind(), sgSrc, sgAcc, + op.getReductionDims()); + if (!layout.getLaneLayoutAsInt().empty() || + !layout.getLaneDataAsInt().empty()) + xegpu::setDistributeLayoutAttr(newOp->getResult(0), + layout.dropSgLayoutAndData()); + newReductions.push_back(newOp.getResult()); + } + rewriter.replaceOpWithMultiple(op, {newReductions}); + return success(); + } +}; + } // namespace namespace mlir { @@ -932,7 +987,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) { WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp, WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset, WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp, - WgToSgStoreMatrixOp>(patterns.getContext()); + WgToSgStoreMatrixOp, WgToSgMultiDimReductionOp>( + patterns.getContext()); } } // namespace xegpu } // namespace mlir @@ -1077,6 +1133,11 @@ void XeGPUWgToSgDistributePass::runOnOperation() { return isLegal(xegpu::getDistributeLayoutAttr(op.getResult())); }); + target.addDynamicallyLegalOp( + [=](vector::MultiDimReductionOp op) -> bool { + return isLegal(xegpu::getDistributeLayoutAttr(op.getResult())); + }); + target.addDynamicallyLegalOp( [=](xegpu::ConvertLayoutOp op) -> bool { return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout()); diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir index afb2bf876c18f..47e6f4cfd6d08 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir @@ -365,4 +365,32 @@ gpu.module @test_distribution { xegpu.store_matrix %cst, %mdesc[0, 0] {layout = #xegpu.layout} : vector<64x128xf32>, !xegpu.mem_desc<64x128xf32> gpu.return } + + // CHECK-LABEL: @vector_reduce_dim_0 + gpu.func @vector_reduce_dim_0(%src: memref<1x128xf32>) { + %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [0]>} dense<1.0> : vector<128xf32> + %tdesc = xegpu.create_nd_tdesc %src : memref<1x128xf32> + -> !xegpu.tensor_desc<1x128xf32, #xegpu.layout> + %load = xegpu.load_nd %tdesc[0, 0] + : !xegpu.tensor_desc<1x128xf32, #xegpu.layout> + -> vector<1x128xf32> + // CHECK: vector.multi_reduction , {{.*}}, {{.*}} [0] : vector<1x4xf32> to vector<4xf32> + %reduce = vector.multi_reduction , %load, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [0]>} [0] + : vector<1x128xf32> to vector<128xf32> + gpu.return + } + + // CHECK-LABEL: @vector_reduce_dim_1 + gpu.func @vector_reduce_dim_1(%src: memref<256x1xf32>) { + %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [1]>} dense<1.0> : vector<256xf32> + %tdesc = xegpu.create_nd_tdesc %src : memref<256x1xf32> + -> !xegpu.tensor_desc<256x1xf32, #xegpu.layout> + %load = xegpu.load_nd %tdesc[0, 0] + : !xegpu.tensor_desc<256x1xf32, #xegpu.layout> + -> vector<256x1xf32> + // CHECK: vector.multi_reduction , {{.*}}, {{.*}} [1] : vector<8x1xf32> to vector<8xf32> + %reduce = vector.multi_reduction , %load, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [1]>} [1] + : vector<256x1xf32> to vector<256xf32> + gpu.return + } } From 924a5e1c444f837ef6f808dd2d2dc56b7aa80708 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Fri, 12 Sep 2025 20:10:05 +0000 Subject: [PATCH 2/8] Fix --- .../Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index efa39cbb43255..0aafb0139a095 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -757,8 +757,8 @@ struct WgToSgArithConstantOp : public OpConversionPattern { auto sgAttr = DenseElementsAttr::get(newType, singleVal); auto cstOp = arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr); - if (!layout.getLaneLayoutAsInt().empty() || - !layout.getLaneDataAsInt().empty()) + if (!layout.getEffectiveLaneLayoutAsInt().empty() || + !layout.getEffectiveInstDataAsInt().empty()) xegpu::setDistributeLayoutAttr(cstOp->getResult(0), layout.dropSgLayoutAndData()); SmallVector newConsts(count, cstOp); @@ -963,8 +963,8 @@ struct WgToSgMultiDimReductionOp auto newOp = rewriter.create( op.getLoc(), newDstType, op.getKind(), sgSrc, sgAcc, op.getReductionDims()); - if (!layout.getLaneLayoutAsInt().empty() || - !layout.getLaneDataAsInt().empty()) + if (!layout.getEffectiveLaneLayoutAsInt().empty() || + !layout.getEffectiveInstDataAsInt().empty()) xegpu::setDistributeLayoutAttr(newOp->getResult(0), layout.dropSgLayoutAndData()); newReductions.push_back(newOp.getResult()); From b4761ded9bf5a39c67f09e19821e1c4d452ab249 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Fri, 12 Sep 2025 21:28:58 +0000 Subject: [PATCH 3/8] sg local reduction --- .../Transforms/XeGPUWgToSgDistribute.cpp | 25 +++++++------ .../XeGPU/xegpu-wg-to-sg-unify-ops.mlir | 36 +++++++++---------- 2 files changed, 32 insertions(+), 29 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 0aafb0139a095..29209446c57d8 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -934,28 +934,31 @@ struct WgToSgMultiDimReductionOp if (!srcType || !dstType) return failure(); - // Only handle [m,1]->[m] or [1,m]->[m] // TODO: generalize it auto srcShape = srcType.getShape(); auto dstShape = dstType.getShape(); if (srcShape.size() != 2 || dstShape.size() != 1) return failure(); - if (!((srcShape[1] == 1 && srcShape[0] == dstShape[0]) || - (srcShape[0] == 1 && srcShape[1] == dstShape[0]))) - return failure(); - xegpu::DistributeLayoutAttr layout = - xegpu::getDistributeLayoutAttr(op.getSource()); + xegpu::getDistributeLayoutAttr(op.getResult()); if (!layout || !layout.isForWorkgroup()) return failure(); + auto reductionDims = op.getReductionDims(); + if (reductionDims.size() != 1) + return failure(); + + SmallVector sgLayout = llvm::cast(layout) + .getParent() + .getEffectiveSgLayoutAsInt(); + // Check that the sgLayout in the reduced dimension is 1. + if (sgLayout[reductionDims[0]] != 1) + return failure(); SmallVector sgShape = getSgShapeAndCount(srcShape, layout).first; - VectorType newDstType; - if (op.getReductionDims() == ArrayRef({0})) - newDstType = VectorType::get({sgShape[1]}, dstType.getElementType()); - else - newDstType = VectorType::get({sgShape[0]}, dstType.getElementType()); + + VectorType newDstType = + VectorType::get({sgShape}, dstType.getElementType()); SmallVector newReductions; for (auto [sgSrc, sgAcc] : diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir index 47e6f4cfd6d08..e6db091174dcb 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir @@ -367,30 +367,30 @@ gpu.module @test_distribution { } // CHECK-LABEL: @vector_reduce_dim_0 - gpu.func @vector_reduce_dim_0(%src: memref<1x128xf32>) { - %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [0]>} dense<1.0> : vector<128xf32> - %tdesc = xegpu.create_nd_tdesc %src : memref<1x128xf32> - -> !xegpu.tensor_desc<1x128xf32, #xegpu.layout> + gpu.func @vector_reduce_dim_0(%src: memref<4x128xf32>) { + %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [0]>} dense<1.0> : vector<128xf32> + %tdesc = xegpu.create_nd_tdesc %src : memref<4x128xf32> + -> !xegpu.tensor_desc<4x128xf32, #xegpu.layout> %load = xegpu.load_nd %tdesc[0, 0] - : !xegpu.tensor_desc<1x128xf32, #xegpu.layout> - -> vector<1x128xf32> - // CHECK: vector.multi_reduction , {{.*}}, {{.*}} [0] : vector<1x4xf32> to vector<4xf32> - %reduce = vector.multi_reduction , %load, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [0]>} [0] - : vector<1x128xf32> to vector<128xf32> + : !xegpu.tensor_desc<4x128xf32, #xegpu.layout> + -> vector<4x128xf32> + // CHECK: vector.multi_reduction , {{.*}}, {{.*}} [0] : vector<4x4xf32> to vector<4xf32> + %reduce = vector.multi_reduction , %load, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [0]>} [0] + : vector<4x128xf32> to vector<128xf32> gpu.return } // CHECK-LABEL: @vector_reduce_dim_1 - gpu.func @vector_reduce_dim_1(%src: memref<256x1xf32>) { - %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [1]>} dense<1.0> : vector<256xf32> - %tdesc = xegpu.create_nd_tdesc %src : memref<256x1xf32> - -> !xegpu.tensor_desc<256x1xf32, #xegpu.layout> + gpu.func @vector_reduce_dim_1(%src: memref<256x64xf32>) { + %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [1]>} dense<1.0> : vector<256xf32> + %tdesc = xegpu.create_nd_tdesc %src : memref<256x64xf32> + -> !xegpu.tensor_desc<256x64xf32, #xegpu.layout> %load = xegpu.load_nd %tdesc[0, 0] - : !xegpu.tensor_desc<256x1xf32, #xegpu.layout> - -> vector<256x1xf32> - // CHECK: vector.multi_reduction , {{.*}}, {{.*}} [1] : vector<8x1xf32> to vector<8xf32> - %reduce = vector.multi_reduction , %load, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [1]>} [1] - : vector<256x1xf32> to vector<256xf32> + : !xegpu.tensor_desc<256x64xf32, #xegpu.layout> + -> vector<256x64xf32> + // CHECK: vector.multi_reduction , {{.*}}, {{.*}} [1] : vector<16x64xf32> to vector<16xf32> + %reduce = vector.multi_reduction , %load, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [1]>} [1] + : vector<256x64xf32> to vector<256xf32> gpu.return } } From 9be2284ec7938bc2bca850ae15d5ed64f24262bd Mon Sep 17 00:00:00 2001 From: nbpatel Date: Tue, 16 Sep 2025 18:56:08 +0000 Subject: [PATCH 4/8] Address feedback --- .../Transforms/XeGPUWgToSgDistribute.cpp | 25 ++++++++++++------- .../XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir | 15 +++++++++++ 2 files changed, 31 insertions(+), 9 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 658e10d12e38b..dced2e5351b1e 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -1027,7 +1027,12 @@ struct WgToSgVectorShapeCastOp } }; -// Pattern for lowering vector.multi_reduction op to subgroup level. +/// Pattern for lowering vector.multi_reduction op to subgroup level. +/// Current limitation: only support 2D->1D reduction with single reduction +/// dimension, and the sg_layout in the reduced dimension being 1 +/// so that reduction is local to subgroup & no cross-subgroup communication is +/// needed. +/// TODO: Add cases to handle more general situations which require SLM access. struct WgToSgMultiDimReductionOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -1035,14 +1040,15 @@ struct WgToSgMultiDimReductionOp LogicalResult matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - VectorType srcType = dyn_cast(op.getSource().getType()); + VectorType srcType = op.getSourceVectorType(); VectorType dstType = dyn_cast(op.getResult().getType()); - if (!srcType || !dstType) + if (!dstType) return failure(); - // TODO: generalize it - auto srcShape = srcType.getShape(); - auto dstShape = dstType.getShape(); + SmallVector srcShape(srcType.getShape().begin(), + srcType.getShape().end()); + SmallVector dstShape(dstType.getShape().begin(), + dstType.getShape().end()); if (srcShape.size() != 2 || dstShape.size() != 1) return failure(); @@ -1051,7 +1057,8 @@ struct WgToSgMultiDimReductionOp if (!layout || !layout.isForWorkgroup()) return failure(); - auto reductionDims = op.getReductionDims(); + SmallVector reductionDims(op.getReductionDims().begin(), + op.getReductionDims().end()); if (reductionDims.size() != 1) return failure(); @@ -1069,8 +1076,8 @@ struct WgToSgMultiDimReductionOp SmallVector newReductions; for (auto [sgSrc, sgAcc] : llvm::zip(adaptor.getSource(), adaptor.getAcc())) { - auto newOp = rewriter.create( - op.getLoc(), newDstType, op.getKind(), sgSrc, sgAcc, + auto newOp = vector::MultiDimReductionOp::create( + rewriter, op.getLoc(), newDstType, op.getKind(), sgSrc, sgAcc, op.getReductionDims()); if (!layout.getEffectiveLaneLayoutAsInt().empty() || !layout.getEffectiveInstDataAsInt().empty()) diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir index 6ff7a94d678a3..b4a6aef0b4206 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir @@ -82,4 +82,19 @@ gpu.module @test_distribution { : vector<256x128xf16>, vector<128x256xf16> -> vector<256x256xf32> gpu.return } + + // CHECK-LABEL: vector_reduce_dim_1 + gpu.func @vector_reduce_dim_1(%src: memref<256x64xf32>) { + %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [1]>} dense<1.0> : vector<256xf32> + %tdesc = xegpu.create_nd_tdesc %src : memref<256x64xf32> + -> !xegpu.tensor_desc<256x64xf32, #xegpu.layout> + %load = xegpu.load_nd %tdesc[0, 0] + : !xegpu.tensor_desc<256x64xf32, #xegpu.layout> + -> vector<256x64xf32> + // CHECK-COUNT-2: vector.multi_reduction , {{.*}}, {{.*}} [1] : vector<16x64xf32> to vector<16xf32> + // CHECK-NOT: vector.multi_reduction + %reduce = vector.multi_reduction , %load, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [1]>} [1] + : vector<256x64xf32> to vector<256xf32> + gpu.return + } } From be94e2fa0d5529a1c92e9f5c8cb60406d17d1edc Mon Sep 17 00:00:00 2001 From: nbpatel Date: Wed, 17 Sep 2025 04:07:44 +0000 Subject: [PATCH 5/8] CHECK --- mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir index b4a6aef0b4206..dce73dee507e1 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir @@ -85,13 +85,14 @@ gpu.module @test_distribution { // CHECK-LABEL: vector_reduce_dim_1 gpu.func @vector_reduce_dim_1(%src: memref<256x64xf32>) { + // CHECK: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<16xf32> %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [1]>} dense<1.0> : vector<256xf32> %tdesc = xegpu.create_nd_tdesc %src : memref<256x64xf32> -> !xegpu.tensor_desc<256x64xf32, #xegpu.layout> %load = xegpu.load_nd %tdesc[0, 0] : !xegpu.tensor_desc<256x64xf32, #xegpu.layout> -> vector<256x64xf32> - // CHECK-COUNT-2: vector.multi_reduction , {{.*}}, {{.*}} [1] : vector<16x64xf32> to vector<16xf32> + // CHECK-COUNT-2: vector.multi_reduction , {{.*}}, %[[CST]] [1] : vector<16x64xf32> to vector<16xf32> // CHECK-NOT: vector.multi_reduction %reduce = vector.multi_reduction , %load, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [1]>} [1] : vector<256x64xf32> to vector<256xf32> From 4217fd6a1f1dd8e824735c804f939478e7991a4e Mon Sep 17 00:00:00 2001 From: nbpatel Date: Wed, 24 Sep 2025 17:41:58 +0000 Subject: [PATCH 6/8] restrict chained reduction --- .../Transforms/XeGPUWgToSgDistribute.cpp | 32 +++++++++++-------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index dced2e5351b1e..ab07125c50b45 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -1045,10 +1045,8 @@ struct WgToSgMultiDimReductionOp if (!dstType) return failure(); - SmallVector srcShape(srcType.getShape().begin(), - srcType.getShape().end()); - SmallVector dstShape(dstType.getShape().begin(), - dstType.getShape().end()); + auto srcShape = srcType.getShape(); + auto dstShape = dstType.getShape(); if (srcShape.size() != 2 || dstShape.size() != 1) return failure(); @@ -1057,27 +1055,34 @@ struct WgToSgMultiDimReductionOp if (!layout || !layout.isForWorkgroup()) return failure(); - SmallVector reductionDims(op.getReductionDims().begin(), - op.getReductionDims().end()); + auto reductionDims = llvm::to_vector(op.getReductionDims()); if (reductionDims.size() != 1) return failure(); SmallVector sgLayout = llvm::cast(layout) .getParent() .getEffectiveSgLayoutAsInt(); - // Check that the sgLayout in the reduced dimension is 1. - if (sgLayout[reductionDims[0]] != 1) - return failure(); + SmallVector sgData = llvm::cast(layout) + .getParent() + .getEffectiveSgDataAsInt(); + + // Check that the sgLayout in the reduced dimension is 1 and + // each sg gets the entire slice to reduce. + if (sgLayout[reductionDims[0]] != 1 || + sgData[reductionDims[0]] != srcShape[reductionDims[0]]) + return rewriter.notifyMatchFailure( + op, "sgLayout in reduced dimension must be 1 and sgData in the " + "reduced dim must match srcShape in that dim"); + SmallVector sgShape = getSgShapeAndCount(srcShape, layout).first; VectorType newDstType = VectorType::get({sgShape}, dstType.getElementType()); SmallVector newReductions; - for (auto [sgSrc, sgAcc] : - llvm::zip(adaptor.getSource(), adaptor.getAcc())) { - auto newOp = vector::MultiDimReductionOp::create( - rewriter, op.getLoc(), newDstType, op.getKind(), sgSrc, sgAcc, + for (auto sgSrc : adaptor.getSource()) { + auto newOp = rewriter.create( + op.getLoc(), newDstType, op.getKind(), sgSrc, adaptor.getAcc()[0], op.getReductionDims()); if (!layout.getEffectiveLaneLayoutAsInt().empty() || !layout.getEffectiveInstDataAsInt().empty()) @@ -1085,6 +1090,7 @@ struct WgToSgMultiDimReductionOp layout.dropSgLayoutAndData()); newReductions.push_back(newOp.getResult()); } + rewriter.replaceOpWithMultiple(op, {newReductions}); return success(); } From c06ca5482a22d94a65144f9d50775c0b1da3f160 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Thu, 25 Sep 2025 04:41:22 +0000 Subject: [PATCH 7/8] Relax check for 2D reduction --- .../XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 18 +++++++----------- .../XeGPU/xegpu-wg-to-sg-unify-ops.mlir | 12 ++++++++++++ 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index ab07125c50b45..dce886340d1c9 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -1046,18 +1046,12 @@ struct WgToSgMultiDimReductionOp return failure(); auto srcShape = srcType.getShape(); - auto dstShape = dstType.getShape(); - if (srcShape.size() != 2 || dstShape.size() != 1) - return failure(); - xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(op.getResult()); if (!layout || !layout.isForWorkgroup()) return failure(); auto reductionDims = llvm::to_vector(op.getReductionDims()); - if (reductionDims.size() != 1) - return failure(); SmallVector sgLayout = llvm::cast(layout) .getParent() @@ -1068,11 +1062,13 @@ struct WgToSgMultiDimReductionOp // Check that the sgLayout in the reduced dimension is 1 and // each sg gets the entire slice to reduce. - if (sgLayout[reductionDims[0]] != 1 || - sgData[reductionDims[0]] != srcShape[reductionDims[0]]) - return rewriter.notifyMatchFailure( - op, "sgLayout in reduced dimension must be 1 and sgData in the " - "reduced dim must match srcShape in that dim"); + for (int64_t dim : reductionDims) { + if (sgLayout[dim] != 1 || sgData[dim] != srcShape[dim]) + return rewriter.notifyMatchFailure( + op, + "sgLayout in each reduced dimension must be 1 and sgData in the " + "reduced dim must match srcShape in that dim"); + } SmallVector sgShape = getSgShapeAndCount(srcShape, layout).first; diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir index 7583ed833205f..48fc633974e63 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir @@ -395,6 +395,18 @@ gpu.module @test_distribution { gpu.return } + // CHECK-LABEL: @vector_reduce_4D + gpu.func @vector_reduce_4D(%src: ui64) { + %cst_acc = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [3]>} dense<0.0> : vector<4x2x6xf16> + %offset = arith.constant {layout_result_0 = #xegpu.layout} dense<0> : vector<4x2x6x32xindex> + %mask = arith.constant {layout_result_0 = #xegpu.layout} dense : vector<4x2x6x32xi1> + %load = xegpu.load %src[%offset], %mask {layout_result_0 = #xegpu.layout} : ui64, vector<4x2x6x32xindex>, vector<4x2x6x32xi1> -> vector<4x2x6x32xf16> + // CHECK: vector.multi_reduction , {{.*}}, {{.*}} [3] : vector<1x1x1x32xf16> to vector<1x1x1xf16> + %reduce = vector.multi_reduction , %load, %cst_acc {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [3]>} [3] + : vector<4x2x6x32xf16> to vector<4x2x6xf16> + gpu.return + } + // CHECK-LABEL: vector_step_op gpu.func @vector_step_op_slice_attr() { //CHECK: [[sgId:%.+]] = gpu.subgroup_id : index From f84f6fd0bce2536ce7e6c333e210e1cbb6bd5117 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Thu, 25 Sep 2025 16:44:55 +0000 Subject: [PATCH 8/8] Fix comment --- mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index dce886340d1c9..9413a9296b184 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -1028,8 +1028,7 @@ struct WgToSgVectorShapeCastOp }; /// Pattern for lowering vector.multi_reduction op to subgroup level. -/// Current limitation: only support 2D->1D reduction with single reduction -/// dimension, and the sg_layout in the reduced dimension being 1 +/// Current limitation: the sg_layout in the reduced dimension being 1 /// so that reduction is local to subgroup & no cross-subgroup communication is /// needed. /// TODO: Add cases to handle more general situations which require SLM access.