@@ -1270,15 +1270,15 @@ struct WgToSgVectorTransposeOp
12701270 }
12711271};
12721272
1273- // This pattern distributes the vector.constant_mask ops to work at subgroup
1274- // level.
1275- struct WgToSgVectorConstantMaskOp
1276- : public OpConversionPattern<vector::ConstantMaskOp> {
1277- using OpConversionPattern<vector::ConstantMaskOp>::OpConversionPattern;
1278-
1279- LogicalResult
1280- matchAndRewrite (vector::ConstantMaskOp op, OneToNOpAdaptor adaptor,
1281- ConversionPatternRewriter &rewriter) const override {
1273+ // Distribute vector mask ops to work at subgroup level.
1274+ template < typename MaskOpType>
1275+ struct WgToSgVectorMaskOp : public OpConversionPattern <MaskOpType> {
1276+ using OpConversionPattern<MaskOpType>::OpConversionPattern;
1277+
1278+ LogicalResult matchAndRewrite (
1279+ MaskOpType op,
1280+ typename OpConversionPattern<MaskOpType>:: OneToNOpAdaptor adaptor,
1281+ ConversionPatternRewriter &rewriter) const override {
12821282 xegpu::DistributeLayoutAttr layout =
12831283 xegpu::getDistributeLayoutAttr (op.getResult ());
12841284 if (!layout || !layout.isForWorkgroup ())
@@ -1288,73 +1288,16 @@ struct WgToSgVectorConstantMaskOp
12881288 VectorType type = op.getResult ().getType ();
12891289 auto wgShape = type.getShape ();
12901290
1291- ArrayRef<int64_t > wgMaskDimSizes = op.getMaskDimSizes ();
1292-
1293- // Get subgroup ID.
1294- Value sgId =
1295- gpu::SubgroupIdOp::create (rewriter, loc, /* upper_bound=*/ nullptr );
1296- auto sgOffsets =
1297- layout.computeDistributedCoords (rewriter, loc, sgId, wgShape);
1298- if (failed (sgOffsets))
1299- return failure ();
1300-
1301- SmallVector<int64_t > sgShape = getSgShapeAndCount (wgShape, layout).first ;
1302- VectorType resultType = VectorType::get (sgShape, type.getElementType ());
1303-
1304- // In each dimension, each subgroup computes its local mask size as:
1305- // min(max(wgMaskSize[d] - offset[d], 0), sgDimSize[d])
1306- SmallVector<Value> newCreateMaskOps;
1307- for (auto offsetSet : *sgOffsets) {
1308- SmallVector<Value> maskOperands;
1309-
1310- for (auto [i, wgMaskSize] : llvm::enumerate (wgMaskDimSizes)) {
1311- Value wgMaskSizeVal =
1312- arith::ConstantIndexOp::create (rewriter, loc, wgMaskSize);
1313- Value dimSizeVal =
1314- arith::ConstantIndexOp::create (rewriter, loc, sgShape[i]);
1315- Value offset = offsetSet[i];
1316- Value adjustedMaskSize =
1317- arith::SubIOp::create (rewriter, loc, wgMaskSizeVal, offset);
1318- Value zero = arith::ConstantIndexOp::create (rewriter, loc, 0 );
1319- Value nonNegative =
1320- arith::MaxSIOp::create (rewriter, loc, adjustedMaskSize, zero);
1321- Value sgMaskSize =
1322- arith::MinSIOp::create (rewriter, loc, nonNegative, dimSizeVal);
1323- maskOperands.push_back (sgMaskSize);
1291+ SmallVector<Value> wgMaskDimSizes;
1292+ if constexpr (std::is_same_v<MaskOpType, vector::ConstantMaskOp>) {
1293+ for (int64_t maskSize : op.getMaskDimSizes ()) {
1294+ wgMaskDimSizes.push_back (
1295+ arith::ConstantIndexOp::create (rewriter, loc, maskSize));
13241296 }
1325-
1326- auto newCreateMaskOp =
1327- vector::CreateMaskOp::create (rewriter, loc, resultType, maskOperands);
1328- xegpu::setDistributeLayoutAttr (newCreateMaskOp->getResult (0 ),
1329- layout.dropSgLayoutAndData ());
1330- newCreateMaskOps.push_back (newCreateMaskOp.getResult ());
1297+ } else if constexpr (std::is_same_v<MaskOpType, vector::CreateMaskOp>) {
1298+ wgMaskDimSizes = llvm::to_vector (op.getOperands ());
13311299 }
13321300
1333- rewriter.replaceOpWithMultiple (op, {newCreateMaskOps});
1334- return success ();
1335- }
1336- };
1337-
1338- // This pattern distributes the vector.create_mask ops to work at subgroup
1339- // level.
1340- struct WgToSgVectorCreateMaskOp
1341- : public OpConversionPattern<vector::CreateMaskOp> {
1342- using OpConversionPattern<vector::CreateMaskOp>::OpConversionPattern;
1343-
1344- LogicalResult
1345- matchAndRewrite (vector::CreateMaskOp op, OneToNOpAdaptor adaptor,
1346- ConversionPatternRewriter &rewriter) const override {
1347- xegpu::DistributeLayoutAttr layout =
1348- xegpu::getDistributeLayoutAttr (op.getResult ());
1349- if (!layout || !layout.isForWorkgroup ())
1350- return failure ();
1351-
1352- Location loc = op.getLoc ();
1353- VectorType type = op.getResult ().getType ();
1354- auto wgShape = type.getShape ();
1355-
1356- auto wgMaskOperands = op.getOperands ();
1357-
13581301 Value sgId =
13591302 gpu::SubgroupIdOp::create (rewriter, loc, /* upper_bound=*/ nullptr );
13601303 auto sgOffsets =
@@ -1366,17 +1309,17 @@ struct WgToSgVectorCreateMaskOp
13661309 VectorType resultType = VectorType::get (sgShape, type.getElementType ());
13671310
13681311 // In each dimension, each subgroup computes its local mask size as:
1369- // min(max(wgMaskSize [d] - offset[d], 0), sgDimSize[d])
1312+ // min(max(wgMaskDimSize [d] - offset[d], 0), sgDimSize[d])
13701313 SmallVector<Value> newCreateMaskOps;
13711314 for (auto offsetSet : *sgOffsets) {
13721315 SmallVector<Value> maskOperands;
13731316
1374- for (auto [i, wgMaskOperand ] : llvm::enumerate (wgMaskOperands )) {
1317+ for (auto [i, wgMaskDimSize ] : llvm::enumerate (wgMaskDimSizes )) {
13751318 Value dimSizeVal =
13761319 arith::ConstantIndexOp::create (rewriter, loc, sgShape[i]);
13771320 Value offset = offsetSet[i];
13781321 Value adjustedMaskSize =
1379- arith::SubIOp::create (rewriter, loc, wgMaskOperand , offset);
1322+ arith::SubIOp::create (rewriter, loc, wgMaskDimSize , offset);
13801323 Value zero = arith::ConstantIndexOp::create (rewriter, loc, 0 );
13811324 Value nonNegative =
13821325 arith::MaxSIOp::create (rewriter, loc, adjustedMaskSize, zero);
@@ -1397,6 +1340,8 @@ struct WgToSgVectorCreateMaskOp
13971340 }
13981341};
13991342
1343+ using WgToSgVectorConstantMaskOp = WgToSgVectorMaskOp<vector::ConstantMaskOp>;
1344+ using WgToSgVectorCreateMaskOp = WgToSgVectorMaskOp<vector::CreateMaskOp>;
14001345} // namespace
14011346
14021347namespace mlir {
0 commit comments