Skip to content

Commit f9c930d

Browse files
committed
Templatize mask ops
1 parent 3df7560 commit f9c930d

File tree

1 file changed

+21
-76
lines changed

1 file changed

+21
-76
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 21 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1270,15 +1270,15 @@ struct WgToSgVectorTransposeOp
12701270
}
12711271
};
12721272

1273-
// This pattern distributes the vector.constant_mask ops to work at subgroup
1274-
// level.
1275-
struct WgToSgVectorConstantMaskOp
1276-
: public OpConversionPattern<vector::ConstantMaskOp> {
1277-
using OpConversionPattern<vector::ConstantMaskOp>::OpConversionPattern;
1278-
1279-
LogicalResult
1280-
matchAndRewrite(vector::ConstantMaskOp op, OneToNOpAdaptor adaptor,
1281-
ConversionPatternRewriter &rewriter) const override {
1273+
// Distribute vector mask ops to work at subgroup level.
1274+
template <typename MaskOpType>
1275+
struct WgToSgVectorMaskOp : public OpConversionPattern<MaskOpType> {
1276+
using OpConversionPattern<MaskOpType>::OpConversionPattern;
1277+
1278+
LogicalResult matchAndRewrite(
1279+
MaskOpType op,
1280+
typename OpConversionPattern<MaskOpType>::OneToNOpAdaptor adaptor,
1281+
ConversionPatternRewriter &rewriter) const override {
12821282
xegpu::DistributeLayoutAttr layout =
12831283
xegpu::getDistributeLayoutAttr(op.getResult());
12841284
if (!layout || !layout.isForWorkgroup())
@@ -1288,73 +1288,16 @@ struct WgToSgVectorConstantMaskOp
12881288
VectorType type = op.getResult().getType();
12891289
auto wgShape = type.getShape();
12901290

1291-
ArrayRef<int64_t> wgMaskDimSizes = op.getMaskDimSizes();
1292-
1293-
// Get subgroup ID.
1294-
Value sgId =
1295-
gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
1296-
auto sgOffsets =
1297-
layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
1298-
if (failed(sgOffsets))
1299-
return failure();
1300-
1301-
SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1302-
VectorType resultType = VectorType::get(sgShape, type.getElementType());
1303-
1304-
// In each dimension, each subgroup computes its local mask size as:
1305-
// min(max(wgMaskSize[d] - offset[d], 0), sgDimSize[d])
1306-
SmallVector<Value> newCreateMaskOps;
1307-
for (auto offsetSet : *sgOffsets) {
1308-
SmallVector<Value> maskOperands;
1309-
1310-
for (auto [i, wgMaskSize] : llvm::enumerate(wgMaskDimSizes)) {
1311-
Value wgMaskSizeVal =
1312-
arith::ConstantIndexOp::create(rewriter, loc, wgMaskSize);
1313-
Value dimSizeVal =
1314-
arith::ConstantIndexOp::create(rewriter, loc, sgShape[i]);
1315-
Value offset = offsetSet[i];
1316-
Value adjustedMaskSize =
1317-
arith::SubIOp::create(rewriter, loc, wgMaskSizeVal, offset);
1318-
Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
1319-
Value nonNegative =
1320-
arith::MaxSIOp::create(rewriter, loc, adjustedMaskSize, zero);
1321-
Value sgMaskSize =
1322-
arith::MinSIOp::create(rewriter, loc, nonNegative, dimSizeVal);
1323-
maskOperands.push_back(sgMaskSize);
1291+
SmallVector<Value> wgMaskDimSizes;
1292+
if constexpr (std::is_same_v<MaskOpType, vector::ConstantMaskOp>) {
1293+
for (int64_t maskSize : op.getMaskDimSizes()) {
1294+
wgMaskDimSizes.push_back(
1295+
arith::ConstantIndexOp::create(rewriter, loc, maskSize));
13241296
}
1325-
1326-
auto newCreateMaskOp =
1327-
vector::CreateMaskOp::create(rewriter, loc, resultType, maskOperands);
1328-
xegpu::setDistributeLayoutAttr(newCreateMaskOp->getResult(0),
1329-
layout.dropSgLayoutAndData());
1330-
newCreateMaskOps.push_back(newCreateMaskOp.getResult());
1297+
} else if constexpr (std::is_same_v<MaskOpType, vector::CreateMaskOp>) {
1298+
wgMaskDimSizes = llvm::to_vector(op.getOperands());
13311299
}
13321300

1333-
rewriter.replaceOpWithMultiple(op, {newCreateMaskOps});
1334-
return success();
1335-
}
1336-
};
1337-
1338-
// This pattern distributes the vector.create_mask ops to work at subgroup
1339-
// level.
1340-
struct WgToSgVectorCreateMaskOp
1341-
: public OpConversionPattern<vector::CreateMaskOp> {
1342-
using OpConversionPattern<vector::CreateMaskOp>::OpConversionPattern;
1343-
1344-
LogicalResult
1345-
matchAndRewrite(vector::CreateMaskOp op, OneToNOpAdaptor adaptor,
1346-
ConversionPatternRewriter &rewriter) const override {
1347-
xegpu::DistributeLayoutAttr layout =
1348-
xegpu::getDistributeLayoutAttr(op.getResult());
1349-
if (!layout || !layout.isForWorkgroup())
1350-
return failure();
1351-
1352-
Location loc = op.getLoc();
1353-
VectorType type = op.getResult().getType();
1354-
auto wgShape = type.getShape();
1355-
1356-
auto wgMaskOperands = op.getOperands();
1357-
13581301
Value sgId =
13591302
gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
13601303
auto sgOffsets =
@@ -1366,17 +1309,17 @@ struct WgToSgVectorCreateMaskOp
13661309
VectorType resultType = VectorType::get(sgShape, type.getElementType());
13671310

13681311
// In each dimension, each subgroup computes its local mask size as:
1369-
// min(max(wgMaskSize[d] - offset[d], 0), sgDimSize[d])
1312+
// min(max(wgMaskDimSize[d] - offset[d], 0), sgDimSize[d])
13701313
SmallVector<Value> newCreateMaskOps;
13711314
for (auto offsetSet : *sgOffsets) {
13721315
SmallVector<Value> maskOperands;
13731316

1374-
for (auto [i, wgMaskOperand] : llvm::enumerate(wgMaskOperands)) {
1317+
for (auto [i, wgMaskDimSize] : llvm::enumerate(wgMaskDimSizes)) {
13751318
Value dimSizeVal =
13761319
arith::ConstantIndexOp::create(rewriter, loc, sgShape[i]);
13771320
Value offset = offsetSet[i];
13781321
Value adjustedMaskSize =
1379-
arith::SubIOp::create(rewriter, loc, wgMaskOperand, offset);
1322+
arith::SubIOp::create(rewriter, loc, wgMaskDimSize, offset);
13801323
Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
13811324
Value nonNegative =
13821325
arith::MaxSIOp::create(rewriter, loc, adjustedMaskSize, zero);
@@ -1397,6 +1340,8 @@ struct WgToSgVectorCreateMaskOp
13971340
}
13981341
};
13991342

1343+
using WgToSgVectorConstantMaskOp = WgToSgVectorMaskOp<vector::ConstantMaskOp>;
1344+
using WgToSgVectorCreateMaskOp = WgToSgVectorMaskOp<vector::CreateMaskOp>;
14001345
} // namespace
14011346

14021347
namespace mlir {

0 commit comments

Comments
 (0)