@@ -1303,7 +1303,7 @@ struct WgToSgVectorConstantMaskOp
13031303 VectorType type = op.getResult ().getType ();
13041304 auto wgShape = type.getShape ();
13051305
1306- ArrayRef<int64_t > originalMaskDimSizes = op.getMaskDimSizes ();
1306+ ArrayRef<int64_t > wgMaskDimSizes = op.getMaskDimSizes ();
13071307
13081308 // Get subgroup ID.
13091309 Value sgId =
@@ -1316,35 +1316,32 @@ struct WgToSgVectorConstantMaskOp
13161316 SmallVector<int64_t > sgShape = getSgShapeAndCount (wgShape, layout).first ;
13171317 VectorType resultType = VectorType::get (sgShape, type.getElementType ());
13181318
1319- // Each subgroup computes its local mask size as: min(max(originalMaskSize -
1319+ // Each subgroup computes its local mask size as: min(max(wgMaskSize -
13201320 // offset, 0), sgDimSize)
13211321 SmallVector<Value> newCreateMaskOps;
13221322 for (auto offsetSet : *sgOffsets) {
13231323 SmallVector<Value> maskOperands;
13241324
1325- for (auto [i, originalMaskSize ] : llvm::enumerate (originalMaskDimSizes )) {
1326- Value originalMaskSizeVal =
1327- arith::ConstantIndexOp::create (rewriter, loc, originalMaskSize );
1325+ for (auto [i, wgMaskSize ] : llvm::enumerate (wgMaskDimSizes )) {
1326+ Value wgMaskSizeVal =
1327+ arith::ConstantIndexOp::create (rewriter, loc, wgMaskSize );
13281328 Value dimSizeVal =
13291329 arith::ConstantIndexOp::create (rewriter, loc, sgShape[i]);
13301330 Value offset = offsetSet[i];
1331- // Compute: originalMaskSize - offset.
13321331 Value adjustedMaskSize =
1333- arith::SubIOp::create (rewriter, loc, originalMaskSizeVal , offset);
1332+ arith::SubIOp::create (rewriter, loc, wgMaskSizeVal , offset);
13341333 Value zero = arith::ConstantIndexOp::create (rewriter, loc, 0 );
1335- Value clampedLow =
1334+ Value nonNegative =
13361335 arith::MaxSIOp::create (rewriter, loc, adjustedMaskSize, zero);
1337- Value clampedHigh =
1338- arith::MinSIOp::create (rewriter, loc, clampedLow , dimSizeVal);
1339- maskOperands.push_back (clampedHigh );
1336+ Value sgMaskSize =
1337+ arith::MinSIOp::create (rewriter, loc, nonNegative , dimSizeVal);
1338+ maskOperands.push_back (sgMaskSize );
13401339 }
13411340
13421341 auto newCreateMaskOp =
13431342 vector::CreateMaskOp::create (rewriter, loc, resultType, maskOperands);
1344- if (!layout.getEffectiveLaneLayoutAsInt ().empty () ||
1345- !layout.getEffectiveInstDataAsInt ().empty ())
1346- xegpu::setDistributeLayoutAttr (newCreateMaskOp->getResult (0 ),
1347- layout.dropSgLayoutAndData ());
1343+ xegpu::setDistributeLayoutAttr (newCreateMaskOp->getResult (0 ),
1344+ layout.dropSgLayoutAndData ());
13481345 newCreateMaskOps.push_back (newCreateMaskOp.getResult ());
13491346 }
13501347
0 commit comments