Skip to content

Conversation

@akroviakov
Copy link
Contributor

This PR enables sg-to-wi distribution of xegpu matrix load/store ops.

@llvmbot
Copy link
Member

llvmbot commented Oct 24, 2025

@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: Artem Kroviakov (akroviakov)

Changes

This PR enables sg-to-wi distribution of xegpu matrix load/store ops.


Full diff: https://github.com/llvm/llvm-project/pull/165008.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp (+116-8)
  • (modified) mlir/test/Dialect/XeGPU/subgroup-distribute.mlir (+15)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index d09dc196c0bf7..fe059bb86eba2 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -906,6 +906,110 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
   }
 };
 
+template <class MatrixOp>
+struct MatrixOpDistribution final : public gpu::WarpDistributionPattern {
+  using gpu::WarpDistributionPattern::WarpDistributionPattern;
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    gpu::YieldOp yield = warpOp.getTerminator();
+    Operation *lastNode = yield->getPrevNode();
+    auto matrixOp = dyn_cast_or_null<MatrixOp>(lastNode);
+    if (!matrixOp)
+      return failure();
+    constexpr bool isLoad{std::is_same_v<MatrixOp, xegpu::LoadMatrixOp>};
+    int operandIdx{-1};
+
+    VectorType payloadTy;
+    VectorType warpResultTy;
+    if constexpr (isLoad) {
+      OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
+        return isa<xegpu::LoadMatrixOp>(op) && matrixOp == op;
+      });
+      if (!producedByLastLoad)
+        return rewriter.notifyMatchFailure(
+            warpOp, "The last op is not xegpu::LoadMatrixOp");
+      operandIdx = producedByLastLoad->getOperandNumber();
+      payloadTy = dyn_cast<VectorType>(matrixOp.getResult().getType());
+      warpResultTy = cast<VectorType>(warpOp.getResult(operandIdx).getType());
+    } else {
+      payloadTy = dyn_cast<VectorType>(matrixOp.getData().getType());
+    }
+    if (!payloadTy)
+      return rewriter.notifyMatchFailure(
+          matrixOp, "the matrix op payload must be a vector type");
+
+    auto loc = matrixOp.getLoc();
+    auto offsets = matrixOp.getMixedOffsets();
+    if (offsets.empty())
+      return rewriter.notifyMatchFailure(matrixOp,
+                                         "the load op must have offsets");
+    SmallVector<Value> offsetsAsValues =
+        vector::getAsValues(rewriter, matrixOp.getLoc(), offsets);
+
+    auto layout = matrixOp.getLayoutAttr();
+    if (!layout)
+      return rewriter.notifyMatchFailure(
+          matrixOp, "the matrix operation lacks layout attribute");
+
+    FailureOr<VectorType> distPayloadByWarpOpOrFailure =
+        getDistVecTypeBasedOnLaneLayout(layout, payloadTy);
+    if (failed(distPayloadByWarpOpOrFailure))
+      return rewriter.notifyMatchFailure(
+          matrixOp,
+          "The matrix op payload has no layouts, using defaults instead.");
+
+    SmallVector<Value> operands;
+    if constexpr (isLoad)
+      operands = {matrixOp.getMemDesc()};
+    else
+      operands = {matrixOp.getData(), matrixOp.getMemDesc()};
+    const unsigned offsetsStartIdx = operands.size();
+    operands.append(offsetsAsValues);
+
+    SmallVector<Type> operandTypes = llvm::to_vector(
+        llvm::map_range(operands, [](Value v) { return v.getType(); }));
+    if constexpr (!isLoad)
+      operandTypes[0] = *distPayloadByWarpOpOrFailure;
+
+    SmallVector<size_t> newRetIndices;
+    gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, operands, operandTypes, newRetIndices);
+    SmallVector<Value> newOperands = llvm::map_to_vector(
+        newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
+
+    rewriter.setInsertionPointAfter(newWarpOp);
+    unsigned operandIdxToModify = offsetsStartIdx + offsetsAsValues.size() - 1;
+    newOperands[operandIdxToModify] = arith::AddIOp::create(
+        rewriter, loc, rewriter.getIndexType(), newOperands[operandIdxToModify],
+        newWarpOp.getLaneid());
+
+    SmallVector<int64_t> newConstOffsets{matrixOp.getConstOffsets()};
+    std::fill(newConstOffsets.begin(), newConstOffsets.end(),
+              ShapedType::kDynamic);
+    DenseI64ArrayAttr newConstOffsetsAttr =
+        rewriter.getDenseI64ArrayAttr(newConstOffsets);
+    ValueRange newOffsets = ValueRange(newOperands).drop_front(offsetsStartIdx);
+
+    if constexpr (isLoad) {
+      xegpu::LoadMatrixOp newOp = xegpu::LoadMatrixOp::create(
+          rewriter, newWarpOp.getLoc(), *distPayloadByWarpOpOrFailure,
+          newOperands[0], newOffsets, newConstOffsetsAttr,
+          matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
+      // Resolve the output type and replace all uses.
+      rewriter.replaceAllUsesWith(
+          newWarpOp.getResult(operandIdx),
+          resolveDistributedTy(newOp.getResult(), warpResultTy, rewriter));
+    } else {
+      xegpu::StoreMatrixOp::create(
+          rewriter, loc, TypeRange{}, newOperands[0], newOperands[1],
+          newOffsets, newConstOffsetsAttr, matrixOp.getSubgroupBlockIoAttr(),
+          xegpu::DistributeLayoutAttr{});
+      rewriter.eraseOp(matrixOp);
+    }
+    return success();
+  }
+};
+
 /// Distribute a scattered load op. The logic and requirements are the same as
 /// for the scattered store distribution. The warpOp's payload vector is
 /// expected to be distributed by the load's result consumer.
@@ -1433,14 +1537,16 @@ struct XeGPUSubgroupDistributePass final
 
 void xegpu::populateXeGPUSubgroupDistributePatterns(
     RewritePatternSet &patterns) {
-  patterns.add<CreateNdDescDistribution, StoreNdDistribution,
-               LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
-               GpuBarrierDistribution, VectorMultiReductionDistribution,
-               LoadDistribution, StoreDistribution, VectorTransposeDistribution,
-               VectorBitcastDistribution,
-               MemrefExtractAlignedPointerAsIndexDistribution>(
-      patterns.getContext(),
-      /*pattern benefit=*/regularPatternBenefit);
+  patterns
+      .add<CreateNdDescDistribution, StoreNdDistribution, LoadNdDistribution,
+           DpasDistribution, PrefetchNdDistribution, GpuBarrierDistribution,
+           VectorMultiReductionDistribution, LoadDistribution,
+           StoreDistribution, VectorTransposeDistribution,
+           VectorBitcastDistribution, MatrixOpDistribution<xegpu::LoadMatrixOp>,
+           MatrixOpDistribution<xegpu::StoreMatrixOp>,
+           MemrefExtractAlignedPointerAsIndexDistribution>(
+          patterns.getContext(),
+          /*pattern benefit=*/regularPatternBenefit);
   patterns.add<VectorShapeCastDistribution>(
       patterns.getContext(),
       /*pattern benefit=*/highPatternBenefit);
@@ -1462,6 +1568,8 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
       // Layouts are needed for vector type only.
       if (!isa<VectorType>(operand.get().getType()))
         continue;
+      if (isa<xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>(op))
+        continue;
 
       auto layout = xegpu::getDistributeLayoutAttr(operand.get());
       if (!layout) {
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index 27a3dc373c739..3fcc747217c9d 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -265,3 +265,18 @@ gpu.module @xevm_module{
     gpu.return
   }
 }
+
+// -----
+// CHECK-LABEL: gpu.func @load_store_matrix_1({{.*}}) {
+// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[{{.*}}] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<1x1xf32>
+// CHECK: xegpu.store_matrix %[[MAT]], %arg0[{{.*}}] : vector<1x1xf32>, !xegpu.mem_desc<32x32xf32>, index, index
+gpu.module @xevm_module{
+  gpu.func @load_store_matrix_1(%arg0: !xegpu.mem_desc<32x32xf32>) {
+    %c0 = arith.constant 0 : index
+    %1 = xegpu.load_matrix %arg0[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}> : !xegpu.mem_desc<32x32xf32>, index, index -> vector<2x8xf32>
+
+    xegpu.store_matrix %1, %arg0[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}> : vector<2x8xf32>, !xegpu.mem_desc<32x32xf32>, index, index
+
+    gpu.return 
+  }
+}

@akroviakov
Copy link
Contributor Author

akroviakov commented Oct 24, 2025

This PR is currently WIP, until there is more clarity on the offset adjustment. Is the logic going to be the same as wg-to-sg (i.e., delinearization based on lane_layout + lane_data)? Is lane_data element- or packing-based?

UPD: offset calculation clarified

Copy link
Contributor

@charithaintc charithaintc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. Will approve after the feedback.

VectorType valOrResVecTy = dyn_cast<VectorType>(data.getType());
if (!valOrResVecTy)
valOrResVecTy = VectorType::get(1, data.getType());
if (valOrResVecTy.getShape().size() != 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a verification here for load/store_matrix with @subgroup_block_io attribute: The payload must be contiguous in the memory.

Both of these two IRs in the tests added in this PR are actually not correct. Since the payload data are not contiguous between lanes. They are correct if you change the vector<2x16xf32> to <16x2xf32> (lane_layout/lane_data need to change accordingly but that is out of IR verifier's scope).

    %1 = xegpu.load_matrix %arg0[%c0, %c0] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} :
      !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index -> vector<2x16xf32>

    xegpu.store_matrix %1, %arg0[%c0, %c0] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} :
      vector<2x16xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added verification. However,

change the vector<2x16xf32> to <16x2xf32>

I understand the logical reasoning for this in the matrix ops case, but the current distribution does not allow it, considering the "correct" lane layout the block load requires.

We have

  for (auto [i, dim] : llvm::enumerate(originalType.getShape())) {
    if (i < distributionStart)
      continue;
    // Check if the dimension can be distributed evenly.
    if (dim % effectiveLaneLayout[i - distributionStart] != 0)
      return failure();
    distributedShape[i] = dim / effectiveLaneLayout[i - distributionStart];
  }

Meaning that given lane_layout = [1, 16], lane_data = [1, 1] and a 16x2 data shape, we get

shape[0] % layout[0] = 16 % 1 = 0 // good
shape[1] % layout[1] = 2 % 16 = 2 // fail

We can change the layout to be [16, 1], which would allow the pattern to complete and the distributed code to still be correct, since the lane layout is not used in further coordinate calculations. But [16, 1] may be harder for users to reason about by simply looking at the xevm block load description and the sg-level subgroup_block_io matrix op.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If user uses stride=[1, 32] in the memory layout, then user should able to reason sg_layout = [16, 1].

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if user use lane_layout = [1, 16], it should not use strided memory layout, the example above should just use block layout. The maxtrix op with subgroup_block_io is a subgroup operation, and all lanes collectively access a contiguous memory buffer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is now a test with a 16x2xf32 result using the proper stride.
Short snippet:

    %1 = xegpu.load_matrix %arg0[%c0, %c1] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>} :
     !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 1]>>, index, index -> vector<16x2xf32>

