Skip to content

Conversation

@nbpatel
Copy link
Contributor

@nbpatel nbpatel commented Nov 25, 2025

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Nov 25, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-gpu

Author: Nishant Patel (nbpatel)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/169571.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp (+29-20)
  • (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir (+8)
  • (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir (+37)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index beb9b60aa9d7a..95c20b1fabe58 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1270,15 +1270,15 @@ struct WgToSgVectorTransposeOp
   }
 };
 
-// This pattern distributes the vector.constant_mask ops to work at subgroup
-// level.
-struct WgToSgVectorConstantMaskOp
-    : public OpConversionPattern<vector::ConstantMaskOp> {
-  using OpConversionPattern<vector::ConstantMaskOp>::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(vector::ConstantMaskOp op, OneToNOpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
+// Distribute vector mask ops to work at subgroup level.
+template <typename MaskOpType>
+struct WgToSgVectorMaskOp : public OpConversionPattern<MaskOpType> {
+  using OpConversionPattern<MaskOpType>::OpConversionPattern;
+
+  LogicalResult matchAndRewrite(
+      MaskOpType op,
+      typename OpConversionPattern<MaskOpType>::OneToNOpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
     xegpu::DistributeLayoutAttr layout =
         xegpu::getDistributeLayoutAttr(op.getResult());
     if (!layout || !layout.isForWorkgroup())
@@ -1288,9 +1288,16 @@ struct WgToSgVectorConstantMaskOp
     VectorType type = op.getResult().getType();
     auto wgShape = type.getShape();
 
-    ArrayRef<int64_t> wgMaskDimSizes = op.getMaskDimSizes();
+    SmallVector<Value> wgMaskDimSizes;
+    if constexpr (std::is_same_v<MaskOpType, vector::ConstantMaskOp>) {
+      for (int64_t maskSize : op.getMaskDimSizes()) {
+        wgMaskDimSizes.push_back(
+            arith::ConstantIndexOp::create(rewriter, loc, maskSize));
+      }
+    } else if constexpr (std::is_same_v<MaskOpType, vector::CreateMaskOp>) {
+      wgMaskDimSizes = llvm::to_vector(op.getOperands());
+    }
 
-    // Get subgroup ID.
     Value sgId =
         gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
     auto sgOffsets =
@@ -1302,19 +1309,17 @@ struct WgToSgVectorConstantMaskOp
     VectorType resultType = VectorType::get(sgShape, type.getElementType());
 
     // In each dimension, each subgroup computes its local mask size as:
-    // min(max(wgMaskSize[d] - offset[d], 0), sgDimSize[d])
+    // min(max(wgMaskDimSize[d] - offset[d], 0), sgDimSize[d])
     SmallVector<Value> newCreateMaskOps;
     for (auto offsetSet : *sgOffsets) {
       SmallVector<Value> maskOperands;
 
-      for (auto [i, wgMaskSize] : llvm::enumerate(wgMaskDimSizes)) {
-        Value wgMaskSizeVal =
-            arith::ConstantIndexOp::create(rewriter, loc, wgMaskSize);
+      for (auto [i, wgMaskDimSize] : llvm::enumerate(wgMaskDimSizes)) {
         Value dimSizeVal =
             arith::ConstantIndexOp::create(rewriter, loc, sgShape[i]);
         Value offset = offsetSet[i];
         Value adjustedMaskSize =
-            arith::SubIOp::create(rewriter, loc, wgMaskSizeVal, offset);
+            arith::SubIOp::create(rewriter, loc, wgMaskDimSize, offset);
         Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
         Value nonNegative =
             arith::MaxSIOp::create(rewriter, loc, adjustedMaskSize, zero);
@@ -1335,6 +1340,8 @@ struct WgToSgVectorConstantMaskOp
   }
 };
 
+using WgToSgVectorConstantMaskOp = WgToSgVectorMaskOp<vector::ConstantMaskOp>;
+using WgToSgVectorCreateMaskOp = WgToSgVectorMaskOp<vector::CreateMaskOp>;
 } // namespace
 
 namespace mlir {
@@ -1350,7 +1357,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
            WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
            WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp,
            WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp,
-           WgToSgVectorConstantMaskOp>(patterns.getContext());
+           WgToSgVectorConstantMaskOp, WgToSgVectorCreateMaskOp>(
+          patterns.getContext());
 }
 } // namespace xegpu
 } // namespace mlir
@@ -1477,9 +1485,10 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
         return isLegal(layout);
       });
 
