Skip to content

Conversation

@nbpatel
Copy link
Contributor

@nbpatel nbpatel commented Aug 4, 2025

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Aug 4, 2025

@llvm/pr-subscribers-mlir-gpu

Author: Nishant Patel (nbpatel)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/151977.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp (+56-2)
  • (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir (+7)
  • (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir (+7)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 850f70cca288f..878638061db5c 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -649,6 +649,52 @@ struct UnrealizedConversionCastOpPattern
   }
 };
 
+// This pattern distributes arith.constant op into subgroup-level constants
+struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
+  using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(arith::ConstantOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue());
+    auto vecType = dyn_cast<VectorType>(op.getType());
+    if (!vecAttr || !vecType)
+      return failure();
+
+    xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
+    if (!layout || !layout.getSgLayout())
+      return failure();
+
+    ArrayRef<int64_t> wgShape = vecType.getShape();
+    SmallVector<int64_t> sgShape;
+    int count;
+    std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
+
+    // Current limitation: constant of vector with single value.
+    // TODO: support more complex cases, e.g., vector with multiple values.
+    Attribute singleVal;
+    if (vecAttr.isSplat())
+      singleVal = vecAttr.getSplatValue<Attribute>();
+    else
+      return failure();
+
+    SmallVector<Value> newConsts;
+    auto newType = VectorType::get(sgShape, vecType.getElementType());
+    auto newLayout = layout.dropSgLayoutAndData();
+    for (int i = 0; i < count; ++i) {
+      auto sgAttr = DenseElementsAttr::get(newType, singleVal);
+      auto cstOp =
+          rewriter.create<arith::ConstantOp>(op.getLoc(), newType, sgAttr);
+      if (newLayout)
+        xegpu::setLayoutAttr(cstOp->getResult(0), newLayout);
+      newConsts.push_back(cstOp);
+    }
+
+    rewriter.replaceOpWithMultiple(op, {newConsts});
+    return success();
+  }
+};
+
 } // namespace
 
 namespace mlir {
@@ -657,8 +703,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
   patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
                WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
                UnrealizedConversionCastOpPattern, WgToSgElementwiseOp,
-               WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp>(
-      patterns.getContext());
+               WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
+               WgToSgArithConstantOp>(patterns.getContext());
 }
 } // namespace xegpu
 } // namespace mlir
@@ -770,6 +816,14 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
         return isLegal(xegpu::getLayoutAttr(op.getResult()));
       });
 
+  target.addDynamicallyLegalOp<arith::ConstantOp>(
+      [=](arith::ConstantOp op) -> bool {
+        auto vecType = dyn_cast<VectorType>(op.getType());
+        if (!vecType)
+          return true;
+        return isLegal(xegpu::getLayoutAttr(op.getResult()));
+      });
+
   target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
       [=](xegpu::ConvertLayoutOp op) -> bool {
         return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
index d67bdb487d8bf..65f4b46ad6d26 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
@@ -225,4 +225,11 @@ gpu.module @test_round_robin_assignment {
                                    target_layout = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [8, 16]>}> : vector<32x64xf32>
     gpu.return
   }
+
+  // CHECK-LABEL: distribute_constant
+  gpu.func @distribute_constant() {
+    // CHECK-COUNT-4: arith.constant dense<1.000000e+00> : vector<16x16xf32>
+    %cst = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16]>} dense<1.0> : vector<256x128xf32>
+    gpu.return
+  }
 }
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index d51122417fb61..415753a652092 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -393,4 +393,11 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
   } {sg_id_range = #xegpu.range<[3, 19]>}
   gpu.return
   }
+
+  // CHECK-LABEL: distribute_constant
+  gpu.func @distribute_constant() {
+    // CHECK: arith.constant dense<1.000000e+00> : vector<32x32xf32>
+    %cst = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>} dense<1.0> : vector<256x128xf32>
+    gpu.return
+  }
 }

@llvmbot
Copy link
Member

llvmbot commented Aug 4, 2025

@llvm/pr-subscribers-mlir

Author: Nishant Patel (nbpatel)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/151977.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp (+56-2)
  • (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir (+7)
  • (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir (+7)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 850f70cca288f..878638061db5c 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -649,6 +649,52 @@ struct UnrealizedConversionCastOpPattern
   }
 };
 
+// This pattern distributes arith.constant op into subgroup-level constants
+struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
+  using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(arith::ConstantOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue());
+    auto vecType = dyn_cast<VectorType>(op.getType());
+    if (!vecAttr || !vecType)
+      return failure();
+
+    xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
+    if (!layout || !layout.getSgLayout())
+      return failure();
+
+    ArrayRef<int64_t> wgShape = vecType.getShape();
+    SmallVector<int64_t> sgShape;
+    int count;
+    std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
+
+    // Current limitation: constant of vector with single value.
+    // TODO: support more complex cases, e.g., vector with multiple values.
+    Attribute singleVal;
+    if (vecAttr.isSplat())
+      singleVal = vecAttr.getSplatValue<Attribute>();
+    else
+      return failure();
+
+    SmallVector<Value> newConsts;
+    auto newType = VectorType::get(sgShape, vecType.getElementType());
+    auto newLayout = layout.dropSgLayoutAndData();
+    for (int i = 0; i < count; ++i) {
+      auto sgAttr = DenseElementsAttr::get(newType, singleVal);
+      auto cstOp =
+          rewriter.create<arith::ConstantOp>(op.getLoc(), newType, sgAttr);
+      if (newLayout)
+        xegpu::setLayoutAttr(cstOp->getResult(0), newLayout);
+      newConsts.push_back(cstOp);
+    }
+
+    rewriter.replaceOpWithMultiple(op, {newConsts});
+    return success();
+  }
+};
+
 } // namespace
 
 namespace mlir {
@@ -657,8 +703,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
   patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
                WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
                UnrealizedConversionCastOpPattern, WgToSgElementwiseOp,
-               WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp>(
-      patterns.getContext());
+               WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
+               WgToSgArithConstantOp>(patterns.getContext());
 }
 } // namespace xegpu
 } // namespace mlir
@@ -770,6 +816,14 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
         return isLegal(xegpu::getLayoutAttr(op.getResult()));
       });
 
+  target.addDynamicallyLegalOp<arith::ConstantOp>(
+      [=](arith::ConstantOp op) -> bool {
+        auto vecType = dyn_cast<VectorType>(op.getType());
+        if (!vecType)
+          return true;
+        return isLegal(xegpu::getLayoutAttr(op.getResult()));
+      });
+
   target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
       [=](xegpu::ConvertLayoutOp op) -> bool {
         return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
index d67bdb487d8bf..65f4b46ad6d26 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
@@ -225,4 +225,11 @@ gpu.module @test_round_robin_assignment {
                                    target_layout = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [8, 16]>}> : vector<32x64xf32>
     gpu.return
   }
+
+  // CHECK-LABEL: distribute_constant
+  gpu.func @distribute_constant() {
+    // CHECK-COUNT-4: arith.constant dense<1.000000e+00> : vector<16x16xf32>
+    %cst = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16]>} dense<1.0> : vector<256x128xf32>
+    gpu.return
+  }
 }
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index d51122417fb61..415753a652092 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -393,4 +393,11 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
   } {sg_id_range = #xegpu.range<[3, 19]>}
   gpu.return
   }
+
+  // CHECK-LABEL: distribute_constant
+  gpu.func @distribute_constant() {
+    // CHECK: arith.constant dense<1.000000e+00> : vector<32x32xf32>
+    %cst = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>} dense<1.0> : vector<256x128xf32>
+    gpu.return
+  }
 }

Copy link
Contributor

@adam-smnk adam-smnk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@chencha3 chencha3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@nbpatel nbpatel merged commit af87214 into llvm:main Aug 13, 2025
9 checks passed
tkarna pushed a commit to tkarna/llvm-project that referenced this pull request Aug 26, 2025
@nbpatel nbpatel deleted the xegpu-arith-constant branch September 25, 2025 20:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants