Skip to content

Conversation

nbpatel
Copy link
Contributor

@nbpatel nbpatel commented Sep 30, 2025

This PR distributes non-splat constant from wg to sg. The current pattern has limitations and avoids cases which require SLM access.

@llvmbot
Copy link
Member

llvmbot commented Sep 30, 2025

@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: Nishant Patel (nbpatel)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/161416.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp (+109-15)
  • (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir (+27)
  • (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir (+20)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 9413a9296b184..be03e6e050c43 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -720,7 +720,7 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
                   ConversionPatternRewriter &rewriter) const override {
     auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue());
     auto vecType = dyn_cast<VectorType>(op.getType());
-    if (!vecAttr || !vecAttr.isSplat() || !vecType)
+    if (!vecAttr || !vecType)
       return failure();
 
     xegpu::DistributeLayoutAttr layout =
@@ -733,22 +733,116 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
     int count;
     std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
 
-    // Current limitation: constant of vector with single value.
-    // TODO: support more complex cases, e.g., vector with multiple values.
-    Attribute singleVal = vecAttr.getSplatValue<Attribute>();
-
     auto newType = VectorType::get(sgShape, vecType.getElementType());
-    auto sgAttr = DenseElementsAttr::get(newType, singleVal);
-    auto cstOp =
-        arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr);
-    if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
-        !layout.getEffectiveInstDataAsInt().empty())
-      xegpu::setDistributeLayoutAttr(cstOp->getResult(0),
-                                     layout.dropSgLayoutAndData());
-    SmallVector<Value> newConsts(count, cstOp);
+    Location loc = op.getLoc();
+    auto eltType = vecType.getElementType();
 
-    rewriter.replaceOpWithMultiple(op, {newConsts});
-    return success();
+    auto setLayoutIfNeeded = [&](Value val) {
+      if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
+          !layout.getEffectiveInstDataAsInt().empty()) {
+        xegpu::setDistributeLayoutAttr(llvm::dyn_cast<OpResult>(val),
+                                       layout.dropSgLayoutAndData());
+      }
+    };
+
+    if (vecAttr.isSplat()) {
+      // Splat: single value for all subgroups
+      Attribute singleVal = vecAttr.getSplatValue<Attribute>();
+      auto sgAttr = DenseElementsAttr::get(newType, singleVal);
+      auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr);
+      setLayoutIfNeeded(cstOp->getResult(0));
+      rewriter.replaceOp(op, cstOp);
+      return success();
+    } else if (sgShape == wgShape) { // if the entire vector is shared by all
+                                     // subgroups, don't distribute
+      auto newConstOp =
+          arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr);
+      setLayoutIfNeeded(newConstOp->getResult(0));
+      rewriter.replaceOp(op, newConstOp);
+      return success();
+    } else {
+      // Non-splat constant
+      // Only supports 1D & 2D (with one unit dim)
+      // TODO: support other cases that require SLM access
+      if (!eltType.isIndex())
+        return rewriter.notifyMatchFailure(
+            op, "Unsupported element type for non-splat constant op.");
+
+      SmallVector<int64_t> sgLayout = layout.getEffectiveSgLayoutAsInt();
+      if (wgShape.size() > 2)
+        return rewriter.notifyMatchFailure(
+            op, "Only 1D & 2D vector constant supported");
+
+      // allow 2D vector/distributions with one unit dim
+      auto hasTwoNonUnitDims = [](ArrayRef<int64_t> dims) {
+        return dims.size() == 2 && dims[0] != 1 && dims[1] != 1;
+      };
+      if (hasTwoNonUnitDims(wgShape) || hasTwoNonUnitDims(sgLayout))
+        return rewriter.notifyMatchFailure(
+            op, "2D vector/distribution only supported with 1 unit dim");
+
+      int64_t nonUnitDim = 0;
+      if (wgShape.size() == 2)
+        nonUnitDim = wgShape[0] != 1 ? 0 : 1;
+
+      SmallVector<Attribute> values(vecAttr.getValues<Attribute>());
+      int64_t stride = 0;
+      if (values.size() > 1) {
+        stride = cast<IntegerAttr>(values[1]).getInt() -
+                 cast<IntegerAttr>(values[0]).getInt();
+        for (size_t i = 2; i < values.size(); ++i) {
+          int64_t diff = cast<IntegerAttr>(values[i]).getInt() -
+                         cast<IntegerAttr>(values[i - 1]).getInt();
+          if (diff != stride)
+            return rewriter.notifyMatchFailure(
+                op, "Non-constant stride in non-splat constant op.");
+        }
+      }
+
+      int sgData = 1;
+      if (sgShape.size() == 1) {
+        sgData = static_cast<int>(sgShape[0]);
+      } else if (sgShape.size() == 2) {
+        sgData = static_cast<int>(sgShape[0] != 1 ? sgShape[0] : sgShape[1]);
+      } else {
+        return rewriter.notifyMatchFailure(
+            op, "Only 1D or 2D vector constant supported");
+      }
+
+      // Create a constant for the base tile
+      SmallVector<Attribute> baseTileValues;
+      for (int i = 0; i < sgData; ++i)
+        baseTileValues.push_back(values[i]);
+      auto tileAttr = DenseElementsAttr::get(VectorType::get({sgData}, eltType),
+                                             baseTileValues);
+      auto baseConstVec = rewriter.create<arith::ConstantOp>(loc, tileAttr);
+
+      // Get subgroup id
+      Value sgId =
+          gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
+
+      auto sgOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
+      if (failed(sgOffsets))
+        return failure();
+
+      auto strideConst = rewriter.create<arith::ConstantIndexOp>(loc, stride);
+      SmallVector<Value> newConstOps;
+      for (auto offsets : *sgOffsets) {
+        // Multiply offset with stride, broadcast it and add to baseConstVec
+        Value mulOffset = rewriter.create<arith::MulIOp>(
+            loc, rewriter.getIndexType(), offsets[nonUnitDim], strideConst);
+        auto bcastOffset = rewriter.create<vector::SplatOp>(
+            loc, VectorType::get({sgData}, rewriter.getIndexType()), mulOffset);
+        auto finalConst =
+            arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset);
+        setLayoutIfNeeded(baseConstVec);
+        setLayoutIfNeeded(bcastOffset);
+        setLayoutIfNeeded(finalConst);
+        newConstOps.push_back(finalConst);
+      }
+      rewriter.replaceOpWithMultiple(op, {newConstOps});
+      return success();
+    }
   }
 };
 
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
index dce73dee507e1..f3e2e41ae4b65 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
@@ -98,4 +98,31 @@ gpu.module @test_distribution {
       : vector<256x64xf32> to vector<256xf32>
     gpu.return
   }
+
+  gpu.func @non_splat_constant() {
+    // CHECK-DAG: %[[CST:.*]] = arith.constant dense<[0, 16]> : vector<2xindex>
+    // CHECK-DAG: %[[SG_ID:.*]] = gpu.subgroup_id : index
+    // CHECK-DAG: %[[AFF1:.*]] = affine.apply #map4()[%[[SG_ID]]]
+    // CHECK-DAG: %[[AFF2:.*]] = affine.apply #map5()[%[[SG_ID]]]
+    // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+    // CHECK-DAG: %[[MUL:.*]] = index.mul %[[AFF1]], %[[C2]]
+    // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+    // CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index
+    // CHECK-DAG: %[[ADD1:.*]] = arith.addi %[[MUL]], %[[C0]] : index
+    // CHECK-DAG: %[[ADD2:.*]] = arith.addi %[[AFF2]], %[[C0_0]] : index
+    // CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
+    // CHECK-DAG: %[[REM:.*]] = index.remu %[[ADD1]], %[[C32]]
+    // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+    // CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
+    // CHECK-DAG: %[[C16_0:.*]] = arith.constant 16 : index
+    // CHECK-DAG: %[[ADD3:.*]] = arith.addi %[[MUL]], %[[C16]] : index
+    // CHECK-DAG: %[[REM2:.*]] = index.remu %[[ADD3]], %[[C32]]
+    // CHECK-DAG: %[[MUL2:.*]] = arith.muli %[[REM]], %[[C16_0]] : index
+    // CHECK-DAG: %[[SPLAT:.*]] = vector.splat %[[MUL2]] : vector<2xindex>
+    // CHECK-DAG: %[[MUL3:.*]] = arith.muli %[[REM2]], %[[C16_0]] : index
+    // CHECK-DAG: %[[SPLAT2:.*]] = vector.splat %[[MUL3]] : vector<2xindex>
+    // CHECK-DAG: %[[ADD4:.*]] = arith.addi %[[CST]], %[[SPLAT2]] : vector<2xindex>
+    %cst_2 = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 1], sg_data = [2, 1]>} dense<[[0], [16], [32], [48], [64], [80], [96], [112], [128], [144], [160], [176], [192], [208], [224], [240], [256], [272], [288], [304], [320], [336], [352], [368], [384], [400], [416], [432], [448], [464], [480], [496]]> : vector<32x1xindex>
+    gpu.return
+  }
 }
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 48fc633974e63..07b1e0f9ba8db 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -464,4 +464,24 @@ gpu.module @test_distribution {
     %broadcast = vector.broadcast %muli {layout_result_0 = #xegpu.layout<sg_layout = [4, 2, 6, 1], sg_data = [1, 1, 1, 32]>} : index to vector<4x2x6x32xindex>
     gpu.return
   }
+
+  // CHECK-LABEL: non_splat_constant
+  gpu.func @non_splat_constant() {
+    // CHECK-DAG: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
+    // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+    // CHECK-DAG: %[[IDY:.*]] = affine.apply #map4()[%[[SGID]]]
+    // CHECK-DAG: %[[IDX:.*]] = affine.apply #map5()[%[[SGID]]]
+    // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+    // CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index
+    // CHECK-DAG: %[[ADDY:.*]] = arith.addi %[[IDY]], %[[C0]] : index
+    // CHECK-DAG: %[[ADDX:.*]] = arith.addi %[[IDX]], %[[C0_0]] : index
+    // CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
+    // CHECK-DAG: %[[REMU_Y:.*]] = index.remu %[[ADDY]], %[[C32]]
+    // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+    // CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
+    // CHECK-DAG: %[[MUL:.*]] = arith.muli %[[REMU_Y]], %[[C16]] : index
+    // CHECK-DAG: %[[SPLAT:.*]] = vector.splat %[[MUL]] : vector<1xindex>
+    %cst = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32, 1], sg_data = [1, 1]>} dense<[[0], [16], [32], [48], [64], [80], [96], [112], [128], [144], [160], [176], [192], [208], [224], [240], [256], [272], [288], [304], [320], [336], [352], [368], [384], [400], [416], [432], [448], [464], [480], [496]]> : vector<32x1xindex>
+    gpu.return
+  }
 }

// Multiply offset with stride, broadcast it and add to baseConstVec
Value mulOffset = rewriter.create<arith::MulIOp>(
loc, rewriter.getIndexType(), offsets[nonUnitDim], strideConst);
auto bcastOffset = rewriter.create<vector::SplatOp>(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that Splat op is deprecated and planned to be removed

@nbpatel nbpatel requested a review from Jianhui-Li October 1, 2025 01:14
for (auto offsets : *sgOffsets) {
// Multiply offset with stride, broadcast it and add to baseConstVec
Value mulOffset = rewriter.create<arith::MulIOp>(
loc, rewriter.getIndexType(), offsets[nonUnitDim], strideConst);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code is written specific to one condition (2d but with 1 unit dim, or 1d). If we have to relax the condition, the code need a totoal rewrite.
Can we make it more generic, like having 2 strides for 2d, here you just compute the linear offset, before adding to baseConstVec.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed it to supporting 2D vectors....the high level logic is having two strides,
rowStride & columnStride, and computing the value as rowOffsetrowStride + columnOffsetcolStride and then broadcasting it to baseConstVec size and adding it with the baseConstVec.

@nbpatel nbpatel requested a review from Jianhui-Li October 6, 2025 20:21
}

// Determine the shape of the base tile for each subgroup.
SmallVector<int64_t> baseTileShape;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you just use sgShape directly instead of new var?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, cleaning it up further

// 1D: offset[0] * strideConst
mulOffset = rewriter.create<arith::MulIOp>(
loc, rewriter.getIndexType(), offsets[0], strideConst);
} else if (wgShape.size() == 2) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just else?

@Garra1980
Copy link

Please also update PR's description accordingly

@nbpatel nbpatel force-pushed the xegpu-wg-sg-constant branch from 3c147c7 to 2c81dee Compare October 7, 2025 00:52
@nbpatel
Copy link
Contributor Author

nbpatel commented Oct 7, 2025

@Jianhui-Li Cleaned it up further, please take a look again

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants