Skip to content

Conversation

@nbpatel
Copy link
Contributor

@nbpatel nbpatel commented Nov 14, 2025

No description provided.

@nbpatel nbpatel marked this pull request as ready for review November 17, 2025 16:45
@llvmbot
Copy link
Member

llvmbot commented Nov 17, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-gpu

Author: Nishant Patel (nbpatel)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/168118.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp (+73-5)
  • (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir (+9)
  • (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir (+37)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 0a9ef0aa6df96..81fd25a155129 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1283,6 +1283,74 @@ struct WgToSgVectorTransposeOp
   }
 };
 
+// This pattern distributes the vector.constant_mask ops to work at subgroup
+// level.
+struct WgToSgVectorConstantMaskOp
+    : public OpConversionPattern<vector::ConstantMaskOp> {
+  using OpConversionPattern<vector::ConstantMaskOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::ConstantMaskOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    xegpu::DistributeLayoutAttr layout =
+        xegpu::getDistributeLayoutAttr(op.getResult());
+    if (!layout || !layout.isForWorkgroup())
+      return failure();
+
+    Location loc = op.getLoc();
+    VectorType type = op.getResult().getType();
+    auto wgShape = type.getShape();
+
+    ArrayRef<int64_t> originalMaskDimSizes = op.getMaskDimSizes();
+
+    // Get subgroup ID.
+    Value sgId =
+        gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
+    auto sgOffsets =
+        layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
+    if (failed(sgOffsets))
+      return failure();
+
+    SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+    VectorType resultType = VectorType::get(sgShape, type.getElementType());
+
+    SmallVector<Value> newCreateMaskOps;
+    for (auto offsetSet : *sgOffsets) {
+      SmallVector<Value> maskOperands;
+
+      for (auto [i, originalMaskSize] : llvm::enumerate(originalMaskDimSizes)) {
+        Value originalMaskSizeVal =
+            arith::ConstantIndexOp::create(rewriter, loc, originalMaskSize);
+        Value dimSizeVal =
+            arith::ConstantIndexOp::create(rewriter, loc, sgShape[i]);
+        Value offset = offsetSet[i];
+        // Compute: originalMaskSize - offset.
+        Value adjustedMaskSize =
+            arith::SubIOp::create(rewriter, loc, originalMaskSizeVal, offset);
+        // Clamp to [0, dimSize]: max(0, min(adjustedMaskSize,
+        // dimSize))
+        Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
+        Value clampedLow =
+            arith::MaxSIOp::create(rewriter, loc, adjustedMaskSize, zero);
+        Value clampedHigh =
+            arith::MinSIOp::create(rewriter, loc, clampedLow, dimSizeVal);
+        maskOperands.push_back(clampedHigh);
+      }
+
+      auto newCreateMaskOp =
+          vector::CreateMaskOp::create(rewriter, loc, resultType, maskOperands);
+      if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
+          !layout.getEffectiveInstDataAsInt().empty())
+        xegpu::setDistributeLayoutAttr(newCreateMaskOp->getResult(0),
+                                       layout.dropSgLayoutAndData());
+      newCreateMaskOps.push_back(newCreateMaskOp.getResult());
+    }
+
+    rewriter.replaceOpWithMultiple(op, {newCreateMaskOps});
+    return success();
+  }
+};
+
 } // namespace
 
 namespace mlir {
@@ -1297,8 +1365,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
            WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
            WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
            WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp,
-           WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp>(
-          patterns.getContext());
+           WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp,
+           WgToSgVectorConstantMaskOp>(patterns.getContext());
 }
 } // namespace xegpu
 } // namespace mlir
@@ -1425,9 +1493,9 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
         return isLegal(layout);
       });
 
-  target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp,
-                               vector::TransposeOp, vector::BroadcastOp,
-                               vector::MultiDimReductionOp>(
+  target.addDynamicallyLegalOp<
+      vector::ShapeCastOp, vector::StepOp, vector::TransposeOp,
+      vector::BroadcastOp, vector::MultiDimReductionOp, vector::ConstantMaskOp>(
       [=](Operation *op) -> bool {
         // Check for either a SliceAttr or LayoutAttr on the result.
         auto layout = xegpu::getDistributeLayoutAttr(op->getResult(0));
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
index 84ce80f477a55..a752d0aa5c541 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
@@ -130,5 +130,14 @@ gpu.module @test_distribution {
     %trans = vector.transpose %load, [1, 0] {layout_result_0 = #xegpu.layout<sg_layout = [4, 8], sg_data = [16, 32], lane_layout = [1, 16], lane_data = [1, 1], order =[1, 0]>} : vector<256x128xf32> to vector<128x256xf32>
       gpu.return
   }
+
+  // CHECK-LABEL: vector_mask_2D
+  gpu.func @vector_mask_2D() {
+    %cst16 = arith.constant 16 : index
+    // CHECK-COUNT-4: vector.create_mask {{.*}}, {{.*}} : vector<16x16xi1>
+    // CHECK-NOT: vector.create_mask
+    %constant_mask = vector.constant_mask [16, 16] {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16]>} : vector<256x128xi1>
+    gpu.return
+  }
 }
 
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 4fbb566cfbe73..fa08ed1623501 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -547,4 +547,41 @@ gpu.module @test_distribution {
     %broadcast = vector.broadcast %arg0 {layout_result_0 = #xegpu.layout<sg_layout = [4, 8, 1], sg_data = [1, 1, 1]>} : index to vector<4x1x1xindex>
     gpu.return
   }
+
+  // CHECK-LABEL: vector_mask_1D
+  gpu.func @vector_mask_1D() {
+    // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+    // CHECK-DAG: %[[REMU:.*]] = index.remu %[[SGID]], %[[C2:.*]]
+    // CHECK-DAG: %[[MUL:.*]] = index.mul %[[REMU]], %[[C16:.*]]
+    // CHECK-DAG: %[[REMU2:.*]] = index.remu %[[MUL]], %[[C32:.*]]
+    // CHECK-DAG: %[[SUB:.*]] = arith.subi %[[C8:.*]], %[[REMU2]] : index
+    // CHECK-DAG: %[[MAX:.*]] = arith.maxsi %[[SUB]], %[[C0:.*]] : index
+    // CHECK-DAG: %[[MIN:.*]] = arith.minsi %[[MAX]], %[[C16:.*]] : index
+    // CHECK-DAG: %[[MASK:.*]] = vector.create_mask %[[MIN]] : vector<16xi1>
+    %cst8 = arith.constant 8 : index
+    %constant_mask = vector.constant_mask [8] {layout_result_0 = #xegpu.layout<sg_layout = [2], sg_data = [16]>} : vector<32xi1>
+    gpu.return
+  }
+
+  // CHECK-LABEL: vector_mask_2D
+  gpu.func @vector_mask_2D() {
+    // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+    // CHECK-DAG: %[[SGIDX:.*]] = index.remu %[[SGID]], %[[C4:.*]]
+    // CHECK-DAG: %[[SGIDY_TMP:.*]] = index.divu %[[SGID]], %[[C4:.*]]
+    // CHECK-DAG: %[[SGIDY:.*]] = index.remu %[[SGIDY_TMP]], %[[C8:.*]]
+    // CHECK-DAG: %[[ROW:.*]] = index.mul %[[SGIDY]], %[[C32:.*]]
+    // CHECK-DAG: %[[COL:.*]] = index.mul %[[SGIDX]], %[[C32:.*]]
+    // CHECK-DAG: %[[MODROW:.*]] = index.remu %[[ROW]], %[[C256:.*]]
+    // CHECK-DAG: %[[MODCOL:.*]] = index.remu %[[COL]], %[[C128:.*]]
+    // CHECK-DAG: %[[SUBROW:.*]] = arith.subi %[[C16:.*]], %[[MODROW]] : index
+    // CHECK-DAG: %[[MAXROW:.*]] = arith.maxsi %[[SUBROW]], %[[C4:.*]] : index
+    // CHECK-DAG: %[[MINROW:.*]] = arith.minsi %[[MAXROW]], %[[C32:.*]] : index
+    // CHECK-DAG: %[[SUBCOL:.*]] = arith.subi %[[C16:.*]], %[[MODCOL]] : index
+    // CHECK-DAG: %[[MAXCOL:.*]] = arith.maxsi %[[SUBCOL]], %[[C7:.*]] : index
+    // CHECK-DAG: %[[MINCOL:.*]] = arith.minsi %[[MAXCOL]], %[[C32:.*]] : index
+    // CHECK-DAG: %[[MASK:.*]] = vector.create_mask %[[MINROW]], %[[MINCOL]] : vector<32x32xi1>
+    %cst16 = arith.constant 16 : index
+    %constant_mask = vector.constant_mask [16, 16] {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>} : vector<256x128xi1>
+    gpu.return
+  }
 }

@nbpatel nbpatel requested a review from silee2 November 17, 2025 17:35
@github-actions
Copy link

github-actions bot commented Nov 18, 2025

🐧 Linux x64 Test Results

  • 7104 tests passed
  • 594 tests skipped

@nbpatel nbpatel force-pushed the xegpu-wg-sg-constant-mask branch from ade9234 to e162350 Compare November 18, 2025 22:12
Copy link
Contributor

@silee2 silee2 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@charithaintc charithaintc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@Jianhui-Li Jianhui-Li left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@nbpatel nbpatel merged commit 310abe0 into llvm:main Nov 20, 2025
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants