@@ -1285,6 +1285,71 @@ struct WgToSgVectorTransposeOp
12851285 }
12861286};
12871287
1288+ // This pattern distributes the vector.constant_mask ops to work at subgroup
1289+ // level.
1290+ struct WgToSgVectorConstantMaskOp
1291+ : public OpConversionPattern<vector::ConstantMaskOp> {
1292+ using OpConversionPattern<vector::ConstantMaskOp>::OpConversionPattern;
1293+
1294+ LogicalResult
1295+ matchAndRewrite (vector::ConstantMaskOp op, OneToNOpAdaptor adaptor,
1296+ ConversionPatternRewriter &rewriter) const override {
1297+ xegpu::DistributeLayoutAttr layout =
1298+ xegpu::getDistributeLayoutAttr (op.getResult ());
1299+ if (!layout || !layout.isForWorkgroup ())
1300+ return failure ();
1301+
1302+ Location loc = op.getLoc ();
1303+ VectorType type = op.getResult ().getType ();
1304+ auto wgShape = type.getShape ();
1305+
1306+ ArrayRef<int64_t > wgMaskDimSizes = op.getMaskDimSizes ();
1307+
1308+ // Get subgroup ID.
1309+ Value sgId =
1310+ gpu::SubgroupIdOp::create (rewriter, loc, /* upper_bound=*/ nullptr );
1311+ auto sgOffsets =
1312+ layout.computeDistributedCoords (rewriter, loc, sgId, wgShape);
1313+ if (failed (sgOffsets))
1314+ return failure ();
1315+
1316+ SmallVector<int64_t > sgShape = getSgShapeAndCount (wgShape, layout).first ;
1317+ VectorType resultType = VectorType::get (sgShape, type.getElementType ());
1318+
1319+ // In each dimension, each subgroup computes its local mask size as:
1320+ // min(max(wgMaskSize[d] - offset[d], 0), sgDimSize[d])
1321+ SmallVector<Value> newCreateMaskOps;
1322+ for (auto offsetSet : *sgOffsets) {
1323+ SmallVector<Value> maskOperands;
1324+
1325+ for (auto [i, wgMaskSize] : llvm::enumerate (wgMaskDimSizes)) {
1326+ Value wgMaskSizeVal =
1327+ arith::ConstantIndexOp::create (rewriter, loc, wgMaskSize);
1328+ Value dimSizeVal =
1329+ arith::ConstantIndexOp::create (rewriter, loc, sgShape[i]);
1330+ Value offset = offsetSet[i];
1331+ Value adjustedMaskSize =
1332+ arith::SubIOp::create (rewriter, loc, wgMaskSizeVal, offset);
1333+ Value zero = arith::ConstantIndexOp::create (rewriter, loc, 0 );
1334+ Value nonNegative =
1335+ arith::MaxSIOp::create (rewriter, loc, adjustedMaskSize, zero);
1336+ Value sgMaskSize =
1337+ arith::MinSIOp::create (rewriter, loc, nonNegative, dimSizeVal);
1338+ maskOperands.push_back (sgMaskSize);
1339+ }
1340+
1341+ auto newCreateMaskOp =
1342+ vector::CreateMaskOp::create (rewriter, loc, resultType, maskOperands);
1343+ xegpu::setDistributeLayoutAttr (newCreateMaskOp->getResult (0 ),
1344+ layout.dropSgLayoutAndData ());
1345+ newCreateMaskOps.push_back (newCreateMaskOp.getResult ());
1346+ }
1347+
1348+ rewriter.replaceOpWithMultiple (op, {newCreateMaskOps});
1349+ return success ();
1350+ }
1351+ };
1352+
12881353} // namespace
12891354
12901355namespace mlir {
@@ -1299,8 +1364,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
12991364 WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
13001365 WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
13011366 WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp,
1302- WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp>(
1303- patterns.getContext ());
1367+ WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp,
1368+ WgToSgVectorConstantMaskOp>( patterns.getContext ());
13041369}
13051370} // namespace xegpu
13061371} // namespace mlir
@@ -1427,9 +1492,9 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
14271492 return isLegal (layout);
14281493 });
14291494
1430- target.addDynamicallyLegalOp <vector::ShapeCastOp, vector::StepOp,
1431- vector::TransposeOp , vector::BroadcastOp ,
1432- vector::MultiDimReductionOp>(
1495+ target.addDynamicallyLegalOp <
1496+ vector::ShapeCastOp, vector::StepOp , vector::TransposeOp ,
1497+ vector::BroadcastOp, vector::MultiDimReductionOp, vector::ConstantMaskOp >(
14331498 [=](Operation *op) -> bool {
14341499 // Check for either a SliceAttr or LayoutAttr on the result.
14351500 auto layout = xegpu::getDistributeLayoutAttr (op->getResult (0 ));
0 commit comments