Skip to content

Commit 50a7eb6

Browse files
authored
[MLIR][XeGPU] Add support for vector.multi_reduction in wg to sg pass [1/N] (#157554)
This PR adds pattern for lowering vector.multi_reduction from workgroup to subgroup IR. It currently only supports sg local reductions
1 parent a5569b4 commit 50a7eb6

File tree

3 files changed

+127
-2
lines changed

3 files changed

+127
-2
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,6 +1027,70 @@ struct WgToSgVectorShapeCastOp
10271027
}
10281028
};
10291029

1030+
/// Pattern for lowering vector.multi_reduction op to subgroup level.
1031+
/// Current limitation: the sg_layout in the reduced dimension being 1
1032+
/// so that reduction is local to subgroup & no cross-subgroup communication is
1033+
/// needed.
1034+
/// TODO: Add cases to handle more general situations which require SLM access.
1035+
struct WgToSgMultiDimReductionOp
1036+
: public OpConversionPattern<vector::MultiDimReductionOp> {
1037+
using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
1038+
1039+
LogicalResult
1040+
matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor,
1041+
ConversionPatternRewriter &rewriter) const override {
1042+
VectorType srcType = op.getSourceVectorType();
1043+
VectorType dstType = dyn_cast<VectorType>(op.getResult().getType());
1044+
if (!dstType)
1045+
return failure();
1046+
1047+
auto srcShape = srcType.getShape();
1048+
xegpu::DistributeLayoutAttr layout =
1049+
xegpu::getDistributeLayoutAttr(op.getResult());
1050+
if (!layout || !layout.isForWorkgroup())
1051+
return failure();
1052+
1053+
auto reductionDims = llvm::to_vector(op.getReductionDims());
1054+
1055+
SmallVector<int64_t> sgLayout = llvm::cast<xegpu::SliceAttr>(layout)
1056+
.getParent()
1057+
.getEffectiveSgLayoutAsInt();
1058+
SmallVector<int64_t> sgData = llvm::cast<xegpu::SliceAttr>(layout)
1059+
.getParent()
1060+
.getEffectiveSgDataAsInt();
1061+
1062+
// Check that the sgLayout in the reduced dimension is 1 and
1063+
// each sg gets the entire slice to reduce.
1064+
for (int64_t dim : reductionDims) {
1065+
if (sgLayout[dim] != 1 || sgData[dim] != srcShape[dim])
1066+
return rewriter.notifyMatchFailure(
1067+
op,
1068+
"sgLayout in each reduced dimension must be 1 and sgData in the "
1069+
"reduced dim must match srcShape in that dim");
1070+
}
1071+
1072+
SmallVector<int64_t> sgShape = getSgShapeAndCount(srcShape, layout).first;
1073+
1074+
VectorType newDstType =
1075+
VectorType::get({sgShape}, dstType.getElementType());
1076+
1077+
SmallVector<Value> newReductions;
1078+
for (auto sgSrc : adaptor.getSource()) {
1079+
auto newOp = rewriter.create<vector::MultiDimReductionOp>(
1080+
op.getLoc(), newDstType, op.getKind(), sgSrc, adaptor.getAcc()[0],
1081+
op.getReductionDims());
1082+
if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
1083+
!layout.getEffectiveInstDataAsInt().empty())
1084+
xegpu::setDistributeLayoutAttr(newOp->getResult(0),
1085+
layout.dropSgLayoutAndData());
1086+
newReductions.push_back(newOp.getResult());
1087+
}
1088+
1089+
rewriter.replaceOpWithMultiple(op, {newReductions});
1090+
return success();
1091+
}
1092+
};
1093+
10301094
} // namespace
10311095

10321096
namespace mlir {
@@ -1040,8 +1104,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
10401104
WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
10411105
WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
10421106
WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
1043-
WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp>(
1044-
patterns.getContext());
1107+
WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp,
1108+
WgToSgMultiDimReductionOp>(patterns.getContext());
10451109
}
10461110
} // namespace xegpu
10471111
} // namespace mlir
@@ -1195,6 +1259,11 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
11951259
return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));
11961260
});
11971261