It distributes to 1x2xf32

ArrayRef<int64_t> sizePerSg,
ArrayRef<int64_t> sizePerWg) {

genOffsets(OpBuilder &builder, Location loc, SmallVector<Value> delinearizedId,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the function should be genCoords(). Inside the function, when we compute each individual value, it is fine to still use offset.

A coordinate (coord) in n-d tensor is a vector of logical offsets.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually reading the code below, almost all "offsets" variable (vector of value) can be renamed to "coord".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed.


auto layout = matrixOp.getLayoutAttr();
if (!layout)
return rewriter.notifyMatchFailure(
Copy link
Contributor

@Jianhui-Li Jianhui-Li Oct 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add one verifier here. For operation without subgroup_block_io, the lane_data must be physically contiguous in the slm.

// this is correct
 %1 = xegpu.load_matrix %arg0[%c0, %c0] {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} :
      !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32]>>, index, index -> vector<2x16xf32>

// this is not correct. 
 %1 = xegpu.load_matrix %arg0[%c0, %c0] {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} :
      !xegpu.mem_desc<32x32xf32>, index, index -> vector<2x16xf32>

For operation with subgroup_block_io, the lane_data must be [1, 1].
See https://github.com/llvm/llvm-project/pull/158118/files#diff-151653efd45c964d5c1af1ceb208f2eb55a344a985e874b922495bf3e4256c81R213

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a verifier check for both linear and coalesced data.

@akroviakov akroviakov requested a review from Jianhui-Li October 29, 2025 15:28
@charithaintc charithaintc self-requested a review October 30, 2025 15:31
Copy link
Contributor

@charithaintc charithaintc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

Copy link
Contributor

@Jianhui-Li Jianhui-Li left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some more comments about the verifier.

VectorType valOrResVecTy = dyn_cast<VectorType>(data.getType());
if (!valOrResVecTy)
valOrResVecTy = VectorType::get(1, data.getType());
if (valOrResVecTy.getShape().size() != 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If user uses stride=[1, 32] in the memory layout, then user should able to reason sg_layout = [16, 1].

VectorType valOrResVecTy = dyn_cast<VectorType>(data.getType());
if (!valOrResVecTy)
valOrResVecTy = VectorType::get(1, data.getType());
if (valOrResVecTy.getShape().size() != 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if user use lane_layout = [1, 16], it should not use strided memory layout, the example above should just use block layout. The maxtrix op with subgroup_block_io is a subgroup operation, and all lanes collectively access a contiguous memory buffer.

[](int x) { return x == 1; });
if (!isLaneDataLinear)
return emitError()
<< "With subgroup_block_io, lane data must be linear.";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we are checking the tensor tile being accessed by matrix op (with subgroup_block_io) is contiguous. It means we need to check the following :

  1. lanedata = [ 1, ..1] // 1 for each dim
  2. for dim x where lane_layout [x]! = 1, stride[x] must be 1, and block[x] must equal to lane_layout[x] if layout is blocked.

Why call it "linear"? Can the error message just be "With subgroup_block_io, tensor tile accessed by subgroup must be contiguous"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed. I use literal strides from the attribute, not the getStrideShape.

@akroviakov akroviakov force-pushed the akroviak/xegpu-matrix-ops-sg-lowering branch 2 times, most recently from 6344048 to 10448e1 Compare October 31, 2025 12:37
@akroviakov akroviakov force-pushed the akroviak/xegpu-matrix-ops-sg-lowering branch from 10448e1 to c99294a Compare October 31, 2025 13:13
@akroviakov akroviakov requested a review from Jianhui-Li October 31, 2025 13:46
Copy link
Contributor

@Jianhui-Li Jianhui-Li left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@akroviakov akroviakov merged commit 68c4c83 into llvm:main Nov 3, 2025
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants