@@ -1589,6 +1589,127 @@ struct DistributeStep final : OpDistributionPattern<vector::StepOp> {
15891589 int64_t subgroupSize;
15901590};
15911591
1592+ SmallVector<Value> createDistributedMaskBounds (PatternRewriter &rewriter,
1593+ Location loc,
1594+ ValueRange upperBounds,
1595+ NestedLayoutAttr layout,
1596+ ArrayRef<Value> subgroupIndices,
1597+ ArrayRef<Value> threadIndices) {
1598+ constexpr int64_t subgroupIdx = 0 ;
1599+ constexpr int64_t batchIdx = 1 ;
1600+ constexpr int64_t outerIdx = 2 ;
1601+ constexpr int64_t threadIdx = 3 ;
1602+ constexpr int64_t elementIdx = 4 ;
1603+ SmallVector<Value> bounds;
1604+ auto zero = rewriter.create <arith::ConstantIndexOp>(loc, 0 );
1605+ auto one = rewriter.create <arith::ConstantIndexOp>(loc, 1 );
1606+
1607+ for (auto [unDistributedDim, upperBound] : llvm::enumerate (upperBounds)) {
1608+ SmallVector<int64_t > undistributedShape =
1609+ layout.getPackedShapeForUndistributedDim (unDistributedDim);
1610+ std::array<int64_t , 3 > distrShape{undistributedShape[batchIdx],
1611+ undistributedShape[outerIdx],
1612+ undistributedShape[elementIdx]};
1613+ int64_t elementPerThread = ShapedType::getNumElements (distrShape);
1614+ auto allValid =
1615+ rewriter.create <arith::ConstantIndexOp>(loc, elementPerThread);
1616+ int64_t elementTileSize = distrShape.back ();
1617+ auto elementTileLastIdx =
1618+ rewriter.create <arith::ConstantIndexOp>(loc, elementTileSize - 1 );
1619+
1620+ // A special condition if the pre-distribution bounds match
1621+ // the mask dimension length, then the distributed bounds
1622+ // should exhibit the same property.
1623+ APInt constUpperBound;
1624+ if (matchPattern (upperBound.getDefiningOp (),
1625+ m_ConstantInt (&constUpperBound))) {
1626+ int64_t undistributedDimLen =
1627+ ShapedType::getNumElements (undistributedShape);
1628+ if (constUpperBound.getZExtValue () == undistributedDimLen) {
1629+ bounds.push_back (allValid);
1630+ continue ;
1631+ }
1632+ }
1633+ auto lastValidIdx = rewriter.create <arith::SubIOp>(loc, upperBound, one);
1634+ auto delineraizedLastValidIdx =
1635+ rewriter.create <affine::AffineDelinearizeIndexOp>(loc, lastValidIdx,
1636+ undistributedShape);
1637+ SmallVector<Value> packedLastValidIdx =
1638+ delineraizedLastValidIdx.getResults ();
1639+
1640+ // When subgroup id is equal to the subgroup that encounters the bound,
1641+ // Every [vtid] less than [vtid that encounters last valid element] should
1642+ // have a all valid element tile
1643+ auto linearizedLastValidIdxPreThreads =
1644+ rewriter.create <affine::AffineLinearizeIndexOp>(
1645+ loc,
1646+ ValueRange{packedLastValidIdx[batchIdx],
1647+ packedLastValidIdx[outerIdx], elementTileLastIdx},
1648+ distrShape);
1649+ // Bound is defined as lastIdx + 1;
1650+ auto distrUpperBoundPreThreads = rewriter.create <arith::AddIOp>(
1651+ loc, linearizedLastValidIdxPreThreads, one);
1652+
1653+ auto linearizedLastValidIdx =
1654+ rewriter.create <affine::AffineLinearizeIndexOp>(
1655+ loc,
1656+ ValueRange{packedLastValidIdx[batchIdx],
1657+ packedLastValidIdx[outerIdx],
1658+ packedLastValidIdx[elementIdx]},
1659+ distrShape);
1660+ auto distrUpperBound =
1661+ rewriter.create <arith::AddIOp>(loc, linearizedLastValidIdx, one);
1662+
1663+ // The following code constructs a selection tree
1664+ // that in effect follows the code:
1665+ // * upperbound --> delinearize --> u0, u1, u2, u3, u4
1666+ //
1667+ // if sg < u0,
1668+ // all valid.
1669+ // elif sg > u0,
1670+ // all invalid.
1671+ // elif sg == u0,
1672+ // if tid < u3:
1673+ // [u1][u2][max]
1674+ // if tid > u3:
1675+ // all invalid.
1676+ // if tid == u3:
1677+ // [u1][u2][u4]
1678+
1679+ // tid == u3
1680+ auto cmpBoundTidEq = rewriter.create <arith::CmpIOp>(
1681+ loc, arith::CmpIPredicate::eq, threadIndices[unDistributedDim],
1682+ packedLastValidIdx[threadIdx]);
1683+ // tid < u3
1684+ auto cmpBoundTidSlt = rewriter.create <arith::CmpIOp>(
1685+ loc, arith::CmpIPredicate::slt, threadIndices[unDistributedDim],
1686+ packedLastValidIdx[threadIdx]);
1687+ // sg == u0
1688+ auto cmpBoundSgEq = rewriter.create <arith::CmpIOp>(
1689+ loc, arith::CmpIPredicate::eq, subgroupIndices[unDistributedDim],
1690+ packedLastValidIdx[subgroupIdx]);
1691+ // sg < u0
1692+ auto cmpBoundSgSlt = rewriter.create <arith::CmpIOp>(
1693+ loc, arith::CmpIPredicate::slt, subgroupIndices[unDistributedDim],
1694+ packedLastValidIdx[subgroupIdx]);
1695+
1696+ // selectTid0 = tid < u3 ? [u1][u2][max] : all invalid
1697+ auto selectTid0 = rewriter.create <arith::SelectOp>(
1698+ loc, cmpBoundTidSlt, distrUpperBoundPreThreads, zero);
1699+ // selectTid1 = tid == u3 : [u1][u2][u4] : selectTid0
1700+ auto selectTid1 = rewriter.create <arith::SelectOp>(
1701+ loc, cmpBoundTidEq, distrUpperBound, selectTid0);
1702+ // selectSg0 = sg < u0 ? all valid : all invalid
1703+ auto selectSg0 =
1704+ rewriter.create <arith::SelectOp>(loc, cmpBoundSgSlt, allValid, zero);
1705+ // selectSg1 = sg == u0 ? selectTid1 : selectSg0
1706+ auto selectSg1 = rewriter.create <arith::SelectOp>(loc, cmpBoundSgEq,
1707+ selectTid1, selectSg0);
1708+ bounds.push_back (selectSg1);
1709+ }
1710+ return bounds;
1711+ }
1712+
15921713struct DistributeCreateMask final
15931714 : OpDistributionPattern<vector::CreateMaskOp> {
15941715 using OpDistributionPattern::OpDistributionPattern;
@@ -1597,157 +1718,88 @@ struct DistributeCreateMask final
15971718 : OpDistributionPattern(context), threadId(threadId),
15981719 subgroupSize (subgroupSize) {}
15991720
1600- SmallVector<Value>
1601- createDistributedBounds (PatternRewriter &rewriter, Location loc,
1602- OperandRange upperBounds, NestedLayoutAttr layout,
1603- ArrayRef<Value> subgroupIndices,
1604- ArrayRef<Value> threadIndices) const {
1605- constexpr int64_t subgroupIdx = 0 ;
1606- constexpr int64_t batchIdx = 1 ;
1607- constexpr int64_t outerIdx = 2 ;
1608- constexpr int64_t threadIdx = 3 ;
1609- constexpr int64_t elementIdx = 4 ;
1610- SmallVector<Value> bounds;
1611- auto zero = rewriter.create <arith::ConstantIndexOp>(loc, 0 );
1612- auto one = rewriter.create <arith::ConstantIndexOp>(loc, 1 );
1613-
1614- for (auto [unDistributedDim, upperBound] : llvm::enumerate (upperBounds)) {
1615- SmallVector<int64_t > undistributedShape =
1616- layout.getPackedShapeForUndistributedDim (unDistributedDim);
1617- SmallVector<int64_t > distrShape{undistributedShape[batchIdx],
1618- undistributedShape[outerIdx],
1619- undistributedShape[elementIdx]};
1620- int64_t elementPerThread = ShapedType::getNumElements (distrShape);
1621- auto allValid =
1622- rewriter.create <arith::ConstantIndexOp>(loc, elementPerThread);
1623- int64_t elementTileSize = distrShape.back ();
1624- auto elementTileLastIdx =
1625- rewriter.create <arith::ConstantIndexOp>(loc, elementTileSize - 1 );
1626-
1627- // A special condition if the pre-distribution bounds match
1628- // the mask dimension length, then the distributed bounds
1629- // should exhibit the same property.
1630- if (auto constUpperBound = dyn_cast_or_null<arith::ConstantIndexOp>(
1631- upperBound.getDefiningOp ())) {
1632- int64_t undistributedDimLen =
1633- ShapedType::getNumElements (undistributedShape);
1634- if (constUpperBound.value () == undistributedDimLen) {
1635- bounds.push_back (allValid);
1636- continue ;
1637- }
1638- }
1639- auto lastValidIdx = rewriter.create <arith::SubIOp>(loc, upperBound, one);
1640- auto delineraizedLastValidIdx =
1641- rewriter.create <affine::AffineDelinearizeIndexOp>(loc, lastValidIdx,
1642- undistributedShape);
1643- SmallVector<Value> packedLastValidIdx =
1644- delineraizedLastValidIdx.getResults ();
1645-
1646- // When subgroup id is equal to the subgroup that encounters the bound,
1647- // Every [vtid] less than [vtid that encounters last valid element] should
1648- // have a all valid element tile
1649- auto linearizedLastValidIdxPreThreads =
1650- rewriter.create <affine::AffineLinearizeIndexOp>(
1651- loc,
1652- ValueRange{packedLastValidIdx[batchIdx],
1653- packedLastValidIdx[outerIdx], elementTileLastIdx},
1654- distrShape);
1655- // Bound is defined as lastIdx + 1;
1656- auto distrUpperBoundPreThreads = rewriter.create <arith::AddIOp>(
1657- loc, linearizedLastValidIdxPreThreads, one);
1658-
1659- auto linearizedLastValidIdx =
1660- rewriter.create <affine::AffineLinearizeIndexOp>(
1661- loc,
1662- ValueRange{packedLastValidIdx[batchIdx],
1663- packedLastValidIdx[outerIdx],
1664- packedLastValidIdx[elementIdx]},
1665- distrShape);
1666- auto distrUpperBound =
1667- rewriter.create <arith::AddIOp>(loc, linearizedLastValidIdx, one);
1668-
1669- // The following code constructs a selection tree
1670- // that in effect follows the code:
1671- // * upperbound --> delinearize --> u0, u1, u2, u3, u4
1672- //
1673- // if sg < u0,
1674- // all valid.
1675- // elif sg > u0,
1676- // all invalid.
1677- // elif sg == u0,
1678- // if tid < u3:
1679- // [u1][u2][max]
1680- // if tid > u3:
1681- // all invalid.
1682- // if tid == u3:
1683- // [u1][u2][u4]
1684-
1685- // tid == u3
1686- auto cmpBoundTidEq = rewriter.create <arith::CmpIOp>(
1687- loc, arith::CmpIPredicate::eq, threadIndices[unDistributedDim],
1688- packedLastValidIdx[threadIdx]);
1689- // tid < u3
1690- auto cmpBoundTidSlt = rewriter.create <arith::CmpIOp>(
1691- loc, arith::CmpIPredicate::slt, threadIndices[unDistributedDim],
1692- packedLastValidIdx[threadIdx]);
1693- // sg == u0
1694- auto cmpBoundSgEq = rewriter.create <arith::CmpIOp>(
1695- loc, arith::CmpIPredicate::eq, subgroupIndices[unDistributedDim],
1696- packedLastValidIdx[subgroupIdx]);
1697- // sg < u0
1698- auto cmpBoundSgSlt = rewriter.create <arith::CmpIOp>(
1699- loc, arith::CmpIPredicate::slt, subgroupIndices[unDistributedDim],
1700- packedLastValidIdx[subgroupIdx]);
1701-
1702- // selectTid0 = tid < u3 ? [u1][u2][max] : all invalid
1703- auto selectTid0 = rewriter.create <arith::SelectOp>(
1704- loc, cmpBoundTidSlt, distrUpperBoundPreThreads, zero);
1705- // selectTid1 = tid == u3 : [u1][u2][u4] : selectTid0
1706- auto selectTid1 = rewriter.create <arith::SelectOp>(
1707- loc, cmpBoundTidEq, distrUpperBound, selectTid0);
1708- // selectSg0 = sg < u0 ? all valid : all invalid
1709- auto selectSg0 =
1710- rewriter.create <arith::SelectOp>(loc, cmpBoundSgSlt, allValid, zero);
1711- // selectSg1 = sg == u0 ? selectTid1 : selectSg0
1712- auto selectSg1 = rewriter.create <arith::SelectOp>(loc, cmpBoundSgEq,
1713- selectTid1, selectSg0);
1714- bounds.push_back (selectSg1);
1715- }
1716- return bounds;
1721+ LogicalResult matchAndRewrite (vector::CreateMaskOp maskOp,
1722+ DistributionSignature &signature,
1723+ PatternRewriter &rewriter) const override {
1724+ Location loc = maskOp.getLoc ();
1725+ VectorValue result = maskOp.getResult ();
1726+ NestedLayoutAttr resultLayout =
1727+ dyn_cast<NestedLayoutAttr>(signature[result]);
1728+ if (!resultLayout) {
1729+ return rewriter.notifyMatchFailure (
1730+ maskOp, " missing nested layout for step op result" );
1731+ }
1732+ SmallVector<Value> subgroupIndices, threadIndices;
1733+ if (failed (populateWarpAndThreadIndices (rewriter, threadId, subgroupSize,
1734+ resultLayout, subgroupIndices,
1735+ threadIndices))) {
1736+ return rewriter.notifyMatchFailure (
1737+ maskOp, " warp or thread tiles have overlapping strides" );
1738+ }
1739+
1740+ SmallVector<Value> distributedBounds = createDistributedMaskBounds (
1741+ rewriter, loc, maskOp.getOperands (), resultLayout, subgroupIndices,
1742+ threadIndices);
1743+
1744+ Type elemType = maskOp.getType ().getElementType ();
1745+ auto distrUnpackedType =
1746+ VectorType::get (resultLayout.getDistributedUnpackedShape (), elemType);
1747+ auto distrMask = rewriter.create <vector::CreateMaskOp>(
1748+ loc, distrUnpackedType, distributedBounds);
1749+ VectorValue interleavedDistrMask =
1750+ getInterleavedPackedForm (rewriter, distrMask, resultLayout);
1751+ replaceOpWithDistributedValues (rewriter, maskOp, {interleavedDistrMask});
1752+ return success ();
17171753 }
1754+ Value threadId;
1755+ int64_t subgroupSize;
1756+ };
1757+
1758+ struct DistributeConstantMask final
1759+ : OpDistributionPattern<vector::ConstantMaskOp> {
1760+ using OpDistributionPattern::OpDistributionPattern;
1761+ DistributeConstantMask (MLIRContext *context, Value threadId,
1762+ int64_t subgroupSize)
1763+ : OpDistributionPattern(context), threadId(threadId),
1764+ subgroupSize (subgroupSize) {}
17181765
1719- LogicalResult matchAndRewrite (vector::CreateMaskOp creatMaskOp ,
1766+ LogicalResult matchAndRewrite (vector::ConstantMaskOp maskOp ,
17201767 DistributionSignature &signature,
17211768 PatternRewriter &rewriter) const override {
1722- Location loc = creatMaskOp .getLoc ();
1723- VectorValue result = creatMaskOp .getResult ();
1769+ Location loc = maskOp .getLoc ();
1770+ VectorValue result = maskOp .getResult ();
17241771 NestedLayoutAttr resultLayout =
17251772 dyn_cast<NestedLayoutAttr>(signature[result]);
17261773 if (!resultLayout) {
17271774 return rewriter.notifyMatchFailure (
1728- creatMaskOp , " missing nested layout for step op result" );
1775+ maskOp , " missing nested layout for step op result" );
17291776 }
17301777 SmallVector<Value> subgroupIndices, threadIndices;
17311778 if (failed (populateWarpAndThreadIndices (rewriter, threadId, subgroupSize,
17321779 resultLayout, subgroupIndices,
17331780 threadIndices))) {
17341781 return rewriter.notifyMatchFailure (
1735- creatMaskOp, " warp or thread tiles have overlapping strides" );
1782+ maskOp, " warp or thread tiles have overlapping strides" );
1783+ }
1784+
1785+ SmallVector<Value> constOperands;
1786+ for (int64_t size : maskOp.getMaskDimSizes ()) {
1787+ Value index = rewriter.create <arith::ConstantIndexOp>(loc, size);
1788+ constOperands.push_back (index);
17361789 }
17371790
17381791 SmallVector<Value> distributedBounds =
1739- createDistributedBounds (rewriter, loc, creatMaskOp. getOperands () ,
1740- resultLayout, subgroupIndices, threadIndices);
1792+ createDistributedMaskBounds (rewriter, loc, constOperands, resultLayout ,
1793+ subgroupIndices, threadIndices);
17411794
1742- Type elemType = creatMaskOp .getType ().getElementType ();
1795+ Type elemType = maskOp .getType ().getElementType ();
17431796 auto distrUnpackedType =
17441797 VectorType::get (resultLayout.getDistributedUnpackedShape (), elemType);
17451798 auto distrMask = rewriter.create <vector::CreateMaskOp>(
17461799 loc, distrUnpackedType, distributedBounds);
17471800 VectorValue interleavedDistrMask =
17481801 getInterleavedPackedForm (rewriter, distrMask, resultLayout);
1749- replaceOpWithDistributedValues (rewriter, creatMaskOp,
1750- {interleavedDistrMask});
1802+ replaceOpWithDistributedValues (rewriter, maskOp, {interleavedDistrMask});
17511803 return success ();
17521804 }
17531805 Value threadId;
@@ -1768,8 +1820,8 @@ void populateGPUDistributeNestedLayoutAttrPatterns(RewritePatternSet &patterns,
17681820 patterns.add <DistributeContract>(patterns.getContext ());
17691821 patterns.add <DistributeBatchOuterToLayoutConversions>(patterns.getContext ());
17701822 patterns.add <DistributeStep>(patterns.getContext (), threadId, subgroupSize);
1771- patterns.add <DistributeCreateMask>(patterns. getContext (), threadId,
1772- subgroupSize);
1823+ patterns.add <DistributeCreateMask, DistributeConstantMask>(
1824+ patterns. getContext (), threadId, subgroupSize);
17731825}
17741826
17751827}; // namespace mlir::iree_compiler
0 commit comments