Skip to content

Commit 100341d

Browse files
committed
Add CHECKS
1 parent c3e5986 commit 100341d

File tree

2 files changed

+41
-19
lines changed

2 files changed

+41
-19
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -830,16 +830,17 @@ struct WgToSgMultiDimReductionOp
830830
matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor,
831831
ConversionPatternRewriter &rewriter) const override {
832832
Location loc = op.getLoc();
833-
// Only support reduction with layout and on a single dimension for now.
834833
VectorType srcType = dyn_cast<VectorType>(op.getSource().getType());
835834
VectorType accType = dyn_cast<VectorType>(op.getAcc().getType());
836835
VectorType resType = dyn_cast<VectorType>(op.getResult().getType());
837836
Type elemTy = srcType.getElementType();
838837
if (!srcType || !accType || !resType)
839838
return failure();
840839

840+
// Support only 2D vectors
841+
if (srcType.getShape().size() != 2 && resType.getShape().size() != 1)
842+
return failure();
841843
ArrayRef<int64_t> wgShape = resType.getShape();
842-
// Handle both LayoutAttr and SliceAttr for the op result.
843844
auto layoutName = xegpu::getLayoutName(op->getResult(0));
844845
auto sliceAttr = op->getAttrOfType<xegpu::SliceAttr>(layoutName);
845846
if (!sliceAttr || sliceAttr.getRank() != 1)
@@ -871,7 +872,6 @@ struct WgToSgMultiDimReductionOp
871872
VectorType::get(shapeCastShape, srcType.getElementType());
872873
auto shapeCast = rewriter.create<vector::ShapeCastOp>(
873874
op.getLoc(), shapeCastTy, sgReduce.getResult());
874-
// TODO: Change it to shapeCast
875875
newReductions.push_back(shapeCast.getResult());
876876
}
877877

@@ -889,23 +889,23 @@ struct WgToSgMultiDimReductionOp
889889
auto slmTy = MemRefType::get(slmSize, rewriter.getI8Type(), {}, 3);
890890
auto slm = rewriter.create<memref::AllocaOp>(loc, slmTy);
891891

892-
// Create a view for the SLM buffer using xegpu.create_mem_desc
893-
SmallVector<int64_t> viewShape;
892+
// Create a SLM buffer using xegpu.create_mem_desc
893+
SmallVector<int64_t> memDescShape;
894894
auto srcVecType = dyn_cast<VectorType>(adaptor.getSource()[0].getType());
895895
ArrayRef<int64_t> srcShape =
896896
srcVecType ? srcVecType.getShape() : ArrayRef<int64_t>();
897897
for (size_t i = 0; i < srcShape.size(); ++i) {
898898
if (static_cast<int64_t>(i) == reduceDim) {
899899
// For the reduced dimension, use sgLayoutParent[i]
900-
viewShape.push_back(sgLayoutParent[i]);
900+
memDescShape.push_back(sgLayoutParent[i]);
901901
} else {
902902
// For other dimensions, multiply sgLayoutParent[i] by sgShape[i]
903-
viewShape.push_back(sgLayoutParent[i] * srcShape[i]);
903+
memDescShape.push_back(sgLayoutParent[i] * srcShape[i]);
904904
}
905905
}
906906

907-
auto memDescType = xegpu::MemDescType::get(rewriter.getContext(), viewShape,
908-
elemTy, nullptr);
907+
auto memDescType = xegpu::MemDescType::get(rewriter.getContext(),
908+
memDescShape, elemTy, nullptr);
909909
auto memDesc =
910910
rewriter.create<xegpu::CreateMemDescOp>(loc, memDescType, slm);
911911

@@ -951,16 +951,16 @@ struct WgToSgMultiDimReductionOp
951951
// Step 3: Load from SLM for the second reduction
952952
SmallVector<int64_t> slmLoadShape;
953953

954-
for (size_t i = 0; i < viewShape.size(); ++i) {
954+
for (size_t i = 0; i < memDescShape.size(); ++i) {
955955
if (static_cast<int64_t>(i) == reduceDim) {
956-
slmLoadShape.push_back(viewShape[i]);
956+
slmLoadShape.push_back(memDescShape[i]);
957957
} else {
958958
int64_t divisor = computeProduct(sgLayoutParent);
959-
slmLoadShape.push_back(viewShape[i] / divisor);
959+
slmLoadShape.push_back(memDescShape[i] / divisor);
960960
}
961961
}
962962

963-
// Calculate offsets for create_nd_desc
963+
// Calculate offsets for load_matrix op
964964
SmallVector<OpFoldResult> slmLoadOffsets;
965965
for (size_t i = 0; i < sgLayoutParent.size(); ++i) {
966966
Value offset = rewriter.createOrFold<index::MulOp>(
@@ -975,8 +975,8 @@ struct WgToSgMultiDimReductionOp
975975
/*layout=*/nullptr);
976976

977977
// Step 4: Create a constant accumulator for the second reduction
978-
// with same vallue as adaptor.getAcc()[0] and shape set to
979-
// the non reduce dimension of shapeCastLoad
978+
// with same value as adaptor.getAcc()[0] and shape set to
979+
// the non reduce dimension of load
980980
auto accShape = load.getType().getShape();
981981
SmallVector<int64_t> accShapeWithoutReduceDim;
982982
for (size_t i = 0; i < accShape.size(); ++i) {
@@ -1180,10 +1180,6 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
11801180

11811181
target.addDynamicallyLegalOp<vector::MultiDimReductionOp>(
11821182
[=](vector::MultiDimReductionOp op) -> bool {
1183-
// Only allow MultiDimReductionOp with a single reduction dimension
1184-
if (op.getReductionDims().size() != 1)
1185-
return true;
1186-
11871183
// Check if the layout is legal
11881184
return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));
11891185
});

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,33 @@ gpu.module @test_distribution {
323323
}
324324

325325
//CHECK-LABEL: vector_reduce
326+
// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
326327
gpu.func @vector_reduce(%src: memref<256x128xf32>) {
328+
// CHECK: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<32xf32>
329+
// CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32> -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [8, 4], lane_data = [1, 1]>>
330+
// CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]] : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [8, 4], lane_data = [1, 1]>> -> vector<32x32xf32>
331+
// CHECK: %[[REDUCE:.*]] = vector.multi_reduction <add>, {{%.*}}, %[[CST]] [0] : vector<32x32xf32> to vector<32xf32>
332+
// CHECK: %[[SHAPECAST:.*]] = vector.shape_cast %[[REDUCE]] : vector<32xf32> to vector<1x32xf32>
333+
// CHECK: %[[ALLOCA:.*]] = memref.alloca() : memref<4096xi8, 3>
334+
// CHECK: %[[MDESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<4096xi8, 3> -> !xegpu.mem_desc<8x128xf32>
335+
// CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
336+
// CHECK: %[[C8:.*]] = arith.constant 8 : index
337+
// CHECK: %[[C4:.*]] = arith.constant 4 : index
338+
// CHECK: %[[C4_1:.*]] = arith.constant 4 : index
339+
// CHECK: %[[ID_Y:.*]] = affine.apply #map()[%[[SGID]]]
340+
// CHECK: %[[ID_X:.*]] = affine.apply #map1()[%[[SGID]]]
341+
// CHECK: %[[C1:.*]] = arith.constant 1 : index
342+
// CHECK: %[[C32:.*]] = arith.constant 32 : index
343+
// CHECK: %[[L_OFF_X:.*]] = index.mul %[[ID_X]], %[[C32]]
344+
// CHECK: xegpu.store_matrix {{.*}}, %[[MDESC]][%[[ID_Y]], %[[L_OFF_X]]] : vector<1x32xf32>, !xegpu.mem_desc<8x128xf32>, index, index
345+
// CHECK: gpu.barrier
346+
// CHECK: %[[C8_1:.*]] = arith.constant 8 : index
347+
// CHECK: %[[OFF_Y:.*]] = index.mul %[[ID_Y]], %[[C8_1]]
348+
// CHECK: %[[C4_2:.*]] = arith.constant 4 : index
349+
// CHECK: %[[OFF_X:.*]] = index.mul %[[ID_X]], %[[C4_2]]
350+
// CHECK: %[[LOAD:.*]] = xegpu.load_matrix %[[MDESC]][%[[OFF_Y]], %[[OFF_X]]] : !xegpu.mem_desc<8x128xf32>, index, index -> vector<8x4xf32>
351+
// CHECK: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<4xf32>
352+
// CHECK: %[[REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD]], %[[CST]] [0] : vector<8x4xf32> to vector<4xf32>
327353
%cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>, dims = [0]>} dense<1.0> : vector<128xf32>
328354
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
329355
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>

0 commit comments

Comments
 (0)