Skip to content

Conversation

@nbpatel
Copy link
Contributor

@nbpatel nbpatel commented Jun 16, 2025

This PR adds transformation pattern for vector.broadcast op in xegpu-wg-to-sg-distribute pass

@nbpatel nbpatel requested a review from chencha3 June 16, 2025 19:41
VectorType::get(sgShape, resultType.getElementType());

SmallVector<Value> newBroadcastOps;
for (size_t i = 0; i < adaptor.getOperands().front().size(); ++i) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about use range-based for loop?

xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
if (!layout || !layout.getSgLayout())
return failure();

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks to me that the current implementation is assuming the rank of source is the same as the rank of the result, which is a subset of the supported semantics of vector.broadcast. I believe it is partially because of the limitation of LayoutAttr. It would be better to add a check.

@nbpatel nbpatel marked this pull request as ready for review June 16, 2025 20:33
@llvmbot
Copy link
Member

llvmbot commented Jun 16, 2025

@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: Nishant Patel (nbpatel)

Changes

This PR adds transformation pattern for vector.broadcast op in xegpu-wg-to-sg-distribute pass


Full diff: https://github.com/llvm/llvm-project/pull/144417.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp (+40-1)
  • (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir (+18-1)
  • (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir (+17-2)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index a26c6b52f0ddc..96c7032d6b812 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -328,6 +328,39 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
   }
 };
 
+/// This pattern transforms vector.broadcast ops to work at subgroup level.
+struct WgToSgVectorBroadcastOp
+    : public OpConversionPattern<vector::BroadcastOp> {
+  using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::BroadcastOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    VectorType resultType = op.getResult().getType();
+    ArrayRef<int64_t> wgShape = resultType.getShape();
+
+    xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
+    if (!layout || !layout.getSgLayout())
+      return failure();
+
+    SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+    VectorType newResultType =
+        VectorType::get(sgShape, resultType.getElementType());
+
+    SmallVector<Value> newBroadcastOps;
+    for (size_t i = 0; i < adaptor.getOperands().front().size(); ++i) {
+      auto newBroadcast = rewriter.create<vector::BroadcastOp>(
+          op.getLoc(), newResultType, adaptor.getOperands().front()[i]);
+      xegpu::setLayoutAttr(newBroadcast->getResult(0),
+                           layout.dropSgLayoutAndData());
+      newBroadcastOps.push_back(newBroadcast.getResult());
+    }
+
+    rewriter.replaceOpWithMultiple(op, {newBroadcastOps});
+    return success();
+  }
+};
+
 // Handles UnrealizedConversionCastOp generated during
 // SCFStructuralTypeConversions (step 1). This op may appear as either a
 // target or source materialization for Vector values, e.g.:
@@ -411,7 +444,8 @@ namespace xegpu {
 void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
   patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
                WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
-               UnrealizedConversionCastOpPattern>(patterns.getContext());
+               WgToSgVectorBroadcastOp, UnrealizedConversionCastOpPattern>(
+      patterns.getContext());
 }
 } // namespace xegpu
 } // namespace mlir
@@ -518,6 +552,11 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
     return isLegal(layout);
   });
 
+  target.addDynamicallyLegalOp<vector::BroadcastOp>(
+      [=](vector::BroadcastOp op) -> bool {
+        return isLegal(xegpu::getLayoutAttr(op.getResult()));
+      });
+
   target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
       [=](UnrealizedConversionCastOp op) {
         return llvm::is_contained(existingCastOps, op.getOperation());
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
index 35ad16d8cd9a9..60ac266b0f112 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
@@ -103,6 +103,24 @@ gpu.module @test_round_robin_assignment {
     gpu.return
   }
 
+  // CHECK-LABEL: test_broadcast
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<24x1xf32>
+  gpu.func @test_broadcast(%src: memref<24x1xf32>) {
+    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x1xf32>
+      -> !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
+    %load =  xegpu.load_nd %tdesc
+      : !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
+      -> vector<24x1xf32>
+    // CHECK-COUNT-3: vector.broadcast {{.*}}
+    // CHECK-SAME-COUNT-3: {layout_result_0 = #xegpu.layout<lane_layout = [2, 1], lane_data = [1, 1]>}
+    // CHECK-SAME-COUNT-3: : vector<2x1xf32> to vector<2x4xf32>
+    // CHECK-NOT: vector.broadcast
+    %broadcast = vector.broadcast %load 
+      {layout_result_0 = #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 4], lane_layout = [2, 1], lane_data = [1, 1]>}
+      : vector<24x1xf32> to vector<24x8xf32>
+    gpu.return
+  }
+
   gpu.func @test_scf_for(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>) {
     %c1 = arith.constant 1 : index
     %c10 = arith.constant 10 : index
@@ -197,5 +215,4 @@ gpu.module @test_round_robin_assignment {
     xegpu.store_nd %d, %1 : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
     gpu.return
   }
-
 }
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index 466842c968448..125bab349b4cb 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -170,6 +170,22 @@ gpu.func @test_dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
     gpu.return
   }
 
+  // CHECK-LABEL: test_broadcast
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<24x1xf32>
+  gpu.func @test_broadcast(%src: memref<24x1xf32>) {
+    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x1xf32>
+      -> !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
+    %load =  xegpu.load_nd %tdesc
+      : !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
+      -> vector<24x1xf32>
+    // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 1], lane_data = [1, 1]>}
+    // CHECK-SAME: : vector<12x1xf32> to vector<12x8xf32>
+    %broadcast = vector.broadcast %load 
+      {layout_result_0 = #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 8], lane_layout = [2, 1], lane_data = [1, 1]>}
+      : vector<24x1xf32> to vector<24x8xf32>
+    gpu.return
+  }
+
   gpu.func @test_scf_for(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>) {
     //CHECK: [[c0:%.+]] = arith.constant 0 : index
     //CHECK: [[c128:%.+]] = arith.constant 128 : index
@@ -295,6 +311,5 @@ gpu.func @test_dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
     xegpu.store_nd %d, %1 : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>>
     gpu.return
   }
-
-
 }
+

@nbpatel nbpatel requested a review from adam-smnk June 20, 2025 17:05
Copy link
Contributor

@adam-smnk adam-smnk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks in line with other distributions

// TODO: Currently only supports cases where the source and result ranks
// are the same.
auto srcType =
dyn_cast<VectorType>(adaptor.getOperands().front()[0].getType());
Copy link
Contributor

@chencha3 chencha3 Jul 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can adaptor.getSource() be used here and later instead of using adaptor.getOperands().front() ?

// and the other dimensions are the same as the destination type
// TODO: Generalize it
auto srcShape = srcType.getShape();
for (size_t i = 0; i < srcShape.size(); ++i) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems this check duplicates the check in broadcast verifier, unless there are cases where the source vector, e.g., vector<32x1x1xf32> can be distributed to a vector, e.g., <8x2x1>.

@nbpatel nbpatel merged commit 56b263b into llvm:main Jul 23, 2025
9 checks passed
mahesh-attarde pushed a commit to mahesh-attarde/llvm-project that referenced this pull request Jul 28, 2025
…o Sg pass (llvm#144417)

This PR adds transformation pattern for vector.broadcast op in
xegpu-wg-to-sg-distribute pass
@nbpatel nbpatel deleted the xegpu_wg_sg_broadcast branch September 25, 2025 20:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants