Skip to content

Conversation

nbpatel
Copy link
Contributor

@nbpatel nbpatel commented Sep 8, 2025

This PR adds pattern for lowering vector.multi_reduction from workgroup to subgroup IR. It currently only supports sg local reductions

@llvmbot
Copy link
Member

llvmbot commented Sep 8, 2025

@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: Nishant Patel (nbpatel)

Changes

This PR adds pattern for lowering vector.multi_reduction from workgroup to subgroup IR. It currently only supports simple reductions of form

  1. <mx1> to <m>
  2. <1xm> to <m>

Full diff: https://github.com/llvm/llvm-project/pull/157554.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp (+64-3)
  • (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir (+28)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 5d0f1d18402f2..fab2b8773a6b8 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -757,8 +757,10 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
     auto sgAttr = DenseElementsAttr::get(newType, singleVal);
     auto cstOp =
         arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr);
-    if (auto newLayout = layout.dropSgLayoutAndData())
-      xegpu::setDistributeLayoutAttr(cstOp->getResult(0), newLayout);
+    if (!layout.getLaneLayoutAsInt().empty() ||
+        !layout.getLaneDataAsInt().empty())
+      xegpu::setDistributeLayoutAttr(cstOp->getResult(0),
+                                     layout.dropSgLayoutAndData());
     SmallVector<Value> newConsts(count, cstOp);
 
     rewriter.replaceOpWithMultiple(op, {newConsts});
@@ -919,6 +921,59 @@ struct WgToSgStoreMatrixOp : public OpConversionPattern<xegpu::StoreMatrixOp> {
   }
 };
 
+// Pattern for lowering vector.multi_reduction op to subgroup level.
+struct WgToSgMultiDimReductionOp
+    : public OpConversionPattern<vector::MultiDimReductionOp> {
+  using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    VectorType srcType = dyn_cast<VectorType>(op.getSource().getType());
+    VectorType dstType = dyn_cast<VectorType>(op.getResult().getType());
+    if (!srcType || !dstType)
+      return failure();
+
+    // Only handle [m,1]->[m] or [1,m]->[m]
+    // TODO: generalize it
+    auto srcShape = srcType.getShape();
+    auto dstShape = dstType.getShape();
+    if (srcShape.size() != 2 || dstShape.size() != 1)
+      return failure();
+
+    if (!((srcShape[1] == 1 && srcShape[0] == dstShape[0]) ||
+          (srcShape[0] == 1 && srcShape[1] == dstShape[0])))
+      return failure();
+
+    xegpu::DistributeLayoutAttr layout =
+        xegpu::getDistributeLayoutAttr(op.getSource());
+    if (!layout || !layout.isForWorkgroup())
+      return failure();
+
+    SmallVector<int64_t> sgShape = getSgShapeAndCount(srcShape, layout).first;
+    VectorType newDstType;
+    if (op.getReductionDims() == ArrayRef<int64_t>({0}))
+      newDstType = VectorType::get({sgShape[1]}, dstType.getElementType());
+    else
+      newDstType = VectorType::get({sgShape[0]}, dstType.getElementType());
+
+    SmallVector<Value> newReductions;
+    for (auto [sgSrc, sgAcc] :
+         llvm::zip(adaptor.getSource(), adaptor.getAcc())) {
+      auto newOp = rewriter.create<vector::MultiDimReductionOp>(
+          op.getLoc(), newDstType, op.getKind(), sgSrc, sgAcc,
+          op.getReductionDims());
+      if (!layout.getLaneLayoutAsInt().empty() ||
+          !layout.getLaneDataAsInt().empty())
+        xegpu::setDistributeLayoutAttr(newOp->getResult(0),
+                                       layout.dropSgLayoutAndData());
+      newReductions.push_back(newOp.getResult());
+    }
+    rewriter.replaceOpWithMultiple(op, {newReductions});
+    return success();
+  }
+};
+
 } // namespace
 
 namespace mlir {
@@ -932,7 +987,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
            WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
            WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
            WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
-           WgToSgStoreMatrixOp>(patterns.getContext());
+           WgToSgStoreMatrixOp, WgToSgMultiDimReductionOp>(
+          patterns.getContext());
 }
 } // namespace xegpu
 } // namespace mlir
@@ -1077,6 +1133,11 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
         return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));
       });
 
+  target.addDynamicallyLegalOp<vector::MultiDimReductionOp>(
+      [=](vector::MultiDimReductionOp op) -> bool {
+        return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));
+      });
+
   target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
       [=](xegpu::ConvertLayoutOp op) -> bool {
         return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index afb2bf876c18f..47e6f4cfd6d08 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -365,4 +365,32 @@ gpu.module @test_distribution {
     xegpu.store_matrix %cst, %mdesc[0, 0] {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [32, 32]>} : vector<64x128xf32>, !xegpu.mem_desc<64x128xf32>
     gpu.return
   }
+
+  // CHECK-LABEL: @vector_reduce_dim_0
+  gpu.func @vector_reduce_dim_0(%src: memref<1x128xf32>) {
+    %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [1, 32], sg_data = [1, 4]>, dims = [0]>} dense<1.0> : vector<128xf32>
+    %tdesc = xegpu.create_nd_tdesc %src : memref<1x128xf32>
+      -> !xegpu.tensor_desc<1x128xf32, #xegpu.layout<sg_layout = [1, 32], sg_data = [1, 4]>>
+    %load =  xegpu.load_nd %tdesc[0, 0]
+      : !xegpu.tensor_desc<1x128xf32, #xegpu.layout<sg_layout = [1, 32], sg_data = [1, 4]>>
+      -> vector<1x128xf32>
+    // CHECK: vector.multi_reduction <add>, {{.*}}, {{.*}} [0] : vector<1x4xf32> to vector<4xf32>
+    %reduce = vector.multi_reduction <add>, %load, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [1, 32], sg_data = [1, 4]>, dims = [0]>} [0]
+      : vector<1x128xf32> to vector<128xf32>
+    gpu.return
+  }
+
+  // CHECK-LABEL: @vector_reduce_dim_1
+  gpu.func @vector_reduce_dim_1(%src: memref<256x1xf32>) {
+    %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [32, 1], sg_data = [8, 1]>, dims = [1]>} dense<1.0> : vector<256xf32>
+    %tdesc = xegpu.create_nd_tdesc %src : memref<256x1xf32>
+      -> !xegpu.tensor_desc<256x1xf32, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 1]>>
+    %load =  xegpu.load_nd %tdesc[0, 0]
+      : !xegpu.tensor_desc<256x1xf32, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 1]>>
+      -> vector<256x1xf32>
+    // CHECK: vector.multi_reduction <add>, {{.*}}, {{.*}} [1] : vector<8x1xf32> to vector<8xf32>
+    %reduce = vector.multi_reduction <add>, %load, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [32, 1], sg_data = [8, 1]>, dims = [1]>} [1]
+      : vector<256x1xf32> to vector<256xf32>
+    gpu.return
+  }
 }

@Garra1980
Copy link

Please modify the PR title so we know it is for the simplest case. It can be e.g. [1/N] tag or explicit mentioning of the case supported

@nbpatel nbpatel changed the title [MLIR][XeGPU] Add support for vector.multi_reduction in wg to sg pass [MLIR][XeGPU] Add support for vector.multi_reduction in wg to sg pass [1/N] Sep 8, 2025
@nbpatel
Copy link
Contributor Author

nbpatel commented Sep 12, 2025

Addressed feedback. Please take a look

@charithaintc charithaintc self-requested a review September 15, 2025 22:14
@nbpatel
Copy link
Contributor Author

nbpatel commented Sep 16, 2025

Addressed Feedback

@charithaintc charithaintc self-requested a review September 18, 2025 23:35
Copy link
Contributor

@charithaintc charithaintc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Jianhui-Li Jianhui-Li self-requested a review September 23, 2025 23:47
@nbpatel nbpatel force-pushed the xegpu-reduction-unit-dim branch from a402217 to 4217fd6 Compare September 24, 2025 18:17
Copy link
Contributor

@Jianhui-Li Jianhui-Li left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@nbpatel nbpatel merged commit 50a7eb6 into llvm:main Sep 25, 2025
9 checks passed
mahesh-attarde pushed a commit to mahesh-attarde/llvm-project that referenced this pull request Oct 3, 2025
… [1/N] (llvm#157554)

This PR adds pattern for lowering vector.multi_reduction from workgroup
to subgroup IR. It currently only supports sg local reductions
@nbpatel nbpatel deleted the xegpu-reduction-unit-dim branch October 16, 2025 18:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants