@@ -1283,57 +1283,74 @@ struct WgToSgVectorTransposeOp
12831283 }
12841284};
12851285
1286- // / Pattern for lowering vector.create_mask and vector.constant_mask ops to
1287- // / subgroup level.
1288- template <typename MaskOpType>
1289- struct WgToSgVectorMaskOp : public OpConversionPattern <MaskOpType> {
1290- using OpConversionPattern<MaskOpType>::OpConversionPattern;
1291-
1292- LogicalResult matchAndRewrite (
1293- MaskOpType op,
1294- typename OpConversionPattern<MaskOpType>::OneToNOpAdaptor adaptor,
1295- ConversionPatternRewriter &rewriter) const override {
1296- VectorType resultType = op.getResult ().getType ();
1297- ArrayRef<int64_t > wgShape = resultType.getShape ();
1286+ // This pattern distributes the vector.constant_mask ops to work at subgroup
1287+ // level.
1288+ struct WgToSgVectorConstantMaskOp
1289+ : public OpConversionPattern<vector::ConstantMaskOp> {
1290+ using OpConversionPattern<vector::ConstantMaskOp>::OpConversionPattern;
12981291
1292+ LogicalResult
1293+ matchAndRewrite (vector::ConstantMaskOp op, OneToNOpAdaptor adaptor,
1294+ ConversionPatternRewriter &rewriter) const override {
12991295 xegpu::DistributeLayoutAttr layout =
13001296 xegpu::getDistributeLayoutAttr (op.getResult ());
13011297 if (!layout || !layout.isForWorkgroup ())
13021298 return failure ();
13031299
1304- SmallVector< int64_t > sgShape ;
1305- int count ;
1306- std::tie (sgShape, count) = getSgShapeAndCount (wgShape, layout );
1307- VectorType newResultType =
1308- VectorType::get (sgShape, resultType. getElementType () );
1300+ Location loc = op. getLoc () ;
1301+ VectorType type = op. getResult (). getType () ;
1302+ auto wgShape = type. getShape ( );
1303+
1304+ ArrayRef< int64_t > originalMaskDimSizes = op. getMaskDimSizes ( );
13091305
1310- SmallVector<Value> newMaskOps;
1311- for (int i = 0 ; i < count; ++i) {
1312- Value newMaskOp;
1313- if constexpr (std::is_same_v<MaskOpType, vector::CreateMaskOp>) {
1314- newMaskOp = vector::CreateMaskOp::create (
1315- rewriter, op.getLoc (), newResultType, op.getOperands ());
1316- } else if constexpr (std::is_same_v<MaskOpType, vector::ConstantMaskOp>) {
1317- newMaskOp = vector::ConstantMaskOp::create (
1318- rewriter, op.getLoc (), newResultType, op.getMaskDimSizes ());
1319- } else {
1320- return rewriter.notifyMatchFailure (op,
1321- " Unsupported mask operation type" );
1306+ // Get subgroup ID.
1307+ Value sgId =
1308+ gpu::SubgroupIdOp::create (rewriter, loc, /* upper_bound=*/ nullptr );
1309+ auto sgOffsets =
1310+ layout.computeDistributedCoords (rewriter, loc, sgId, wgShape);
1311+ if (failed (sgOffsets))
1312+ return failure ();
1313+
1314+ SmallVector<int64_t > sgShape = getSgShapeAndCount (wgShape, layout).first ;
1315+ VectorType resultType = VectorType::get (sgShape, type.getElementType ());
1316+
1317+ SmallVector<Value> newCreateMaskOps;
1318+ for (auto offsetSet : *sgOffsets) {
1319+ SmallVector<Value> maskOperands;
1320+
1321+ for (auto [i, originalMaskSize] : llvm::enumerate (originalMaskDimSizes)) {
1322+ Value originalMaskSizeVal =
1323+ arith::ConstantIndexOp::create (rewriter, loc, originalMaskSize);
1324+ Value dimSizeVal =
1325+ arith::ConstantIndexOp::create (rewriter, loc, sgShape[i]);
1326+ Value offset = offsetSet[i];
1327+ // Compute: originalMaskSize - offset.
1328+ Value adjustedMaskSize =
1329+ arith::SubIOp::create (rewriter, loc, originalMaskSizeVal, offset);
1330+ // Clamp to [0, dimSize]: max(0, min(adjustedMaskSize,
1331+ // dimSize))
1332+ Value zero = arith::ConstantIndexOp::create (rewriter, loc, 0 );
1333+ Value clampedLow =
1334+ arith::MaxSIOp::create (rewriter, loc, adjustedMaskSize, zero);
1335+ Value clampedHigh =
1336+ arith::MinSIOp::create (rewriter, loc, clampedLow, dimSizeVal);
1337+ maskOperands.push_back (clampedHigh);
13221338 }
1323- xegpu::setDistributeLayoutAttr (cast<OpResult>(newMaskOp),
1324- layout.dropSgLayoutAndData ());
13251339
1326- newMaskOps.push_back (newMaskOp);
1340+ auto newCreateMaskOp =
1341+ vector::CreateMaskOp::create (rewriter, loc, resultType, maskOperands);
1342+ if (!layout.getEffectiveLaneLayoutAsInt ().empty () ||
1343+ !layout.getEffectiveInstDataAsInt ().empty ())
1344+ xegpu::setDistributeLayoutAttr (newCreateMaskOp->getResult (0 ),
1345+ layout.dropSgLayoutAndData ());
1346+ newCreateMaskOps.push_back (newCreateMaskOp.getResult ());
13271347 }
13281348
1329- rewriter.replaceOpWithMultiple (op, {newMaskOps });
1349+ rewriter.replaceOpWithMultiple (op, {newCreateMaskOps });
13301350 return success ();
13311351 }
13321352};
13331353
1334- using WgToSgVectorCreateMaskOp = WgToSgVectorMaskOp<vector::CreateMaskOp>;
1335- using WgToSgVectorConstantMaskOp = WgToSgVectorMaskOp<vector::ConstantMaskOp>;
1336-
13371354} // namespace
13381355
13391356namespace mlir {
@@ -1349,8 +1366,7 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
13491366 WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
13501367 WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp,
13511368 WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp,
1352- WgToSgVectorConstantMaskOp, WgToSgVectorCreateMaskOp>(
1353- patterns.getContext ());
1369+ WgToSgVectorConstantMaskOp>(patterns.getContext ());
13541370}
13551371} // namespace xegpu
13561372} // namespace mlir
@@ -1477,10 +1493,9 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
14771493 return isLegal (layout);
14781494 });
14791495
1480- target.addDynamicallyLegalOp <vector::ShapeCastOp, vector::StepOp,
1481- vector::TransposeOp, vector::BroadcastOp,
1482- vector::MultiDimReductionOp,
1483- vector::ConstantMaskOp, vector::CreateMaskOp>(
1496+ target.addDynamicallyLegalOp <
1497+ vector::ShapeCastOp, vector::StepOp, vector::TransposeOp,
1498+ vector::BroadcastOp, vector::MultiDimReductionOp, vector::ConstantMaskOp>(
14841499 [=](Operation *op) -> bool {
14851500 // Check for either a SliceAttr or LayoutAttr on the result.
14861501 auto layout = xegpu::getDistributeLayoutAttr (op->getResult (0 ));
0 commit comments