-  target.addDynamicallyLegalOp<
-      vector::ShapeCastOp, vector::StepOp, vector::TransposeOp,
-      vector::BroadcastOp, vector::MultiDimReductionOp, vector::ConstantMaskOp>(
+  target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp,
+                               vector::TransposeOp, vector::BroadcastOp,
+                               vector::MultiDimReductionOp,
+                               vector::ConstantMaskOp, vector::CreateMaskOp>(
       [=](Operation *op) -> bool {
         // Check for either a SliceAttr or LayoutAttr on the result.
         auto layout = xegpu::getDistributeLayoutAttr(op->getResult(0));
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
index 1cddccb5fbbd1..4fb50b3b28534 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
@@ -138,5 +138,13 @@ gpu.module @test_distribution {
     %constant_mask = vector.constant_mask [16, 16] {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16]>} : vector<256x128xi1>
     gpu.return
   }
+
+  gpu.func @vector_create_mask_2D() {
+    // CHECK-COUNT-4: vector.create_mask {{.*}}, {{.*}} : vector<16x16xi1>
+    // CHECK-NOT: vector.create_mask
+    %cst16 = arith.constant 16 : index
+    %constant_mask = vector.create_mask %cst16, %cst16 {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16]>} : vector<256x128xi1>
+    gpu.return
+  } 
 }
 
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 574b365443a0a..48e93320093fd 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -583,6 +583,43 @@ gpu.module @test_distribution {
     gpu.return
   }
 
+  // CHECK-LABEL: vector_create_mask_1D
+  gpu.func @vector_create_mask_1D() {
+    // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+    // CHECK-DAG: %[[REMU:.*]] = index.remu %[[SGID]], %[[C2:.*]]
+    // CHECK-DAG: %[[MUL:.*]] = index.mul %[[REMU]], %[[C16:.*]]
+    // CHECK-DAG: %[[REMU2:.*]] = index.remu %[[MUL]], %[[C32:.*]]
+    // CHECK-DAG: %[[SUB:.*]] = arith.subi %[[C8:.*]], %[[REMU2]] : index
+    // CHECK-DAG: %[[MAX:.*]] = arith.maxsi %[[SUB]], %[[C0:.*]] : index
+    // CHECK-DAG: %[[MIN:.*]] = arith.minsi %[[MAX]], %[[C16:.*]] : index
+    // CHECK-DAG: %[[MASK:.*]] = vector.create_mask %[[MIN]] : vector<16xi1>
+    %cst8 = arith.constant 8 : index
+    %constant_mask = vector.create_mask %cst8 {layout_result_0 = #xegpu.layout<sg_layout = [2], sg_data = [16]>} : vector<32xi1>
+    gpu.return
+  }
+
+  // CHECK-LABEL: vector_create_mask_2D
+  gpu.func @vector_create_mask_2D() {
+    // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+    // CHECK-DAG: %[[SGIDX:.*]] = index.remu %[[SGID]], %[[C4:.*]]
+    // CHECK-DAG: %[[SGIDY_TMP:.*]] = index.divu %[[SGID]], %[[C4:.*]]
+    // CHECK-DAG: %[[SGIDY:.*]] = index.remu %[[SGIDY_TMP]], %[[C8:.*]]
+    // CHECK-DAG: %[[ROW:.*]] = index.mul %[[SGIDY]], %[[C32:.*]]
+    // CHECK-DAG: %[[COL:.*]] = index.mul %[[SGIDX]], %[[C32:.*]]
+    // CHECK-DAG: %[[MODROW:.*]] = index.remu %[[ROW]], %[[C256:.*]]
+    // CHECK-DAG: %[[MODCOL:.*]] = index.remu %[[COL]], %[[C128:.*]]
+    // CHECK-DAG: %[[SUBROW:.*]] = arith.subi %[[C16:.*]], %[[MODROW]] : index
+    // CHECK-DAG: %[[MAXROW:.*]] = arith.maxsi %[[SUBROW]], %[[C0:.*]] : index
+    // CHECK-DAG: %[[MINROW:.*]] = arith.minsi %[[MAXROW]], %[[C32:.*]] : index
+    // CHECK-DAG: %[[SUBCOL:.*]] = arith.subi %[[C16:.*]], %[[MODCOL]] : index
+    // CHECK-DAG: %[[MAXCOL:.*]] = arith.maxsi %[[SUBCOL]], %[[C0:.*]] : index
+    // CHECK-DAG: %[[MINCOL:.*]] = arith.minsi %[[MAXCOL]], %[[C32:.*]] : index
+    // CHECK-DAG: %[[MASK:.*]] = vector.create_mask %[[MINROW]], %[[MINCOL]] : vector<32x32xi1>
+    %cst16 = arith.constant 16 : index
+    %constant_mask = vector.create_mask %cst16, %cst16 {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>} : vector<256x128xi1>
+    gpu.return
+  }
+
   // CHECK-LABEL: distribute_load_slice_attr
   gpu.func @distribute_load_slice_attr() {
     %2 = memref.alloca() {alignment = 1024} : memref<4096xf32>

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants