Skip to content

Commit 4d54fb1

Browse files
committed
Add distribution for vector mask operations
1 parent 6a89439 commit 4d54fb1

File tree

3 files changed

+88
-2
lines changed

3 files changed

+88
-2
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1283,6 +1283,57 @@ struct WgToSgVectorTransposeOp
12831283
}
12841284
};
12851285

1286+
/// Pattern for lowering vector.create_mask and vector.constant_mask ops to
1287+
/// subgroup level.
1288+
template <typename MaskOpType>
1289+
struct WgToSgVectorMaskOp : public OpConversionPattern<MaskOpType> {
1290+
using OpConversionPattern<MaskOpType>::OpConversionPattern;
1291+
1292+
LogicalResult matchAndRewrite(
1293+
MaskOpType op,
1294+
typename OpConversionPattern<MaskOpType>::OneToNOpAdaptor adaptor,
1295+
ConversionPatternRewriter &rewriter) const override {
1296+
VectorType resultType = op.getResult().getType();
1297+
ArrayRef<int64_t> wgShape = resultType.getShape();
1298+
1299+
xegpu::DistributeLayoutAttr layout =
1300+
xegpu::getDistributeLayoutAttr(op.getResult());
1301+
if (!layout || !layout.isForWorkgroup())
1302+
return failure();
1303+
1304+
SmallVector<int64_t> sgShape;
1305+
int count;
1306+
std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
1307+
VectorType newResultType =
1308+
VectorType::get(sgShape, resultType.getElementType());
1309+
1310+
SmallVector<Value> newMaskOps;
1311+
for (int i = 0; i < count; ++i) {
1312+
Value newMaskOp;
1313+
if constexpr (std::is_same_v<MaskOpType, vector::CreateMaskOp>) {
1314+
newMaskOp = vector::CreateMaskOp::create(
1315+
rewriter, op.getLoc(), newResultType, op.getOperands());
1316+
} else if constexpr (std::is_same_v<MaskOpType, vector::ConstantMaskOp>) {
1317+
newMaskOp = vector::ConstantMaskOp::create(
1318+
rewriter, op.getLoc(), newResultType, op.getMaskDimSizes());
1319+
} else {
1320+
return rewriter.notifyMatchFailure(op,
1321+
"Unsupported mask operation type");
1322+
}
1323+
xegpu::setDistributeLayoutAttr(cast<OpResult>(newMaskOp),
1324+
layout.dropSgLayoutAndData());
1325+
1326+
newMaskOps.push_back(newMaskOp);
1327+
}
1328+
1329+
rewriter.replaceOpWithMultiple(op, {newMaskOps});
1330+
return success();
1331+
}
1332+
};
1333+
1334+
using WgToSgVectorCreateMaskOp = WgToSgVectorMaskOp<vector::CreateMaskOp>;
1335+
using WgToSgVectorConstantMaskOp = WgToSgVectorMaskOp<vector::ConstantMaskOp>;
1336+
12861337
} // namespace
12871338

12881339
namespace mlir {
@@ -1297,7 +1348,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
12971348
WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
12981349
WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
12991350
WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp,
1300-
WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp>(
1351+
WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp,
1352+
WgToSgVectorConstantMaskOp, WgToSgVectorCreateMaskOp>(
13011353
patterns.getContext());
13021354
}
13031355
} // namespace xegpu
@@ -1427,7 +1479,8 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
14271479

14281480
target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp,
14291481
vector::TransposeOp, vector::BroadcastOp,
1430-
vector::MultiDimReductionOp>(
1482+
vector::MultiDimReductionOp,
1483+
vector::ConstantMaskOp, vector::CreateMaskOp>(
14311484
[=](Operation *op) -> bool {
14321485
// Check for either a SliceAttr or LayoutAttr on the result.
14331486
auto layout = xegpu::getDistributeLayoutAttr(op->getResult(0));

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,5 +130,18 @@ gpu.module @test_distribution {
130130
%trans = vector.transpose %load, [1, 0] {layout_result_0 = #xegpu.layout<sg_layout = [4, 8], sg_data = [16, 32], lane_layout = [1, 16], lane_data = [1, 1], order =[1, 0]>} : vector<256x128xf32> to vector<128x256xf32>
131131
gpu.return
132132
}
133+
134+
// CHECK-LABEL: vector_mask_2D
135+
gpu.func @vector_mask_2D() {
136+
%cst16 = arith.constant 16 : index
137+
// CHECK: %[[CST16:.*]] = arith.constant 16 : index
138+
// CHECK-COUNT-4: vector.create_mask %[[CST16:.*]], %[[CST16]] : vector<16x16xi1>
139+
// CHECK-NOT: vector.create_mask
140+
// CHECK-COUNT-4: vector.constant_mask [16, 16] : vector<16x16xi1>
141+
// CHECK-NOT: vector.constant_mask
142+
%create_mask = vector.create_mask %cst16, %cst16 {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16]>} : vector<256x128xi1>
143+
%constant_mask = vector.constant_mask [16, 16] {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16]>} : vector<256x128xi1>
144+
gpu.return
145+
}
133146
}
134147

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -547,4 +547,24 @@ gpu.module @test_distribution {
547547
%broadcast = vector.broadcast %arg0 {layout_result_0 = #xegpu.layout<sg_layout = [4, 8, 1], sg_data = [1, 1, 1]>} : index to vector<4x1x1xindex>
548548
gpu.return
549549
}
550+
551+
// CHECK-LABEL: vector_mask_1D
552+
gpu.func @vector_mask_1D() {
553+
%cst8 = arith.constant 8 : index
554+
// CHECK: vector.create_mask {{.*}} : vector<16xi1>
555+
%create_mask = vector.create_mask %cst8 {layout_result_0 = #xegpu.layout<sg_layout = [2], sg_data = [16]>} : vector<16xi1>
556+
// CHECK: vector.constant_mask [8] : vector<16xi1>
557+
%constant_mask = vector.constant_mask [8] {layout_result_0 = #xegpu.layout<sg_layout = [2], sg_data = [16]>} : vector<32xi1>
558+
gpu.return
559+
}
560+
561+
// CHECK-LABEL: vector_mask_2D
562+
gpu.func @vector_mask_2D() {
563+
%cst16 = arith.constant 16 : index
564+
// CHECK: vector.create_mask {{.*}}, {{.*}} : vector<32x32xi1>
565+
%create_mask = vector.create_mask %cst16, %cst16 {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>} : vector<256x128xi1>
566+
// CHECK: vector.constant_mask [16, 16] : vector<32x32xi1>
567+
%constant_mask = vector.constant_mask [16, 16] {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>} : vector<256x128xi1>
568+
gpu.return
569+
}
550570
}

0 commit comments

Comments
 (0)