@@ -1278,15 +1278,15 @@ struct WgToSgVectorTransposeOp
12781278 }
12791279};
12801280
1281- // This pattern distributes the vector.constant_mask ops to work at subgroup
1282- // level.
1283- struct WgToSgVectorConstantMaskOp
1284- : public OpConversionPattern<vector::ConstantMaskOp> {
1285- using OpConversionPattern<vector::ConstantMaskOp>::OpConversionPattern;
1286-
1287- LogicalResult
1288- matchAndRewrite (vector::ConstantMaskOp op, OneToNOpAdaptor adaptor,
1289- ConversionPatternRewriter &rewriter) const override {
1281+ // Distribute vector mask ops to work at subgroup level.
1282+ template < typename MaskOpType>
1283+ struct WgToSgVectorMaskOp : public OpConversionPattern <MaskOpType> {
1284+ using OpConversionPattern<MaskOpType>::OpConversionPattern;
1285+
1286+ LogicalResult matchAndRewrite (
1287+ MaskOpType op,
1288+ typename OpConversionPattern<MaskOpType>:: OneToNOpAdaptor adaptor,
1289+ ConversionPatternRewriter &rewriter) const override {
12901290 xegpu::DistributeLayoutAttr layout =
12911291 xegpu::getDistributeLayoutAttr (op.getResult ());
12921292 if (!layout || !layout.isForWorkgroup ())
@@ -1296,9 +1296,16 @@ struct WgToSgVectorConstantMaskOp
12961296 VectorType type = op.getResult ().getType ();
12971297 auto wgShape = type.getShape ();
12981298
1299- ArrayRef<int64_t > wgMaskDimSizes = op.getMaskDimSizes ();
1299+ SmallVector<Value> wgMaskDimSizes;
1300+ if constexpr (std::is_same_v<MaskOpType, vector::ConstantMaskOp>) {
1301+ for (int64_t maskSize : op.getMaskDimSizes ()) {
1302+ wgMaskDimSizes.push_back (
1303+ arith::ConstantIndexOp::create (rewriter, loc, maskSize));
1304+ }
1305+ } else if constexpr (std::is_same_v<MaskOpType, vector::CreateMaskOp>) {
1306+ wgMaskDimSizes = llvm::to_vector (op.getOperands ());
1307+ }
13001308
1301- // Get subgroup ID.
13021309 Value sgId =
13031310 gpu::SubgroupIdOp::create (rewriter, loc, /* upper_bound=*/ nullptr );
13041311 auto sgOffsets =
@@ -1310,19 +1317,17 @@ struct WgToSgVectorConstantMaskOp
13101317 VectorType resultType = VectorType::get (sgShape, type.getElementType ());
13111318
13121319 // In each dimension, each subgroup computes its local mask size as:
1313- // min(max(wgMaskSize [d] - offset[d], 0), sgDimSize[d])
1320+ // min(max(wgMaskDimSize [d] - offset[d], 0), sgDimSize[d])
13141321 SmallVector<Value> newCreateMaskOps;
13151322 for (auto offsetSet : *sgOffsets) {
13161323 SmallVector<Value> maskOperands;
13171324
1318- for (auto [i, wgMaskSize] : llvm::enumerate (wgMaskDimSizes)) {
1319- Value wgMaskSizeVal =
1320- arith::ConstantIndexOp::create (rewriter, loc, wgMaskSize);
1325+ for (auto [i, wgMaskDimSize] : llvm::enumerate (wgMaskDimSizes)) {
13211326 Value dimSizeVal =
13221327 arith::ConstantIndexOp::create (rewriter, loc, sgShape[i]);
13231328 Value offset = offsetSet[i];
13241329 Value adjustedMaskSize =
1325- arith::SubIOp::create (rewriter, loc, wgMaskSizeVal , offset);
1330+ arith::SubIOp::create (rewriter, loc, wgMaskDimSize , offset);
13261331 Value zero = arith::ConstantIndexOp::create (rewriter, loc, 0 );
13271332 Value nonNegative =
13281333 arith::MaxSIOp::create (rewriter, loc, adjustedMaskSize, zero);
@@ -1343,6 +1348,8 @@ struct WgToSgVectorConstantMaskOp
13431348 }
13441349};
13451350
1351+ using WgToSgVectorConstantMaskOp = WgToSgVectorMaskOp<vector::ConstantMaskOp>;
1352+ using WgToSgVectorCreateMaskOp = WgToSgVectorMaskOp<vector::CreateMaskOp>;
13461353} // namespace
13471354
13481355namespace mlir {
@@ -1358,7 +1365,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
13581365 WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
13591366 WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp,
13601367 WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp,
1361- WgToSgVectorConstantMaskOp>(patterns.getContext ());
1368+ WgToSgVectorConstantMaskOp, WgToSgVectorCreateMaskOp>(
1369+ patterns.getContext ());
13621370}
13631371} // namespace xegpu
13641372} // namespace mlir
@@ -1485,9 +1493,10 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
14851493 return isLegal (layout);
14861494 });
14871495
1488- target.addDynamicallyLegalOp <
1489- vector::ShapeCastOp, vector::StepOp, vector::TransposeOp,
1490- vector::BroadcastOp, vector::MultiDimReductionOp, vector::ConstantMaskOp>(
1496+ target.addDynamicallyLegalOp <vector::ShapeCastOp, vector::StepOp,
1497+ vector::TransposeOp, vector::BroadcastOp,
1498+ vector::MultiDimReductionOp,
1499+ vector::ConstantMaskOp, vector::CreateMaskOp>(
14911500 [=](Operation *op) -> bool {
14921501 // Check for either a SliceAttr or LayoutAttr on the result.
14931502 auto layout = xegpu::getDistributeLayoutAttr (op->getResult (0 ));
0 commit comments