@@ -1283,6 +1283,57 @@ struct WgToSgVectorTransposeOp
12831283 }
12841284};
12851285
1286+ // / Pattern for lowering vector.create_mask and vector.constant_mask ops to
1287+ // / subgroup level.
1288+ template <typename MaskOpType>
1289+ struct WgToSgVectorMaskOp : public OpConversionPattern <MaskOpType> {
1290+ using OpConversionPattern<MaskOpType>::OpConversionPattern;
1291+
1292+ LogicalResult matchAndRewrite (
1293+ MaskOpType op,
1294+ typename OpConversionPattern<MaskOpType>::OneToNOpAdaptor adaptor,
1295+ ConversionPatternRewriter &rewriter) const override {
1296+ VectorType resultType = op.getResult ().getType ();
1297+ ArrayRef<int64_t > wgShape = resultType.getShape ();
1298+
1299+ xegpu::DistributeLayoutAttr layout =
1300+ xegpu::getDistributeLayoutAttr (op.getResult ());
1301+ if (!layout || !layout.isForWorkgroup ())
1302+ return failure ();
1303+
1304+ SmallVector<int64_t > sgShape;
1305+ int count;
1306+ std::tie (sgShape, count) = getSgShapeAndCount (wgShape, layout);
1307+ VectorType newResultType =
1308+ VectorType::get (sgShape, resultType.getElementType ());
1309+
1310+ SmallVector<Value> newMaskOps;
1311+ for (int i = 0 ; i < count; ++i) {
1312+ Value newMaskOp;
1313+ if constexpr (std::is_same_v<MaskOpType, vector::CreateMaskOp>) {
1314+ newMaskOp = vector::CreateMaskOp::create (
1315+ rewriter, op.getLoc (), newResultType, op.getOperands ());
1316+ } else if constexpr (std::is_same_v<MaskOpType, vector::ConstantMaskOp>) {
1317+ newMaskOp = vector::ConstantMaskOp::create (
1318+ rewriter, op.getLoc (), newResultType, op.getMaskDimSizes ());
1319+ } else {
1320+ return rewriter.notifyMatchFailure (op,
1321+ " Unsupported mask operation type" );
1322+ }
1323+ xegpu::setDistributeLayoutAttr (cast<OpResult>(newMaskOp),
1324+ layout.dropSgLayoutAndData ());
1325+
1326+ newMaskOps.push_back (newMaskOp);
1327+ }
1328+
1329+ rewriter.replaceOpWithMultiple (op, {newMaskOps});
1330+ return success ();
1331+ }
1332+ };
1333+
1334+ using WgToSgVectorCreateMaskOp = WgToSgVectorMaskOp<vector::CreateMaskOp>;
1335+ using WgToSgVectorConstantMaskOp = WgToSgVectorMaskOp<vector::ConstantMaskOp>;
1336+
12861337} // namespace
12871338
12881339namespace mlir {
@@ -1297,7 +1348,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
12971348 WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
12981349 WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
12991350 WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp,
1300- WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp>(
1351+ WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp,
1352+ WgToSgVectorConstantMaskOp, WgToSgVectorCreateMaskOp>(
13011353 patterns.getContext ());
13021354}
13031355} // namespace xegpu
@@ -1427,7 +1479,8 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
14271479
14281480 target.addDynamicallyLegalOp <vector::ShapeCastOp, vector::StepOp,
14291481 vector::TransposeOp, vector::BroadcastOp,
1430- vector::MultiDimReductionOp>(
1482+ vector::MultiDimReductionOp,
1483+ vector::ConstantMaskOp, vector::CreateMaskOp>(
14311484 [=](Operation *op) -> bool {
14321485 // Check for either a SliceAttr or LayoutAttr on the result.
14331486 auto layout = xegpu::getDistributeLayoutAttr (op->getResult (0 ));
0 commit comments