Skip to content

Commit e162350

Browse files
committed
Update variable names
1 parent 05f0cf3 commit e162350

File tree

1 file changed

+12
-15
lines changed

1 file changed

+12
-15
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1303,7 +1303,7 @@ struct WgToSgVectorConstantMaskOp
13031303
VectorType type = op.getResult().getType();
13041304
auto wgShape = type.getShape();
13051305

1306-
ArrayRef<int64_t> originalMaskDimSizes = op.getMaskDimSizes();
1306+
ArrayRef<int64_t> wgMaskDimSizes = op.getMaskDimSizes();
13071307

13081308
// Get subgroup ID.
13091309
Value sgId =
@@ -1316,35 +1316,32 @@ struct WgToSgVectorConstantMaskOp
13161316
SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
13171317
VectorType resultType = VectorType::get(sgShape, type.getElementType());
13181318

1319-
// Each subgroup computes its local mask size as: min(max(originalMaskSize -
1319+
// Each subgroup computes its local mask size as: min(max(wgMaskSize -
13201320
// offset, 0), sgDimSize)
13211321
SmallVector<Value> newCreateMaskOps;
13221322
for (auto offsetSet : *sgOffsets) {
13231323
SmallVector<Value> maskOperands;
13241324

1325-
for (auto [i, originalMaskSize] : llvm::enumerate(originalMaskDimSizes)) {
1326-
Value originalMaskSizeVal =
1327-
arith::ConstantIndexOp::create(rewriter, loc, originalMaskSize);
1325+
for (auto [i, wgMaskSize] : llvm::enumerate(wgMaskDimSizes)) {
1326+
Value wgMaskSizeVal =
1327+
arith::ConstantIndexOp::create(rewriter, loc, wgMaskSize);
13281328
Value dimSizeVal =
13291329
arith::ConstantIndexOp::create(rewriter, loc, sgShape[i]);
13301330
Value offset = offsetSet[i];
1331-
// Compute: originalMaskSize - offset.
13321331
Value adjustedMaskSize =
1333-
arith::SubIOp::create(rewriter, loc, originalMaskSizeVal, offset);
1332+
arith::SubIOp::create(rewriter, loc, wgMaskSizeVal, offset);
13341333
Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
1335-
Value clampedLow =
1334+
Value nonNegative =
13361335
arith::MaxSIOp::create(rewriter, loc, adjustedMaskSize, zero);
1337-
Value clampedHigh =
1338-
arith::MinSIOp::create(rewriter, loc, clampedLow, dimSizeVal);
1339-
maskOperands.push_back(clampedHigh);
1336+
Value sgMaskSize =
1337+
arith::MinSIOp::create(rewriter, loc, nonNegative, dimSizeVal);
1338+
maskOperands.push_back(sgMaskSize);
13401339
}
13411340

13421341
auto newCreateMaskOp =
13431342
vector::CreateMaskOp::create(rewriter, loc, resultType, maskOperands);
1344-
if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
1345-
!layout.getEffectiveInstDataAsInt().empty())
1346-
xegpu::setDistributeLayoutAttr(newCreateMaskOp->getResult(0),
1347-
layout.dropSgLayoutAndData());
1343+
xegpu::setDistributeLayoutAttr(newCreateMaskOp->getResult(0),
1344+
layout.dropSgLayoutAndData());
13481345
newCreateMaskOps.push_back(newCreateMaskOp.getResult());
13491346
}
13501347

0 commit comments

Comments
 (0)