Skip to content

Commit f57e4bd

Browse files
authored
[VectorDistribution] Add support for distributing vector.constant_mask (iree-org#20708)
This patch implements support for distributing vector.constant_mask op. This is same as vector.create_mask and uses the same implementation.
1 parent 1cbcb4e commit f57e4bd

File tree

3 files changed

+228
-130
lines changed

3 files changed

+228
-130
lines changed

compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp

Lines changed: 181 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -1589,6 +1589,127 @@ struct DistributeStep final : OpDistributionPattern<vector::StepOp> {
15891589
int64_t subgroupSize;
15901590
};
15911591

1592+
SmallVector<Value> createDistributedMaskBounds(PatternRewriter &rewriter,
1593+
Location loc,
1594+
ValueRange upperBounds,
1595+
NestedLayoutAttr layout,
1596+
ArrayRef<Value> subgroupIndices,
1597+
ArrayRef<Value> threadIndices) {
1598+
constexpr int64_t subgroupIdx = 0;
1599+
constexpr int64_t batchIdx = 1;
1600+
constexpr int64_t outerIdx = 2;
1601+
constexpr int64_t threadIdx = 3;
1602+
constexpr int64_t elementIdx = 4;
1603+
SmallVector<Value> bounds;
1604+
auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1605+
auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
1606+
1607+
for (auto [unDistributedDim, upperBound] : llvm::enumerate(upperBounds)) {
1608+
SmallVector<int64_t> undistributedShape =
1609+
layout.getPackedShapeForUndistributedDim(unDistributedDim);
1610+
std::array<int64_t, 3> distrShape{undistributedShape[batchIdx],
1611+
undistributedShape[outerIdx],
1612+
undistributedShape[elementIdx]};
1613+
int64_t elementPerThread = ShapedType::getNumElements(distrShape);
1614+
auto allValid =
1615+
rewriter.create<arith::ConstantIndexOp>(loc, elementPerThread);
1616+
int64_t elementTileSize = distrShape.back();
1617+
auto elementTileLastIdx =
1618+
rewriter.create<arith::ConstantIndexOp>(loc, elementTileSize - 1);
1619+
1620+
// A special condition if the pre-distribution bounds match
1621+
// the mask dimension length, then the distributed bounds
1622+
// should exhibit the same property.
1623+
APInt constUpperBound;
1624+
if (matchPattern(upperBound.getDefiningOp(),
1625+
m_ConstantInt(&constUpperBound))) {
1626+
int64_t undistributedDimLen =
1627+
ShapedType::getNumElements(undistributedShape);
1628+
if (constUpperBound.getZExtValue() == undistributedDimLen) {
1629+
bounds.push_back(allValid);
1630+
continue;
1631+
}
1632+
}
1633+
auto lastValidIdx = rewriter.create<arith::SubIOp>(loc, upperBound, one);
1634+
auto delineraizedLastValidIdx =
1635+
rewriter.create<affine::AffineDelinearizeIndexOp>(loc, lastValidIdx,
1636+
undistributedShape);
1637+
SmallVector<Value> packedLastValidIdx =
1638+
delineraizedLastValidIdx.getResults();
1639+
1640+
// When subgroup id is equal to the subgroup that encounters the bound,
1641+
// Every [vtid] less than [vtid that encounters last valid element] should
1642+
// have a all valid element tile
1643+
auto linearizedLastValidIdxPreThreads =
1644+
rewriter.create<affine::AffineLinearizeIndexOp>(
1645+
loc,
1646+
ValueRange{packedLastValidIdx[batchIdx],
1647+
packedLastValidIdx[outerIdx], elementTileLastIdx},
1648+
distrShape);
1649+
// Bound is defined as lastIdx + 1;
1650+
auto distrUpperBoundPreThreads = rewriter.create<arith::AddIOp>(
1651+
loc, linearizedLastValidIdxPreThreads, one);
1652+
1653+
auto linearizedLastValidIdx =
1654+
rewriter.create<affine::AffineLinearizeIndexOp>(
1655+
loc,
1656+
ValueRange{packedLastValidIdx[batchIdx],
1657+
packedLastValidIdx[outerIdx],
1658+
packedLastValidIdx[elementIdx]},
1659+
distrShape);
1660+
auto distrUpperBound =
1661+
rewriter.create<arith::AddIOp>(loc, linearizedLastValidIdx, one);
1662+
1663+
// The following code constructs a selection tree
1664+
// that in effect follows the code:
1665+
// * upperbound --> delinearize --> u0, u1, u2, u3, u4
1666+
//
1667+
// if sg < u0,
1668+
// all valid.
1669+
// elif sg > u0,
1670+
// all invalid.
1671+
// elif sg == u0,
1672+
// if tid < u3:
1673+
// [u1][u2][max]
1674+
// if tid > u3:
1675+
// all invalid.
1676+
// if tid == u3:
1677+
// [u1][u2][u4]
1678+
1679+
// tid == u3
1680+
auto cmpBoundTidEq = rewriter.create<arith::CmpIOp>(
1681+
loc, arith::CmpIPredicate::eq, threadIndices[unDistributedDim],
1682+
packedLastValidIdx[threadIdx]);
1683+
// tid < u3
1684+
auto cmpBoundTidSlt = rewriter.create<arith::CmpIOp>(
1685+
loc, arith::CmpIPredicate::slt, threadIndices[unDistributedDim],
1686+
packedLastValidIdx[threadIdx]);
1687+
// sg == u0
1688+
auto cmpBoundSgEq = rewriter.create<arith::CmpIOp>(
1689+
loc, arith::CmpIPredicate::eq, subgroupIndices[unDistributedDim],
1690+
packedLastValidIdx[subgroupIdx]);
1691+
// sg < u0
1692+
auto cmpBoundSgSlt = rewriter.create<arith::CmpIOp>(
1693+
loc, arith::CmpIPredicate::slt, subgroupIndices[unDistributedDim],
1694+
packedLastValidIdx[subgroupIdx]);
1695+
1696+
// selectTid0 = tid < u3 ? [u1][u2][max] : all invalid
1697+
auto selectTid0 = rewriter.create<arith::SelectOp>(
1698+
loc, cmpBoundTidSlt, distrUpperBoundPreThreads, zero);
1699+
// selectTid1 = tid == u3 : [u1][u2][u4] : selectTid0
1700+
auto selectTid1 = rewriter.create<arith::SelectOp>(
1701+
loc, cmpBoundTidEq, distrUpperBound, selectTid0);
1702+
// selectSg0 = sg < u0 ? all valid : all invalid
1703+
auto selectSg0 =
1704+
rewriter.create<arith::SelectOp>(loc, cmpBoundSgSlt, allValid, zero);
1705+
// selectSg1 = sg == u0 ? selectTid1 : selectSg0
1706+
auto selectSg1 = rewriter.create<arith::SelectOp>(loc, cmpBoundSgEq,
1707+
selectTid1, selectSg0);
1708+
bounds.push_back(selectSg1);
1709+
}
1710+
return bounds;
1711+
}
1712+
15921713
struct DistributeCreateMask final
15931714
: OpDistributionPattern<vector::CreateMaskOp> {
15941715
using OpDistributionPattern::OpDistributionPattern;
@@ -1597,157 +1718,88 @@ struct DistributeCreateMask final
15971718
: OpDistributionPattern(context), threadId(threadId),
15981719
subgroupSize(subgroupSize) {}
15991720

1600-
SmallVector<Value>
1601-
createDistributedBounds(PatternRewriter &rewriter, Location loc,
1602-
OperandRange upperBounds, NestedLayoutAttr layout,
1603-
ArrayRef<Value> subgroupIndices,
1604-
ArrayRef<Value> threadIndices) const {
1605-
constexpr int64_t subgroupIdx = 0;
1606-
constexpr int64_t batchIdx = 1;
1607-
constexpr int64_t outerIdx = 2;
1608-
constexpr int64_t threadIdx = 3;
1609-
constexpr int64_t elementIdx = 4;
1610-
SmallVector<Value> bounds;
1611-
auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1612-
auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
1613-
1614-
for (auto [unDistributedDim, upperBound] : llvm::enumerate(upperBounds)) {
1615-
SmallVector<int64_t> undistributedShape =
1616-
layout.getPackedShapeForUndistributedDim(unDistributedDim);
1617-
SmallVector<int64_t> distrShape{undistributedShape[batchIdx],
1618-
undistributedShape[outerIdx],
1619-
undistributedShape[elementIdx]};
1620-
int64_t elementPerThread = ShapedType::getNumElements(distrShape);
1621-
auto allValid =
1622-
rewriter.create<arith::ConstantIndexOp>(loc, elementPerThread);
1623-
int64_t elementTileSize = distrShape.back();
1624-
auto elementTileLastIdx =
1625-
rewriter.create<arith::ConstantIndexOp>(loc, elementTileSize - 1);
1626-
1627-
// A special condition if the pre-distribution bounds match
1628-
// the mask dimension length, then the distributed bounds
1629-
// should exhibit the same property.
1630-
if (auto constUpperBound = dyn_cast_or_null<arith::ConstantIndexOp>(
1631-
upperBound.getDefiningOp())) {
1632-
int64_t undistributedDimLen =
1633-
ShapedType::getNumElements(undistributedShape);
1634-
if (constUpperBound.value() == undistributedDimLen) {
1635-
bounds.push_back(allValid);
1636-
continue;
1637-
}
1638-
}
1639-
auto lastValidIdx = rewriter.create<arith::SubIOp>(loc, upperBound, one);
1640-
auto delineraizedLastValidIdx =
1641-
rewriter.create<affine::AffineDelinearizeIndexOp>(loc, lastValidIdx,
1642-
undistributedShape);
1643-
SmallVector<Value> packedLastValidIdx =
1644-
delineraizedLastValidIdx.getResults();
1645-
1646-
// When subgroup id is equal to the subgroup that encounters the bound,
1647-
// Every [vtid] less than [vtid that encounters last valid element] should
1648-
// have a all valid element tile
1649-
auto linearizedLastValidIdxPreThreads =
1650-
rewriter.create<affine::AffineLinearizeIndexOp>(
1651-
loc,
1652-
ValueRange{packedLastValidIdx[batchIdx],
1653-
packedLastValidIdx[outerIdx], elementTileLastIdx},
1654-
distrShape);
1655-
// Bound is defined as lastIdx + 1;
1656-
auto distrUpperBoundPreThreads = rewriter.create<arith::AddIOp>(
1657-
loc, linearizedLastValidIdxPreThreads, one);
1658-
1659-
auto linearizedLastValidIdx =
1660-
rewriter.create<affine::AffineLinearizeIndexOp>(
1661-
loc,
1662-
ValueRange{packedLastValidIdx[batchIdx],
1663-
packedLastValidIdx[outerIdx],
1664-
packedLastValidIdx[elementIdx]},
1665-
distrShape);
1666-
auto distrUpperBound =
1667-
rewriter.create<arith::AddIOp>(loc, linearizedLastValidIdx, one);
1668-
1669-
// The following code constructs a selection tree
1670-
// that in effect follows the code:
1671-
// * upperbound --> delinearize --> u0, u1, u2, u3, u4
1672-
//
1673-
// if sg < u0,
1674-
// all valid.
1675-
// elif sg > u0,
1676-
// all invalid.
1677-
// elif sg == u0,
1678-
// if tid < u3:
1679-
// [u1][u2][max]
1680-
// if tid > u3:
1681-
// all invalid.
1682-
// if tid == u3:
1683-
// [u1][u2][u4]
1684-
1685-
// tid == u3
1686-
auto cmpBoundTidEq = rewriter.create<arith::CmpIOp>(
1687-
loc, arith::CmpIPredicate::eq, threadIndices[unDistributedDim],
1688-
packedLastValidIdx[threadIdx]);
1689-
// tid < u3
1690-
auto cmpBoundTidSlt = rewriter.create<arith::CmpIOp>(
1691-
loc, arith::CmpIPredicate::slt, threadIndices[unDistributedDim],
1692-
packedLastValidIdx[threadIdx]);
1693-
// sg == u0
1694-
auto cmpBoundSgEq = rewriter.create<arith::CmpIOp>(
1695-
loc, arith::CmpIPredicate::eq, subgroupIndices[unDistributedDim],
1696-
packedLastValidIdx[subgroupIdx]);
1697-
// sg < u0
1698-
auto cmpBoundSgSlt = rewriter.create<arith::CmpIOp>(
1699-
loc, arith::CmpIPredicate::slt, subgroupIndices[unDistributedDim],
1700-
packedLastValidIdx[subgroupIdx]);
1701-
1702-
// selectTid0 = tid < u3 ? [u1][u2][max] : all invalid
1703-
auto selectTid0 = rewriter.create<arith::SelectOp>(
1704-
loc, cmpBoundTidSlt, distrUpperBoundPreThreads, zero);
1705-
// selectTid1 = tid == u3 : [u1][u2][u4] : selectTid0
1706-
auto selectTid1 = rewriter.create<arith::SelectOp>(
1707-
loc, cmpBoundTidEq, distrUpperBound, selectTid0);
1708-
// selectSg0 = sg < u0 ? all valid : all invalid
1709-
auto selectSg0 =
1710-
rewriter.create<arith::SelectOp>(loc, cmpBoundSgSlt, allValid, zero);
1711-
// selectSg1 = sg == u0 ? selectTid1 : selectSg0
1712-
auto selectSg1 = rewriter.create<arith::SelectOp>(loc, cmpBoundSgEq,
1713-
selectTid1, selectSg0);
1714-
bounds.push_back(selectSg1);
1715-
}
1716-
return bounds;
1721+
LogicalResult matchAndRewrite(vector::CreateMaskOp maskOp,
1722+
DistributionSignature &signature,
1723+
PatternRewriter &rewriter) const override {
1724+
Location loc = maskOp.getLoc();
1725+
VectorValue result = maskOp.getResult();
1726+
NestedLayoutAttr resultLayout =
1727+
dyn_cast<NestedLayoutAttr>(signature[result]);
1728+
if (!resultLayout) {
1729+
return rewriter.notifyMatchFailure(
1730+
maskOp, "missing nested layout for step op result");
1731+
}
1732+
SmallVector<Value> subgroupIndices, threadIndices;
1733+
if (failed(populateWarpAndThreadIndices(rewriter, threadId, subgroupSize,
1734+
resultLayout, subgroupIndices,
1735+
threadIndices))) {
1736+
return rewriter.notifyMatchFailure(
1737+
maskOp, "warp or thread tiles have overlapping strides");
1738+
}
1739+
1740+
SmallVector<Value> distributedBounds = createDistributedMaskBounds(
1741+
rewriter, loc, maskOp.getOperands(), resultLayout, subgroupIndices,
1742+
threadIndices);
1743+
1744+
Type elemType = maskOp.getType().getElementType();
1745+
auto distrUnpackedType =
1746+
VectorType::get(resultLayout.getDistributedUnpackedShape(), elemType);
1747+
auto distrMask = rewriter.create<vector::CreateMaskOp>(
1748+
loc, distrUnpackedType, distributedBounds);
1749+
VectorValue interleavedDistrMask =
1750+
getInterleavedPackedForm(rewriter, distrMask, resultLayout);
1751+
replaceOpWithDistributedValues(rewriter, maskOp, {interleavedDistrMask});
1752+
return success();
17171753
}
1754+
Value threadId;
1755+
int64_t subgroupSize;
1756+
};
1757+
1758+
struct DistributeConstantMask final
1759+
: OpDistributionPattern<vector::ConstantMaskOp> {
1760+
using OpDistributionPattern::OpDistributionPattern;
1761+
DistributeConstantMask(MLIRContext *context, Value threadId,
1762+
int64_t subgroupSize)
1763+
: OpDistributionPattern(context), threadId(threadId),
1764+
subgroupSize(subgroupSize) {}
17181765

1719-
LogicalResult matchAndRewrite(vector::CreateMaskOp creatMaskOp,
1766+
LogicalResult matchAndRewrite(vector::ConstantMaskOp maskOp,
17201767
DistributionSignature &signature,
17211768
PatternRewriter &rewriter) const override {
1722-
Location loc = creatMaskOp.getLoc();
1723-
VectorValue result = creatMaskOp.getResult();
1769+
Location loc = maskOp.getLoc();
1770+
VectorValue result = maskOp.getResult();
17241771
NestedLayoutAttr resultLayout =
17251772
dyn_cast<NestedLayoutAttr>(signature[result]);
17261773
if (!resultLayout) {
17271774
return rewriter.notifyMatchFailure(
1728-
creatMaskOp, "missing nested layout for step op result");
1775+
maskOp, "missing nested layout for step op result");
17291776
}
17301777
SmallVector<Value> subgroupIndices, threadIndices;
17311778
if (failed(populateWarpAndThreadIndices(rewriter, threadId, subgroupSize,
17321779
resultLayout, subgroupIndices,
17331780
threadIndices))) {
17341781
return rewriter.notifyMatchFailure(
1735-
creatMaskOp, "warp or thread tiles have overlapping strides");
1782+
maskOp, "warp or thread tiles have overlapping strides");
1783+
}
1784+
1785+
SmallVector<Value> constOperands;
1786+
for (int64_t size : maskOp.getMaskDimSizes()) {
1787+
Value index = rewriter.create<arith::ConstantIndexOp>(loc, size);
1788+
constOperands.push_back(index);
17361789
}
17371790

17381791
SmallVector<Value> distributedBounds =
1739-
createDistributedBounds(rewriter, loc, creatMaskOp.getOperands(),
1740-
resultLayout, subgroupIndices, threadIndices);
1792+
createDistributedMaskBounds(rewriter, loc, constOperands, resultLayout,
1793+
subgroupIndices, threadIndices);
17411794

1742-
Type elemType = creatMaskOp.getType().getElementType();
1795+
Type elemType = maskOp.getType().getElementType();
17431796
auto distrUnpackedType =
17441797
VectorType::get(resultLayout.getDistributedUnpackedShape(), elemType);
17451798
auto distrMask = rewriter.create<vector::CreateMaskOp>(
17461799
loc, distrUnpackedType, distributedBounds);
17471800
VectorValue interleavedDistrMask =
17481801
getInterleavedPackedForm(rewriter, distrMask, resultLayout);
1749-
replaceOpWithDistributedValues(rewriter, creatMaskOp,
1750-
{interleavedDistrMask});
1802+
replaceOpWithDistributedValues(rewriter, maskOp, {interleavedDistrMask});
17511803
return success();
17521804
}
17531805
Value threadId;
@@ -1768,8 +1820,8 @@ void populateGPUDistributeNestedLayoutAttrPatterns(RewritePatternSet &patterns,
17681820
patterns.add<DistributeContract>(patterns.getContext());
17691821
patterns.add<DistributeBatchOuterToLayoutConversions>(patterns.getContext());
17701822
patterns.add<DistributeStep>(patterns.getContext(), threadId, subgroupSize);
1771-
patterns.add<DistributeCreateMask>(patterns.getContext(), threadId,
1772-
subgroupSize);
1823+
patterns.add<DistributeCreateMask, DistributeConstantMask>(
1824+
patterns.getContext(), threadId, subgroupSize);
17731825
}
17741826

17751827
}; // namespace mlir::iree_compiler

0 commit comments

Comments
 (0)