Skip to content

Commit c8d3b0c

Browse files
authored
[MLIR][XeGPU] Add distribution for vector.create_mask from Wg to Sg (#169571)
1 parent b70be3d commit c8d3b0c

File tree

3 files changed

+74
-20
lines changed

3 files changed

+74
-20
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1278,15 +1278,15 @@ struct WgToSgVectorTransposeOp
12781278
}
12791279
};
12801280

1281-
// This pattern distributes the vector.constant_mask ops to work at subgroup
1282-
// level.
1283-
struct WgToSgVectorConstantMaskOp
1284-
: public OpConversionPattern<vector::ConstantMaskOp> {
1285-
using OpConversionPattern<vector::ConstantMaskOp>::OpConversionPattern;
1286-
1287-
LogicalResult
1288-
matchAndRewrite(vector::ConstantMaskOp op, OneToNOpAdaptor adaptor,
1289-
ConversionPatternRewriter &rewriter) const override {
1281+
// Distribute vector mask ops to work at subgroup level.
1282+
template <typename MaskOpType>
1283+
struct WgToSgVectorMaskOp : public OpConversionPattern<MaskOpType> {
1284+
using OpConversionPattern<MaskOpType>::OpConversionPattern;
1285+
1286+
LogicalResult matchAndRewrite(
1287+
MaskOpType op,
1288+
typename OpConversionPattern<MaskOpType>::OneToNOpAdaptor adaptor,
1289+
ConversionPatternRewriter &rewriter) const override {
12901290
xegpu::DistributeLayoutAttr layout =
12911291
xegpu::getDistributeLayoutAttr(op.getResult());
12921292
if (!layout || !layout.isForWorkgroup())
@@ -1296,9 +1296,16 @@ struct WgToSgVectorConstantMaskOp
12961296
VectorType type = op.getResult().getType();
12971297
auto wgShape = type.getShape();
12981298

1299-
ArrayRef<int64_t> wgMaskDimSizes = op.getMaskDimSizes();
1299+
SmallVector<Value> wgMaskDimSizes;
1300+
if constexpr (std::is_same_v<MaskOpType, vector::ConstantMaskOp>) {
1301+
for (int64_t maskSize : op.getMaskDimSizes()) {
1302+
wgMaskDimSizes.push_back(
1303+
arith::ConstantIndexOp::create(rewriter, loc, maskSize));
1304+
}
1305+
} else if constexpr (std::is_same_v<MaskOpType, vector::CreateMaskOp>) {
1306+
wgMaskDimSizes = llvm::to_vector(op.getOperands());
1307+
}
13001308

1301-
// Get subgroup ID.
13021309
Value sgId =
13031310
gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
13041311
auto sgOffsets =
@@ -1310,19 +1317,17 @@ struct WgToSgVectorConstantMaskOp
13101317
VectorType resultType = VectorType::get(sgShape, type.getElementType());
13111318

13121319
// In each dimension, each subgroup computes its local mask size as:
1313-
// min(max(wgMaskSize[d] - offset[d], 0), sgDimSize[d])
1320+
// min(max(wgMaskDimSize[d] - offset[d], 0), sgDimSize[d])
13141321
SmallVector<Value> newCreateMaskOps;
13151322
for (auto offsetSet : *sgOffsets) {
13161323
SmallVector<Value> maskOperands;
13171324

1318-
for (auto [i, wgMaskSize] : llvm::enumerate(wgMaskDimSizes)) {
1319-
Value wgMaskSizeVal =
1320-
arith::ConstantIndexOp::create(rewriter, loc, wgMaskSize);
1325+
for (auto [i, wgMaskDimSize] : llvm::enumerate(wgMaskDimSizes)) {
13211326
Value dimSizeVal =
13221327
arith::ConstantIndexOp::create(rewriter, loc, sgShape[i]);
13231328
Value offset = offsetSet[i];
13241329
Value adjustedMaskSize =
1325-
arith::SubIOp::create(rewriter, loc, wgMaskSizeVal, offset);
1330+
arith::SubIOp::create(rewriter, loc, wgMaskDimSize, offset);
13261331
Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
13271332
Value nonNegative =
13281333
arith::MaxSIOp::create(rewriter, loc, adjustedMaskSize, zero);
@@ -1343,6 +1348,8 @@ struct WgToSgVectorConstantMaskOp
13431348
}
13441349
};
13451350

1351+
using WgToSgVectorConstantMaskOp = WgToSgVectorMaskOp<vector::ConstantMaskOp>;
1352+
using WgToSgVectorCreateMaskOp = WgToSgVectorMaskOp<vector::CreateMaskOp>;
13461353
} // namespace
13471354

13481355
namespace mlir {
@@ -1358,7 +1365,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
13581365
WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
13591366
WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp,
13601367
WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp,
1361-
WgToSgVectorConstantMaskOp>(patterns.getContext());
1368+
WgToSgVectorConstantMaskOp, WgToSgVectorCreateMaskOp>(
1369+
patterns.getContext());
13621370
}
13631371
} // namespace xegpu
13641372
} // namespace mlir
@@ -1485,9 +1493,10 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
14851493
return isLegal(layout);
14861494
});
14871495

1488-
target.addDynamicallyLegalOp<
1489-
vector::ShapeCastOp, vector::StepOp, vector::TransposeOp,
1490-
vector::BroadcastOp, vector::MultiDimReductionOp, vector::ConstantMaskOp>(
1496+
target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp,
1497+
vector::TransposeOp, vector::BroadcastOp,
1498+
vector::MultiDimReductionOp,
1499+
vector::ConstantMaskOp, vector::CreateMaskOp>(
14911500
[=](Operation *op) -> bool {
14921501
// Check for either a SliceAttr or LayoutAttr on the result.
14931502
auto layout = xegpu::getDistributeLayoutAttr(op->getResult(0));

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,4 +135,12 @@ gpu.module @test_distribution {
135135
%constant_mask = vector.constant_mask [16, 16] {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16]>} : vector<256x128xi1>
136136
gpu.return
137137
}
138+
139+
gpu.func @vector_create_mask_2D() {
140+
// CHECK-COUNT-4: vector.create_mask {{.*}}, {{.*}} : vector<16x16xi1>
141+
// CHECK-NOT: vector.create_mask
142+
%cst16 = arith.constant 16 : index
143+
%constant_mask = vector.create_mask %cst16, %cst16 {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16]>} : vector<256x128xi1>
144+
gpu.return
145+
}
138146
}

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,43 @@ gpu.module @test_distribution {
580580
gpu.return
581581
}
582582

583+
// CHECK-LABEL: vector_create_mask_1D
584+
gpu.func @vector_create_mask_1D() {
585+
// CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
586+
// CHECK-DAG: %[[REMU:.*]] = index.remu %[[SGID]], %[[C2:.*]]
587+
// CHECK-DAG: %[[MUL:.*]] = index.mul %[[REMU]], %[[C16:.*]]
588+
// CHECK-DAG: %[[REMU2:.*]] = index.remu %[[MUL]], %[[C32:.*]]
589+
// CHECK-DAG: %[[SUB:.*]] = arith.subi %[[C8:.*]], %[[REMU2]] : index
590+
// CHECK-DAG: %[[MAX:.*]] = arith.maxsi %[[SUB]], %[[C0:.*]] : index
591+
// CHECK-DAG: %[[MIN:.*]] = arith.minsi %[[MAX]], %[[C16:.*]] : index
592+
// CHECK-DAG: %[[MASK:.*]] = vector.create_mask %[[MIN]] : vector<16xi1>
593+
%cst8 = arith.constant 8 : index
594+
%constant_mask = vector.create_mask %cst8 {layout_result_0 = #xegpu.layout<sg_layout = [2], sg_data = [16]>} : vector<32xi1>
595+
gpu.return
596+
}
597+
598+
// CHECK-LABEL: vector_create_mask_2D
599+
gpu.func @vector_create_mask_2D() {
600+
// CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
601+
// CHECK-DAG: %[[SGIDX:.*]] = index.remu %[[SGID]], %[[C4:.*]]
602+
// CHECK-DAG: %[[SGIDY_TMP:.*]] = index.divu %[[SGID]], %[[C4:.*]]
603+
// CHECK-DAG: %[[SGIDY:.*]] = index.remu %[[SGIDY_TMP]], %[[C8:.*]]
604+
// CHECK-DAG: %[[ROW:.*]] = index.mul %[[SGIDY]], %[[C32:.*]]
605+
// CHECK-DAG: %[[COL:.*]] = index.mul %[[SGIDX]], %[[C32:.*]]
606+
// CHECK-DAG: %[[MODROW:.*]] = index.remu %[[ROW]], %[[C256:.*]]
607+
// CHECK-DAG: %[[MODCOL:.*]] = index.remu %[[COL]], %[[C128:.*]]
608+
// CHECK-DAG: %[[SUBROW:.*]] = arith.subi %[[C16:.*]], %[[MODROW]] : index
609+
// CHECK-DAG: %[[MAXROW:.*]] = arith.maxsi %[[SUBROW]], %[[C0:.*]] : index
610+
// CHECK-DAG: %[[MINROW:.*]] = arith.minsi %[[MAXROW]], %[[C32:.*]] : index
611+
// CHECK-DAG: %[[SUBCOL:.*]] = arith.subi %[[C16:.*]], %[[MODCOL]] : index
612+
// CHECK-DAG: %[[MAXCOL:.*]] = arith.maxsi %[[SUBCOL]], %[[C0:.*]] : index
613+
// CHECK-DAG: %[[MINCOL:.*]] = arith.minsi %[[MAXCOL]], %[[C32:.*]] : index
614+
// CHECK-DAG: %[[MASK:.*]] = vector.create_mask %[[MINROW]], %[[MINCOL]] : vector<32x32xi1>
615+
%cst16 = arith.constant 16 : index
616+
%constant_mask = vector.create_mask %cst16, %cst16 {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>} : vector<256x128xi1>
617+
gpu.return
618+
}
619+
583620
// CHECK-LABEL: distribute_load_slice_attr
584621
gpu.func @distribute_load_slice_attr() {
585622
%2 = memref.alloca() {alignment = 1024} : memref<4096xf32>

0 commit comments

Comments
 (0)