Skip to content

Conversation

@nbpatel
Copy link
Contributor

@nbpatel nbpatel commented Nov 13, 2025

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Nov 13, 2025

@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: Nishant Patel (nbpatel)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/167970.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp (+55-2)
  • (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir (+13)
  • (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir (+20)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 0a9ef0aa6df96..afab880d173c7 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1283,6 +1283,57 @@ struct WgToSgVectorTransposeOp
   }
 };
 
+/// Pattern for lowering vector.create_mask and vector.constant_mask ops to
+/// subgroup level.
+template <typename MaskOpType>
+struct WgToSgVectorMaskOp : public OpConversionPattern<MaskOpType> {
+  using OpConversionPattern<MaskOpType>::OpConversionPattern;
+
+  LogicalResult matchAndRewrite(
+      MaskOpType op,
+      typename OpConversionPattern<MaskOpType>::OneToNOpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
+    VectorType resultType = op.getResult().getType();
+    ArrayRef<int64_t> wgShape = resultType.getShape();
+
+    xegpu::DistributeLayoutAttr layout =
+        xegpu::getDistributeLayoutAttr(op.getResult());
+    if (!layout || !layout.isForWorkgroup())
+      return failure();
+
+    SmallVector<int64_t> sgShape;
+    int count;
+    std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
+    VectorType newResultType =
+        VectorType::get(sgShape, resultType.getElementType());
+
+    SmallVector<Value> newMaskOps;
+    for (int i = 0; i < count; ++i) {
+      Value newMaskOp;
+      if constexpr (std::is_same_v<MaskOpType, vector::CreateMaskOp>) {
+        newMaskOp = vector::CreateMaskOp::create(
+            rewriter, op.getLoc(), newResultType, op.getOperands());
+      } else if constexpr (std::is_same_v<MaskOpType, vector::ConstantMaskOp>) {
+        newMaskOp = vector::ConstantMaskOp::create(
+            rewriter, op.getLoc(), newResultType, op.getMaskDimSizes());
+      } else {
+        return rewriter.notifyMatchFailure(op,
+                                           "Unsupported mask operation type");
+      }
+      xegpu::setDistributeLayoutAttr(cast<OpResult>(newMaskOp),
+                                     layout.dropSgLayoutAndData());
+
+      newMaskOps.push_back(newMaskOp);
+    }
+
+    rewriter.replaceOpWithMultiple(op, {newMaskOps});
+    return success();
+  }
+};
+
+using WgToSgVectorCreateMaskOp = WgToSgVectorMaskOp<vector::CreateMaskOp>;
+using WgToSgVectorConstantMaskOp = WgToSgVectorMaskOp<vector::ConstantMaskOp>;
+
 } // namespace
 
 namespace mlir {
@@ -1297,7 +1348,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
            WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
            WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
            WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp,
-           WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp>(
+           WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp,
+           WgToSgVectorConstantMaskOp, WgToSgVectorCreateMaskOp>(
           patterns.getContext());
 }
 } // namespace xegpu
@@ -1427,7 +1479,8 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
 
   target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp,
                                vector::TransposeOp, vector::BroadcastOp,
-                               vector::MultiDimReductionOp>(
+                               vector::MultiDimReductionOp,
+                               vector::ConstantMaskOp, vector::CreateMaskOp>(
       [=](Operation *op) -> bool {
         // Check for either a SliceAttr or LayoutAttr on the result.
         auto layout = xegpu::getDistributeLayoutAttr(op->getResult(0));
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
index 84ce80f477a55..b587ecc726f4d 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
@@ -130,5 +130,18 @@ gpu.module @test_distribution {
     %trans = vector.transpose %load, [1, 0] {layout_result_0 = #xegpu.layout<sg_layout = [4, 8], sg_data = [16, 32], lane_layout = [1, 16], lane_data = [1, 1], order =[1, 0]>} : vector<256x128xf32> to vector<128x256xf32>
       gpu.return
   }
+
+  // CHECK-LABEL: vector_mask_2D
+  gpu.func @vector_mask_2D() {
+    %cst16 = arith.constant 16 : index
+    // CHECK: %[[CST16:.*]] = arith.constant 16 : index
+    // CHECK-COUNT-4: vector.create_mask %[[CST16:.*]], %[[CST16]] : vector<16x16xi1>
+    // CHECK-NOT: vector.create_mask
+    // CHECK-COUNT-4: vector.constant_mask [16, 16] : vector<16x16xi1>
+    // CHECK-NOT: vector.constant_mask
+    %create_mask = vector.create_mask %cst16, %cst16 {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16]>} : vector<256x128xi1>
+    %constant_mask = vector.constant_mask [16, 16] {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16]>} : vector<256x128xi1>
+    gpu.return
+  }
 }
 
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 4fbb566cfbe73..f254b82c6401f 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -547,4 +547,24 @@ gpu.module @test_distribution {
     %broadcast = vector.broadcast %arg0 {layout_result_0 = #xegpu.layout<sg_layout = [4, 8, 1], sg_data = [1, 1, 1]>} : index to vector<4x1x1xindex>
     gpu.return
   }
+
+  // CHECK-LABEL: vector_mask_1D
+  gpu.func @vector_mask_1D() {
+    %cst8 = arith.constant 8 : index
+    // CHECK: vector.create_mask {{.*}} : vector<16xi1>
+    %create_mask = vector.create_mask %cst8 {layout_result_0 = #xegpu.layout<sg_layout = [2], sg_data = [16]>} : vector<16xi1>
+    // CHECK: vector.constant_mask [8] : vector<16xi1>
+    %constant_mask = vector.constant_mask [8] {layout_result_0 = #xegpu.layout<sg_layout = [2], sg_data = [16]>} : vector<32xi1>
+    gpu.return
+  }
+
+  // CHECK-LABEL: vector_mask_2D
+  gpu.func @vector_mask_2D() {
+    %cst16 = arith.constant 16 : index
+    // CHECK: vector.create_mask {{.*}}, {{.*}} : vector<32x32xi1>
+    %create_mask = vector.create_mask %cst16, %cst16 {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>} : vector<256x128xi1>
+    // CHECK: vector.constant_mask [16, 16] : vector<32x32xi1>
+    %constant_mask = vector.constant_mask [16, 16] {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>} : vector<256x128xi1>
+    gpu.return
+  }
 }

@nbpatel nbpatel changed the title [XeGPU][WgToSg] Add distribution for vector mask operations [MLIR][XeGPU][WgToSg] Add distribution for vector mask operations Nov 13, 2025
@nbpatel
Copy link
Contributor Author

nbpatel commented Nov 13, 2025

actually it will be complex if mask dim sizes are bigger than sgData...please hold on reviewing this one till I update it

@nbpatel nbpatel closed this Nov 14, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants