Skip to content

Commit 887f978

Browse files
committed
[MLIR][XeGPU] Matrix load/store subgroup distribution
1 parent 986e0fe commit 887f978

File tree

2 files changed

+131
-8
lines changed

2 files changed

+131
-8
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Lines changed: 116 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -906,6 +906,110 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
906906
}
907907
};
908908

909+
template <class MatrixOp>
910+
struct MatrixOpDistribution final : public gpu::WarpDistributionPattern {
911+
using gpu::WarpDistributionPattern::WarpDistributionPattern;
912+
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
913+
PatternRewriter &rewriter) const override {
914+
gpu::YieldOp yield = warpOp.getTerminator();
915+
Operation *lastNode = yield->getPrevNode();
916+
auto matrixOp = dyn_cast_or_null<MatrixOp>(lastNode);
917+
if (!matrixOp)
918+
return failure();
919+
constexpr bool isLoad{std::is_same_v<MatrixOp, xegpu::LoadMatrixOp>};
920+
int operandIdx{-1};
921+
922+
VectorType payloadTy;
923+
VectorType warpResultTy;
924+
if constexpr (isLoad) {
925+
OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
926+
return isa<xegpu::LoadMatrixOp>(op) && matrixOp == op;
927+
});
928+
if (!producedByLastLoad)
929+
return rewriter.notifyMatchFailure(
930+
warpOp, "The last op is not xegpu::LoadMatrixOp");
931+
operandIdx = producedByLastLoad->getOperandNumber();
932+
payloadTy = dyn_cast<VectorType>(matrixOp.getResult().getType());
933+
warpResultTy = cast<VectorType>(warpOp.getResult(operandIdx).getType());
934+
} else {
935+
payloadTy = dyn_cast<VectorType>(matrixOp.getData().getType());
936+
}
937+
if (!payloadTy)
938+
return rewriter.notifyMatchFailure(
939+
matrixOp, "the matrix op payload must be a vector type");
940+
941+
auto loc = matrixOp.getLoc();
942+
auto offsets = matrixOp.getMixedOffsets();
943+
if (offsets.empty())
944+
return rewriter.notifyMatchFailure(matrixOp,
945+
"the load op must have offsets");
946+
SmallVector<Value> offsetsAsValues =
947+
vector::getAsValues(rewriter, matrixOp.getLoc(), offsets);
948+
949+
auto layout = matrixOp.getLayoutAttr();
950+
if (!layout)
951+
return rewriter.notifyMatchFailure(
952+
matrixOp, "the matrix operation lacks layout attribute");
953+
954+
FailureOr<VectorType> distPayloadByWarpOpOrFailure =
955+
getDistVecTypeBasedOnLaneLayout(layout, payloadTy);
956+
if (failed(distPayloadByWarpOpOrFailure))
957+
return rewriter.notifyMatchFailure(
958+
matrixOp,
959+
"The matrix op payload has no layouts, using defaults instead.");
960+
961+
SmallVector<Value> operands;
962+
if constexpr (isLoad)
963+
operands = {matrixOp.getMemDesc()};
964+
else
965+
operands = {matrixOp.getData(), matrixOp.getMemDesc()};
966+
const unsigned offsetsStartIdx = operands.size();
967+
operands.append(offsetsAsValues);
968+
969+
SmallVector<Type> operandTypes = llvm::to_vector(
970+
llvm::map_range(operands, [](Value v) { return v.getType(); }));
971+
if constexpr (!isLoad)
972+
operandTypes[0] = *distPayloadByWarpOpOrFailure;
973+
974+
SmallVector<size_t> newRetIndices;
975+
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
976+
rewriter, warpOp, operands, operandTypes, newRetIndices);
977+
SmallVector<Value> newOperands = llvm::map_to_vector(
978+
newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
979+
980+
rewriter.setInsertionPointAfter(newWarpOp);
981+
unsigned operandIdxToModify = offsetsStartIdx + offsetsAsValues.size() - 1;
982+
newOperands[operandIdxToModify] = arith::AddIOp::create(
983+
rewriter, loc, rewriter.getIndexType(), newOperands[operandIdxToModify],
984+
newWarpOp.getLaneid());
985+
986+
SmallVector<int64_t> newConstOffsets{matrixOp.getConstOffsets()};
987+
std::fill(newConstOffsets.begin(), newConstOffsets.end(),
988+
ShapedType::kDynamic);
989+
DenseI64ArrayAttr newConstOffsetsAttr =
990+
rewriter.getDenseI64ArrayAttr(newConstOffsets);
991+
ValueRange newOffsets = ValueRange(newOperands).drop_front(offsetsStartIdx);
992+
993+
if constexpr (isLoad) {
994+
xegpu::LoadMatrixOp newOp = xegpu::LoadMatrixOp::create(
995+
rewriter, newWarpOp.getLoc(), *distPayloadByWarpOpOrFailure,
996+
newOperands[0], newOffsets, newConstOffsetsAttr,
997+
matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
998+
// Resolve the output type and replace all uses.
999+
rewriter.replaceAllUsesWith(
1000+
newWarpOp.getResult(operandIdx),
1001+
resolveDistributedTy(newOp.getResult(), warpResultTy, rewriter));
1002+
} else {
1003+
xegpu::StoreMatrixOp::create(
1004+
rewriter, loc, TypeRange{}, newOperands[0], newOperands[1],
1005+
newOffsets, newConstOffsetsAttr, matrixOp.getSubgroupBlockIoAttr(),
1006+
xegpu::DistributeLayoutAttr{});
1007+
rewriter.eraseOp(matrixOp);
1008+
}
1009+
return success();
1010+
}
1011+
};
1012+
9091013
/// Distribute a scattered load op. The logic and requirements are the same as
9101014
/// for the scattered store distribution. The warpOp's payload vector is
9111015
/// expected to be distributed by the load's result consumer.
@@ -1433,14 +1537,16 @@ struct XeGPUSubgroupDistributePass final
14331537

14341538
void xegpu::populateXeGPUSubgroupDistributePatterns(
14351539
RewritePatternSet &patterns) {
1436-
patterns.add<CreateNdDescDistribution, StoreNdDistribution,
1437-
LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
1438-
GpuBarrierDistribution, VectorMultiReductionDistribution,
1439-
LoadDistribution, StoreDistribution, VectorTransposeDistribution,
1440-
VectorBitcastDistribution,
1441-
MemrefExtractAlignedPointerAsIndexDistribution>(
1442-
patterns.getContext(),
1443-
/*pattern benefit=*/regularPatternBenefit);
1540+
patterns
1541+
.add<CreateNdDescDistribution, StoreNdDistribution, LoadNdDistribution,
1542+
DpasDistribution, PrefetchNdDistribution, GpuBarrierDistribution,
1543+
VectorMultiReductionDistribution, LoadDistribution,
1544+
StoreDistribution, VectorTransposeDistribution,
1545+
VectorBitcastDistribution, MatrixOpDistribution<xegpu::LoadMatrixOp>,
1546+
MatrixOpDistribution<xegpu::StoreMatrixOp>,
1547+
MemrefExtractAlignedPointerAsIndexDistribution>(
1548+
patterns.getContext(),
1549+
/*pattern benefit=*/regularPatternBenefit);
14441550
patterns.add<VectorShapeCastDistribution>(
14451551
patterns.getContext(),
14461552
/*pattern benefit=*/highPatternBenefit);
@@ -1462,6 +1568,8 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
14621568
// Layouts are needed for vector type only.
14631569
if (!isa<VectorType>(operand.get().getType()))
14641570
continue;
1571+
if (isa<xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>(op))
1572+
continue;
14651573

14661574
auto layout = xegpu::getDistributeLayoutAttr(operand.get());
14671575
if (!layout) {

mlir/test/Dialect/XeGPU/subgroup-distribute.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,3 +265,18 @@ gpu.module @xevm_module{
265265
gpu.return
266266
}
267267
}
268+
269+
// -----
270+
// CHECK-LABEL: gpu.func @load_store_matrix_1({{.*}}) {
271+
// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[{{.*}}] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<1x1xf32>
272+
// CHECK: xegpu.store_matrix %[[MAT]], %arg0[{{.*}}] : vector<1x1xf32>, !xegpu.mem_desc<32x32xf32>, index, index
273+
gpu.module @xevm_module{
274+
gpu.func @load_store_matrix_1(%arg0: !xegpu.mem_desc<32x32xf32>) {
275+
%c0 = arith.constant 0 : index
276+
%1 = xegpu.load_matrix %arg0[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}> : !xegpu.mem_desc<32x32xf32>, index, index -> vector<2x8xf32>
277+
278+
xegpu.store_matrix %1, %arg0[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}> : vector<2x8xf32>, !xegpu.mem_desc<32x32xf32>, index, index
279+
280+
gpu.return
281+
}
282+
}

0 commit comments

Comments
 (0)