Skip to content

Commit 310abe0

Browse files
authored
[MLIR] [XeGPU] Add distribution pattern for vector.constant_mask from Wg To Sg (#168118)
1 parent fbc0935 commit 310abe0

File tree

3 files changed

+113
-5
lines changed

3 files changed

+113
-5
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 70 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1285,6 +1285,71 @@ struct WgToSgVectorTransposeOp
12851285
}
12861286
};
12871287

1288+
// This pattern distributes the vector.constant_mask ops to work at subgroup
1289+
// level.
1290+
struct WgToSgVectorConstantMaskOp
1291+
: public OpConversionPattern<vector::ConstantMaskOp> {
1292+
using OpConversionPattern<vector::ConstantMaskOp>::OpConversionPattern;
1293+
1294+
LogicalResult
1295+
matchAndRewrite(vector::ConstantMaskOp op, OneToNOpAdaptor adaptor,
1296+
ConversionPatternRewriter &rewriter) const override {
1297+
xegpu::DistributeLayoutAttr layout =
1298+
xegpu::getDistributeLayoutAttr(op.getResult());
1299+
if (!layout || !layout.isForWorkgroup())
1300+
return failure();
1301+
1302+
Location loc = op.getLoc();
1303+
VectorType type = op.getResult().getType();
1304+
auto wgShape = type.getShape();
1305+
1306+
ArrayRef<int64_t> wgMaskDimSizes = op.getMaskDimSizes();
1307+
1308+
// Get subgroup ID.
1309+
Value sgId =
1310+
gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
1311+
auto sgOffsets =
1312+
layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
1313+
if (failed(sgOffsets))
1314+
return failure();
1315+
1316+
SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1317+
VectorType resultType = VectorType::get(sgShape, type.getElementType());
1318+
1319+
// In each dimension, each subgroup computes its local mask size as:
1320+
// min(max(wgMaskSize[d] - offset[d], 0), sgDimSize[d])
1321+
SmallVector<Value> newCreateMaskOps;
1322+
for (auto offsetSet : *sgOffsets) {
1323+
SmallVector<Value> maskOperands;
1324+
1325+
for (auto [i, wgMaskSize] : llvm::enumerate(wgMaskDimSizes)) {
1326+
Value wgMaskSizeVal =
1327+
arith::ConstantIndexOp::create(rewriter, loc, wgMaskSize);
1328+
Value dimSizeVal =
1329+
arith::ConstantIndexOp::create(rewriter, loc, sgShape[i]);
1330+
Value offset = offsetSet[i];
1331+
Value adjustedMaskSize =
1332+
arith::SubIOp::create(rewriter, loc, wgMaskSizeVal, offset);
1333+
Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
1334+
Value nonNegative =
1335+
arith::MaxSIOp::create(rewriter, loc, adjustedMaskSize, zero);
1336+
Value sgMaskSize =
1337+
arith::MinSIOp::create(rewriter, loc, nonNegative, dimSizeVal);
1338+
maskOperands.push_back(sgMaskSize);
1339+
}
1340+
1341+
auto newCreateMaskOp =
1342+
vector::CreateMaskOp::create(rewriter, loc, resultType, maskOperands);
1343+
xegpu::setDistributeLayoutAttr(newCreateMaskOp->getResult(0),
1344+
layout.dropSgLayoutAndData());
1345+
newCreateMaskOps.push_back(newCreateMaskOp.getResult());
1346+
}
1347+
1348+
rewriter.replaceOpWithMultiple(op, {newCreateMaskOps});
1349+
return success();
1350+
}
1351+
};
1352+
12881353
} // namespace
12891354

12901355
namespace mlir {
@@ -1299,8 +1364,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
12991364
WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
13001365
WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
13011366
WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp,
1302-
WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp>(
1303-
patterns.getContext());
1367+
WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp,
1368+
WgToSgVectorConstantMaskOp>(patterns.getContext());
13041369
}
13051370
} // namespace xegpu
13061371
} // namespace mlir
@@ -1427,9 +1492,9 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
14271492
return isLegal(layout);
14281493
});
14291494