1262+
target.addDynamicallyLegalOp<vector::MultiDimReductionOp>(
1263+
[=](vector::MultiDimReductionOp op) -> bool {
1264+
return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));
1265+
});
1266+
11981267
target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
11991268
[=](xegpu::ConvertLayoutOp op) -> bool {
12001269
return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,4 +82,20 @@ gpu.module @test_distribution {
8282
: vector<256x128xf16>, vector<128x256xf16> -> vector<256x256xf32>
8383
gpu.return
8484
}
85+
86+
// CHECK-LABEL: vector_reduce_dim_1
87+
gpu.func @vector_reduce_dim_1(%src: memref<256x64xf32>) {
88+
// CHECK: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<16xf32>
89+
%cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 1], sg_data = [16, 64]>, dims = [1]>} dense<1.0> : vector<256xf32>
90+
%tdesc = xegpu.create_nd_tdesc %src : memref<256x64xf32>
91+
-> !xegpu.tensor_desc<256x64xf32, #xegpu.layout<sg_layout = [8, 1], sg_data = [16, 64]>>
92+
%load = xegpu.load_nd %tdesc[0, 0]
93+
: !xegpu.tensor_desc<256x64xf32, #xegpu.layout<sg_layout = [8, 1], sg_data = [16, 64]>>
94+
-> vector<256x64xf32>
95+
// CHECK-COUNT-2: vector.multi_reduction <add>, {{.*}}, %[[CST]] [1] : vector<16x64xf32> to vector<16xf32>
96+
// CHECK-NOT: vector.multi_reduction
97+
%reduce = vector.multi_reduction <add>, %load, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 1], sg_data = [16, 64]>, dims = [1]>} [1]
98+
: vector<256x64xf32> to vector<256xf32>
99+
gpu.return
100+
}
85101
}

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,46 @@ gpu.module @test_distribution {
367367
gpu.return
368368
}
369369

370+
// CHECK-LABEL: @vector_reduce_dim_0
371+
gpu.func @vector_reduce_dim_0(%src: memref<4x128xf32>) {
372+
%cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [1, 32], sg_data = [4, 4]>, dims = [0]>} dense<1.0> : vector<128xf32>
373+
%tdesc = xegpu.create_nd_tdesc %src : memref<4x128xf32>
374+
-> !xegpu.tensor_desc<4x128xf32, #xegpu.layout<sg_layout = [1, 32], sg_data = [4, 4]>>
375+
%load = xegpu.load_nd %tdesc[0, 0]
376+
: !xegpu.tensor_desc<4x128xf32, #xegpu.layout<sg_layout = [1, 32], sg_data = [4, 4]>>
377+
-> vector<4x128xf32>
378+
// CHECK: vector.multi_reduction <add>, {{.*}}, {{.*}} [0] : vector<4x4xf32> to vector<4xf32>
379+
%reduce = vector.multi_reduction <add>, %load, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [1, 32], sg_data = [4, 4]>, dims = [0]>} [0]
380+
: vector<4x128xf32> to vector<128xf32>
381+
gpu.return
382+
}
383+
384+
// CHECK-LABEL: @vector_reduce_dim_1
385+
gpu.func @vector_reduce_dim_1(%src: memref<256x64xf32>) {
386+
%cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [16, 1], sg_data = [16, 64]>, dims = [1]>} dense<1.0> : vector<256xf32>
387+
%tdesc = xegpu.create_nd_tdesc %src : memref<256x64xf32>
388+
-> !xegpu.tensor_desc<256x64xf32, #xegpu.layout<sg_layout = [16, 1], sg_data = [16, 64]>>
389+
%load = xegpu.load_nd %tdesc[0, 0]
390+
: !xegpu.tensor_desc<256x64xf32, #xegpu.layout<sg_layout = [16, 1], sg_data = [16, 64]>>
391+
-> vector<256x64xf32>
392+
// CHECK: vector.multi_reduction <add>, {{.*}}, {{.*}} [1] : vector<16x64xf32> to vector<16xf32>
393+
%reduce = vector.multi_reduction <add>, %load, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [16, 1], sg_data = [16, 64]>, dims = [1]>} [1]
394+
: vector<256x64xf32> to vector<256xf32>
395+
gpu.return
396+
}
397+
398+
// CHECK-LABEL: @vector_reduce_4D
399+
gpu.func @vector_reduce_4D(%src: ui64) {
400+
%cst_acc = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [4, 2, 6, 1], sg_data = [1, 1, 1, 32]>, dims = [3]>} dense<0.0> : vector<4x2x6xf16>
401+
%offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [4, 2, 6, 1], sg_data = [1, 1, 1, 32]>} dense<0> : vector<4x2x6x32xindex>
402+
%mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [4, 2, 6, 1], sg_data = [1, 1, 1, 32]>} dense<true> : vector<4x2x6x32xi1>
403+
%load = xegpu.load %src[%offset], %mask {layout_result_0 = #xegpu.layout<sg_layout = [4, 2, 6, 1], sg_data = [1, 1, 1, 32]>} : ui64, vector<4x2x6x32xindex>, vector<4x2x6x32xi1> -> vector<4x2x6x32xf16>
404+
// CHECK: vector.multi_reduction <add>, {{.*}}, {{.*}} [3] : vector<1x1x1x32xf16> to vector<1x1x1xf16>
405+
%reduce = vector.multi_reduction <add>, %load, %cst_acc {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [4, 2, 6, 1], sg_data = [1, 1, 1, 32]>, dims = [3]>} [3]
406+
: vector<4x2x6x32xf16> to vector<4x2x6xf16>
407+
gpu.return
408+
}
409+
370410
// CHECK-LABEL: vector_step_op
371411
gpu.func @vector_step_op_slice_attr() {
372412
//CHECK: [[sgId:%.+]] = gpu.subgroup_id : index

0 commit comments

Comments
 (0)