Skip to content

Commit 9be2284

Browse files
committed
Address feedback
1 parent cf1eb16 commit 9be2284

File tree

2 files changed

+31
-9
lines changed

2 files changed

+31
-9
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,22 +1027,28 @@ struct WgToSgVectorShapeCastOp
10271027
}
10281028
};
10291029

1030-
// Pattern for lowering vector.multi_reduction op to subgroup level.
1030+
/// Pattern for lowering vector.multi_reduction op to subgroup level.
1031+
/// Current limitation: only support 2D->1D reduction with single reduction
1032+
/// dimension, and the sg_layout in the reduced dimension being 1
1033+
/// so that reduction is local to subgroup & no cross-subgroup communication is
1034+
/// needed.
1035+
/// TODO: Add cases to handle more general situations which require SLM access.
10311036
struct WgToSgMultiDimReductionOp
10321037
: public OpConversionPattern<vector::MultiDimReductionOp> {
10331038
using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
10341039

10351040
LogicalResult
10361041
matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor,
10371042
ConversionPatternRewriter &rewriter) const override {
1038-
VectorType srcType = dyn_cast<VectorType>(op.getSource().getType());
1043+
VectorType srcType = op.getSourceVectorType();
10391044
VectorType dstType = dyn_cast<VectorType>(op.getResult().getType());
1040-
if (!srcType || !dstType)
1045+
if (!dstType)
10411046
return failure();
10421047

1043-
// TODO: generalize it
1044-
auto srcShape = srcType.getShape();
1045-
auto dstShape = dstType.getShape();
1048+
SmallVector<int64_t> srcShape(srcType.getShape().begin(),
1049+
srcType.getShape().end());
1050+
SmallVector<int64_t> dstShape(dstType.getShape().begin(),
1051+
dstType.getShape().end());
10461052
if (srcShape.size() != 2 || dstShape.size() != 1)
10471053
return failure();
10481054

@@ -1051,7 +1057,8 @@ struct WgToSgMultiDimReductionOp
10511057
if (!layout || !layout.isForWorkgroup())
10521058
return failure();
10531059

1054-
auto reductionDims = op.getReductionDims();
1060+
SmallVector<int64_t> reductionDims(op.getReductionDims().begin(),
1061+
op.getReductionDims().end());
10551062
if (reductionDims.size() != 1)
10561063
return failure();
10571064

@@ -1069,8 +1076,8 @@ struct WgToSgMultiDimReductionOp
10691076
SmallVector<Value> newReductions;
10701077
for (auto [sgSrc, sgAcc] :
10711078
llvm::zip(adaptor.getSource(), adaptor.getAcc())) {
1072-
auto newOp = rewriter.create<vector::MultiDimReductionOp>(
1073-
op.getLoc(), newDstType, op.getKind(), sgSrc, sgAcc,
1079+
auto newOp = vector::MultiDimReductionOp::create(
1080+
rewriter, op.getLoc(), newDstType, op.getKind(), sgSrc, sgAcc,
10741081
op.getReductionDims());
10751082
if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
10761083
!layout.getEffectiveInstDataAsInt().empty())

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,4 +82,19 @@ gpu.module @test_distribution {
8282
: vector<256x128xf16>, vector<128x256xf16> -> vector<256x256xf32>
8383
gpu.return
8484
}
85+
86+
// CHECK-LABEL: vector_reduce_dim_1
87+
gpu.func @vector_reduce_dim_1(%src: memref<256x64xf32>) {
88+
%cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 1], sg_data = [16, 64]>, dims = [1]>} dense<1.0> : vector<256xf32>
89+
%tdesc = xegpu.create_nd_tdesc %src : memref<256x64xf32>
90+
-> !xegpu.tensor_desc<256x64xf32, #xegpu.layout<sg_layout = [8, 1], sg_data = [16, 64]>>
91+
%load = xegpu.load_nd %tdesc[0, 0]
92+
: !xegpu.tensor_desc<256x64xf32, #xegpu.layout<sg_layout = [8, 1], sg_data = [16, 64]>>
93+
-> vector<256x64xf32>
94+
// CHECK-COUNT-2: vector.multi_reduction <add>, {{.*}}, {{.*}} [1] : vector<16x64xf32> to vector<16xf32>
95+
// CHECK-NOT: vector.multi_reduction
96+
%reduce = vector.multi_reduction <add>, %load, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 1], sg_data = [16, 64]>, dims = [1]>} [1]
97+
: vector<256x64xf32> to vector<256xf32>
98+
gpu.return
99+
}
85100
}

0 commit comments

Comments
 (0)