Skip to content

Commit be59f46

Browse files
committed
Add pattern for constant mask
1 parent 4d54fb1 commit be59f46

File tree

3 files changed

+81
-53
lines changed

3 files changed

+81
-53
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 57 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1283,57 +1283,74 @@ struct WgToSgVectorTransposeOp
12831283
}
12841284
};
12851285

1286-
/// Pattern for lowering vector.create_mask and vector.constant_mask ops to
1287-
/// subgroup level.
1288-
template <typename MaskOpType>
1289-
struct WgToSgVectorMaskOp : public OpConversionPattern<MaskOpType> {
1290-
using OpConversionPattern<MaskOpType>::OpConversionPattern;
1291-
1292-
LogicalResult matchAndRewrite(
1293-
MaskOpType op,
1294-
typename OpConversionPattern<MaskOpType>::OneToNOpAdaptor adaptor,
1295-
ConversionPatternRewriter &rewriter) const override {
1296-
VectorType resultType = op.getResult().getType();
1297-
ArrayRef<int64_t> wgShape = resultType.getShape();
1286+
// This pattern distributes the vector.constant_mask ops to work at subgroup
1287+
// level.
1288+
struct WgToSgVectorConstantMaskOp
1289+
: public OpConversionPattern<vector::ConstantMaskOp> {
1290+
using OpConversionPattern<vector::ConstantMaskOp>::OpConversionPattern;
12981291

1292+
LogicalResult
1293+
matchAndRewrite(vector::ConstantMaskOp op, OneToNOpAdaptor adaptor,
1294+
ConversionPatternRewriter &rewriter) const override {
12991295
xegpu::DistributeLayoutAttr layout =
13001296
xegpu::getDistributeLayoutAttr(op.getResult());
13011297
if (!layout || !layout.isForWorkgroup())
13021298
return failure();
13031299

1304-
SmallVector<int64_t> sgShape;
1305-
int count;
1306-
std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
1307-
VectorType newResultType =
1308-
VectorType::get(sgShape, resultType.getElementType());
1300+
Location loc = op.getLoc();
1301+
VectorType type = op.getResult().getType();
1302+
auto wgShape = type.getShape();
1303+
1304+
ArrayRef<int64_t> originalMaskDimSizes = op.getMaskDimSizes();
13091305

1310-
SmallVector<Value> newMaskOps;
1311-
for (int i = 0; i < count; ++i) {
1312-
Value newMaskOp;
1313-
if constexpr (std::is_same_v<MaskOpType, vector::CreateMaskOp>) {
1314-
newMaskOp = vector::CreateMaskOp::create(
1315-
rewriter, op.getLoc(), newResultType, op.getOperands());
1316-
} else if constexpr (std::is_same_v<MaskOpType, vector::ConstantMaskOp>) {
1317-
newMaskOp = vector::ConstantMaskOp::create(
1318-
rewriter, op.getLoc(), newResultType, op.getMaskDimSizes());
1319-
} else {
1320-
return rewriter.notifyMatchFailure(op,
1321-
"Unsupported mask operation type");
1306+
// Get subgroup ID.
1307+
Value sgId =
1308+
gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
1309+
auto sgOffsets =
1310+
layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
1311+
if (failed(sgOffsets))
1312+
return failure();
1313+
1314+
SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1315+
VectorType resultType = VectorType::get(sgShape, type.getElementType());
1316+
1317+
SmallVector<Value> newCreateMaskOps;
1318+
for (auto offsetSet : *sgOffsets) {
1319+
SmallVector<Value> maskOperands;
1320+
1321+
for (auto [i, originalMaskSize] : llvm::enumerate(originalMaskDimSizes)) {
1322+
Value originalMaskSizeVal =
1323+
arith::ConstantIndexOp::create(rewriter, loc, originalMaskSize);
1324+
Value dimSizeVal =
1325+
arith::ConstantIndexOp::create(rewriter, loc, sgShape[i]);
1326+
Value offset = offsetSet[i];
1327+
// Compute: originalMaskSize - offset.
1328+
Value adjustedMaskSize =
1329+
arith::SubIOp::create(rewriter, loc, originalMaskSizeVal, offset);
1330+
// Clamp to [0, dimSize]: max(0, min(adjustedMaskSize,
1331+
// dimSize))
1332+
Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
1333+
Value clampedLow =
1334+
arith::MaxSIOp::create(rewriter, loc, adjustedMaskSize, zero);
1335+
Value clampedHigh =
1336+
arith::MinSIOp::create(rewriter, loc, clampedLow, dimSizeVal);
1337+
maskOperands.push_back(clampedHigh);
13221338
}
1323-
xegpu::setDistributeLayoutAttr(cast<OpResult>(newMaskOp),
1324-
layout.dropSgLayoutAndData());
13251339

1326-
newMaskOps.push_back(newMaskOp);
1340+
auto newCreateMaskOp =
1341+
vector::CreateMaskOp::create(rewriter, loc, resultType, maskOperands);
1342+
if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
1343+
!layout.getEffectiveInstDataAsInt().empty())
1344+
xegpu::setDistributeLayoutAttr(newCreateMaskOp->getResult(0),
1345+
layout.dropSgLayoutAndData());
1346+
newCreateMaskOps.push_back(newCreateMaskOp.getResult());
13271347
}
13281348

1329-
rewriter.replaceOpWithMultiple(op, {newMaskOps});
1349+
rewriter.replaceOpWithMultiple(op, {newCreateMaskOps});
13301350
return success();
13311351
}
13321352
};
13331353

1334-
using WgToSgVectorCreateMaskOp = WgToSgVectorMaskOp<vector::CreateMaskOp>;
1335-
using WgToSgVectorConstantMaskOp = WgToSgVectorMaskOp<vector::ConstantMaskOp>;
1336-
13371354
} // namespace
13381355

13391356
namespace mlir {
@@ -1349,8 +1366,7 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
13491366
WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
13501367
WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp,
13511368
WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp,
1352-
WgToSgVectorConstantMaskOp, WgToSgVectorCreateMaskOp>(
1353-
patterns.getContext());
1369+
WgToSgVectorConstantMaskOp>(patterns.getContext());
13541370
}
13551371
} // namespace xegpu
13561372
} // namespace mlir
@@ -1477,10 +1493,9 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
14771493
return isLegal(layout);
14781494
});
14791495

1480-
target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp,
1481-
vector::TransposeOp, vector::BroadcastOp,
1482-
vector::MultiDimReductionOp,
1483-
vector::ConstantMaskOp, vector::CreateMaskOp>(
1496+
target.addDynamicallyLegalOp<
1497+
vector::ShapeCastOp, vector::StepOp, vector::TransposeOp,
1498+
vector::BroadcastOp, vector::MultiDimReductionOp, vector::ConstantMaskOp>(
14841499
[=](Operation *op) -> bool {
14851500
// Check for either a SliceAttr or LayoutAttr on the result.
14861501
auto layout = xegpu::getDistributeLayoutAttr(op->getResult(0));

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -134,12 +134,8 @@ gpu.module @test_distribution {
134134
// CHECK-LABEL: vector_mask_2D
135135
gpu.func @vector_mask_2D() {
136136
%cst16 = arith.constant 16 : index
137-
// CHECK: %[[CST16:.*]] = arith.constant 16 : index
138-
// CHECK-COUNT-4: vector.create_mask %[[CST16:.*]], %[[CST16]] : vector<16x16xi1>
137+
// CHECK-COUNT-4: vector.create_mask {{.*}}, {{.*}} : vector<16x16xi1>
139138
// CHECK-NOT: vector.create_mask
140-
// CHECK-COUNT-4: vector.constant_mask [16, 16] : vector<16x16xi1>
141-
// CHECK-NOT: vector.constant_mask
142-
%create_mask = vector.create_mask %cst16, %cst16 {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16]>} : vector<256x128xi1>
143139
%constant_mask = vector.constant_mask [16, 16] {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16]>} : vector<256x128xi1>
144140
gpu.return
145141
}

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -550,20 +550,37 @@ gpu.module @test_distribution {
550550

551551
// CHECK-LABEL: vector_mask_1D
552552
gpu.func @vector_mask_1D() {
553+
// CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
554+
// CHECK-DAG: %[[REMU:.*]] = index.remu %[[SGID]], %[[C2:.*]]
555+
// CHECK-DAG: %[[MUL:.*]] = index.mul %[[REMU]], %[[C16:.*]]
556+
// CHECK-DAG: %[[REMU2:.*]] = index.remu %[[MUL]], %[[C32:.*]]
557+
// CHECK-DAG: %[[SUB:.*]] = arith.subi %[[C8:.*]], %[[REMU2]] : index
558+
// CHECK-DAG: %[[MAX:.*]] = arith.maxsi %[[SUB]], %[[C0:.*]] : index
559+
// CHECK-DAG: %[[MIN:.*]] = arith.minsi %[[MAX]], %[[C16:.*]] : index
560+
// CHECK-DAG: %[[MASK:.*]] = vector.create_mask %[[MIN]] : vector<16xi1>
553561
%cst8 = arith.constant 8 : index
554-
// CHECK: vector.create_mask {{.*}} : vector<16xi1>
555-
%create_mask = vector.create_mask %cst8 {layout_result_0 = #xegpu.layout<sg_layout = [2], sg_data = [16]>} : vector<16xi1>
556-
// CHECK: vector.constant_mask [8] : vector<16xi1>
557562
%constant_mask = vector.constant_mask [8] {layout_result_0 = #xegpu.layout<sg_layout = [2], sg_data = [16]>} : vector<32xi1>
558563
gpu.return
559564
}
560565

561566
// CHECK-LABEL: vector_mask_2D
562567
gpu.func @vector_mask_2D() {
568+
// CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
569+
// CHECK-DAG: %[[SGIDX:.*]] = index.remu %[[SGID]], %[[C4:.*]]
570+
// CHECK-DAG: %[[SGIDY_TMP:.*]] = index.divu %[[SGID]], %[[C4:.*]]
571+
// CHECK-DAG: %[[SGIDY:.*]] = index.remu %[[SGIDY_TMP]], %[[C8:.*]]
572+
// CHECK-DAG: %[[ROW:.*]] = index.mul %[[SGIDY]], %[[C32:.*]]
573+
// CHECK-DAG: %[[COL:.*]] = index.mul %[[SGIDX]], %[[C32:.*]]
574+
// CHECK-DAG: %[[MODROW:.*]] = index.remu %[[ROW]], %[[C256:.*]]
575+
// CHECK-DAG: %[[MODCOL:.*]] = index.remu %[[COL]], %[[C128:.*]]
576+
// CHECK-DAG: %[[SUBROW:.*]] = arith.subi %[[C16:.*]], %[[MODROW]] : index
577+
// CHECK-DAG: %[[MAXROW:.*]] = arith.maxsi %[[SUBROW]], %[[C4:.*]] : index
578+
// CHECK-DAG: %[[MINROW:.*]] = arith.minsi %[[MAXROW]], %[[C32:.*]] : index
579+
// CHECK-DAG: %[[SUBCOL:.*]] = arith.subi %[[C16:.*]], %[[MODCOL]] : index
580+
// CHECK-DAG: %[[MAXCOL:.*]] = arith.maxsi %[[SUBCOL]], %[[C7:.*]] : index
581+
// CHECK-DAG: %[[MINCOL:.*]] = arith.minsi %[[MAXCOL]], %[[C32:.*]] : index
582+
// CHECK-DAG: %[[MASK:.*]] = vector.create_mask %[[MINROW]], %[[MINCOL]] : vector<32x32xi1>
563583
%cst16 = arith.constant 16 : index
564-
// CHECK: vector.create_mask {{.*}}, {{.*}} : vector<32x32xi1>
565-
%create_mask = vector.create_mask %cst16, %cst16 {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>} : vector<256x128xi1>
566-
// CHECK: vector.constant_mask [16, 16] : vector<32x32xi1>
567584
%constant_mask = vector.constant_mask [16, 16] {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>} : vector<256x128xi1>
568585
gpu.return
569586
}

0 commit comments

Comments
 (0)