1430-
target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp,
1431-
vector::TransposeOp, vector::BroadcastOp,
1432-
vector::MultiDimReductionOp>(
1495+
target.addDynamicallyLegalOp<
1496+
vector::ShapeCastOp, vector::StepOp, vector::TransposeOp,
1497+
vector::BroadcastOp, vector::MultiDimReductionOp, vector::ConstantMaskOp>(
14331498
[=](Operation *op) -> bool {
14341499
// Check for either a SliceAttr or LayoutAttr on the result.
14351500
auto layout = xegpu::getDistributeLayoutAttr(op->getResult(0));

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,5 +130,13 @@ gpu.module @test_distribution {
130130
%trans = vector.transpose %load, [1, 0] {layout_result_0 = #xegpu.layout<sg_layout = [4, 8], sg_data = [16, 32], lane_layout = [1, 16], lane_data = [1, 1], order =[1, 0]>} : vector<256x128xf32> to vector<128x256xf32>
131131
gpu.return
132132
}
133+
134+
// CHECK-LABEL: vector_mask_2D
135+
gpu.func @vector_mask_2D() {
136+
// CHECK-COUNT-4: vector.create_mask {{.*}}, {{.*}} : vector<16x16xi1>
137+
// CHECK-NOT: vector.create_mask
138+
%constant_mask = vector.constant_mask [16, 16] {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16]>} : vector<256x128xi1>
139+
gpu.return
140+
}
133141
}
134142

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,41 @@ gpu.module @test_distribution {
548548
gpu.return
549549
}
550550

551+
// CHECK-LABEL: vector_mask_1D
552+
gpu.func @vector_mask_1D() {
553+
// CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
554+
// CHECK-DAG: %[[REMU:.*]] = index.remu %[[SGID]], %[[C2:.*]]
555+
// CHECK-DAG: %[[MUL:.*]] = index.mul %[[REMU]], %[[C16:.*]]
556+
// CHECK-DAG: %[[REMU2:.*]] = index.remu %[[MUL]], %[[C32:.*]]
557+
// CHECK-DAG: %[[SUB:.*]] = arith.subi %[[C8:.*]], %[[REMU2]] : index
558+
// CHECK-DAG: %[[MAX:.*]] = arith.maxsi %[[SUB]], %[[C0:.*]] : index
559+
// CHECK-DAG: %[[MIN:.*]] = arith.minsi %[[MAX]], %[[C16:.*]] : index
560+
// CHECK-DAG: %[[MASK:.*]] = vector.create_mask %[[MIN]] : vector<16xi1>
561+
%constant_mask = vector.constant_mask [8] {layout_result_0 = #xegpu.layout<sg_layout = [2], sg_data = [16]>} : vector<32xi1>
562+
gpu.return
563+
}
564+
565+
// CHECK-LABEL: vector_mask_2D
566+
gpu.func @vector_mask_2D() {
567+
// CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
568+
// CHECK-DAG: %[[SGIDX:.*]] = index.remu %[[SGID]], %[[C4:.*]]
569+
// CHECK-DAG: %[[SGIDY_TMP:.*]] = index.divu %[[SGID]], %[[C4:.*]]
570+
// CHECK-DAG: %[[SGIDY:.*]] = index.remu %[[SGIDY_TMP]], %[[C8:.*]]
571+
// CHECK-DAG: %[[ROW:.*]] = index.mul %[[SGIDY]], %[[C32:.*]]
572+
// CHECK-DAG: %[[COL:.*]] = index.mul %[[SGIDX]], %[[C32:.*]]
573+
// CHECK-DAG: %[[MODROW:.*]] = index.remu %[[ROW]], %[[C256:.*]]
574+
// CHECK-DAG: %[[MODCOL:.*]] = index.remu %[[COL]], %[[C128:.*]]
575+
// CHECK-DAG: %[[SUBROW:.*]] = arith.subi %[[C16:.*]], %[[MODROW]] : index
576+
// CHECK-DAG: %[[MAXROW:.*]] = arith.maxsi %[[SUBROW]], %[[C4:.*]] : index
577+
// CHECK-DAG: %[[MINROW:.*]] = arith.minsi %[[MAXROW]], %[[C32:.*]] : index
578+
// CHECK-DAG: %[[SUBCOL:.*]] = arith.subi %[[C16:.*]], %[[MODCOL]] : index
579+
// CHECK-DAG: %[[MAXCOL:.*]] = arith.maxsi %[[SUBCOL]], %[[C7:.*]] : index
580+
// CHECK-DAG: %[[MINCOL:.*]] = arith.minsi %[[MAXCOL]], %[[C32:.*]] : index
581+
// CHECK-DAG: %[[MASK:.*]] = vector.create_mask %[[MINROW]], %[[MINCOL]] : vector<32x32xi1>
582+
%constant_mask = vector.constant_mask [16, 16] {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>} : vector<256x128xi1>
583+
gpu.return
584+
}
585+
551586
// CHECK-LABEL: distribute_load_slice_attr
552587
gpu.func @distribute_load_slice_attr() {
553588
%2 = memref.alloca() {alignment = 1024} : memref<4096xf32>

0 commit comments

Comments
 (0)