|
17 | 17 | //===----------------------------------------------------------------------===// |
18 | 18 |
|
19 | 19 | #include <math.h> |
| 20 | +#include <numeric> |
20 | 21 |
|
| 22 | +#include "mlir/Dialect/Traits.h" |
21 | 23 | #include "mlir/IR/Matchers.h" |
22 | 24 | #include "mlir/IR/PatternMatch.h" |
23 | 25 | #include "mlir/IR/TypeUtilities.h" |
@@ -1633,66 +1635,271 @@ struct RecomposeConcatPattern : public OpRewritePattern<ONNXConcatOp> { |
1633 | 1635 | // ============================================================================= |
1634 | 1636 | // Rewrite pattern LayerNormalization |
1635 | 1637 | // ============================================================================= |
| 1638 | +namespace { |
1636 | 1639 |
|
1637 | | -template <typename OP_TYPE> |
1638 | | -struct PropagateBiasIntoLayerNormRewritePattern |
1639 | | - : public OpRewritePattern<ONNXAddOp> { |
1640 | | - using OpRewritePattern<ONNXAddOp>::OpRewritePattern; |
| 1640 | +// Checks if B is unidiretional broadcastable to A. Requires static shapes |
| 1641 | +[[nodiscard]] bool areUnidirectionalBroadcastCompatible(Type a, Type b) { |
| 1642 | + auto aShaped = dyn_cast<ShapedType>(a); |
| 1643 | + auto bShaped = dyn_cast<ShapedType>(b); |
| 1644 | + if (!aShaped || !bShaped || !aShaped.hasStaticShape() || |
| 1645 | + !bShaped.hasStaticShape()) { |
| 1646 | + return false; |
| 1647 | + } |
| 1648 | + SmallVector<int64_t> broadcastedShape; |
| 1649 | + if (!OpTrait::util::getBroadcastedShape( |
| 1650 | + aShaped.getShape(), bShaped.getShape(), broadcastedShape)) { |
| 1651 | + return false; |
| 1652 | + } |
| 1653 | + // For unidirectional broadcasting, a and the resulting shape need to match |
| 1654 | + return aShaped.getShape() == ArrayRef<int64_t>(broadcastedShape); |
| 1655 | +} |
1641 | 1656 |
|
1642 | | - PropagateBiasIntoLayerNormRewritePattern(MLIRContext *context) |
1643 | | - : OpRewritePattern(context) {} |
| 1657 | +[[nodiscard]] bool isValueNoneOrConstZero(Value value) { |
| 1658 | + if (!value) { |
| 1659 | + return false; |
| 1660 | + } |
| 1661 | + if (isNoneValue(value)) { |
| 1662 | + return true; |
| 1663 | + } |
| 1664 | + auto elementsAttr = getElementAttributeFromONNXValue(value); |
| 1665 | + if (!elementsAttr) { |
| 1666 | + return false; |
| 1667 | + } |
| 1668 | + if (!elementsAttr.isSplat()) { |
| 1669 | + return false; |
| 1670 | + } |
| 1671 | + return elementsAttr.template getSplatValue<APFloat>().isZero(); |
| 1672 | +} |
| 1673 | + |
| 1674 | +template <typename LN_TYPE, typename MATCH_OP_TYPE, |
| 1675 | + size_t OPERAND_TO_MODIFY_INDEX> |
| 1676 | +struct PropagateBiasOrScaleIntoLayerNormRewritePatternBase |
| 1677 | + : public OpRewritePattern<MATCH_OP_TYPE> { |
| 1678 | + using OpRewritePattern<MATCH_OP_TYPE>::OpRewritePattern; |
| 1679 | + |
| 1680 | + static_assert(std::is_same_v<MATCH_OP_TYPE, ONNXAddOp> || |
| 1681 | + std::is_same_v<MATCH_OP_TYPE, ONNXMulOp>, |
| 1682 | + "MATCH_OP_TYPE must be ONNXAddOp or ONNXMulOp"); |
| 1683 | + |
| 1684 | + [[nodiscard]] virtual bool doExisitingScaleAndBiasAllowFusion( |
| 1685 | + LN_TYPE lnOp) const = 0; |
| 1686 | + |
| 1687 | + FailureOr<SmallVector<int64_t>> verifyAndCalculateNewReshapeShapes( |
| 1688 | + Operation *reshapeOp, MATCH_OP_TYPE matchOp, PatternRewriter &rewriter, |
| 1689 | + Value scaleOrBias) const { |
| 1690 | + // if we have a reshape, check that the add/mul is not changing the shape |
| 1691 | + // by broadcasting |
| 1692 | + auto reshapeResultType = |
| 1693 | + dyn_cast<ShapedType>(reshapeOp->getResult(0).getType()); |
| 1694 | + auto addOrMulResultType = |
| 1695 | + dyn_cast<ShapedType>(matchOp->getResult(0).getType()); |
| 1696 | + if (!reshapeResultType || !addOrMulResultType || |
| 1697 | + !reshapeResultType.hasStaticShape() || |
| 1698 | + !addOrMulResultType.hasStaticShape() || |
| 1699 | + reshapeResultType.getShape() != addOrMulResultType.getShape()) { |
| 1700 | + return rewriter.notifyMatchFailure( |
| 1701 | + matchOp, "incompatible shapes, add is broadcasting"); |
| 1702 | + } |
| 1703 | + // Check that the bias/scale is only on a single dimension, that is not |
| 1704 | + // affected by the reshape. The bias/scale could be multi-dimentional, but |
| 1705 | + // this increases the complexity and was not seen in models |
| 1706 | + auto scaleOrBiasType = dyn_cast<ShapedType>(scaleOrBias.getType()); |
| 1707 | + if (!scaleOrBiasType || !scaleOrBiasType.hasStaticShape()) { |
| 1708 | + return rewriter.notifyMatchFailure( |
| 1709 | + matchOp, "bias/scale has not a static shape"); |
| 1710 | + } |
| 1711 | + |
| 1712 | + SmallVector<int64_t> biasOrScaleRankFixedShape; |
| 1713 | + biasOrScaleRankFixedShape.append( |
| 1714 | + addOrMulResultType.getRank() - scaleOrBiasType.getRank(), 1); |
| 1715 | + biasOrScaleRankFixedShape.append( |
| 1716 | + scaleOrBiasType.getShape().begin(), scaleOrBiasType.getShape().end()); |
| 1717 | + |
| 1718 | + // biasOrScaleRankFixedShape should have exactly one dimension that is not |
| 1719 | + // one |
| 1720 | + std::optional<int64_t> afterReshapeComputationDim; |
| 1721 | + for (auto [idx, dimSize] : enumerate(biasOrScaleRankFixedShape)) { |
| 1722 | + if (dimSize != 1) { |
| 1723 | + if (afterReshapeComputationDim) { |
| 1724 | + return rewriter.notifyMatchFailure( |
| 1725 | + matchOp, "scale/bias has more than one non-one dimension"); |
| 1726 | + } |
| 1727 | + afterReshapeComputationDim = idx; |
| 1728 | + } |
| 1729 | + } |
| 1730 | + if (!afterReshapeComputationDim) { |
| 1731 | + return rewriter.notifyMatchFailure( |
| 1732 | + matchOp, "scale/bias has no non-one dimension"); |
| 1733 | + } |
| 1734 | + |
| 1735 | + const auto shapeIncludingComputationDim = |
| 1736 | + ArrayRef<int64_t>(reshapeResultType.getShape()) |
| 1737 | + .slice(0, *afterReshapeComputationDim + 1); |
| 1738 | + const uint64_t computationRelevantSize = |
| 1739 | + std::accumulate(shapeIncludingComputationDim.begin(), |
| 1740 | + shapeIncludingComputationDim.end(), 1, std::multiplies<uint64_t>()); |
| 1741 | + |
| 1742 | + // The bias/scale dim should be not affected by the reshape. We need to |
| 1743 | + // map it back through it. |
| 1744 | + size_t reshapeInComputationDim; |
| 1745 | + auto reshapeInType = |
| 1746 | + dyn_cast<ShapedType>(reshapeOp->getOperand(0).getType()); |
| 1747 | + if (!reshapeInType || !reshapeInType.hasStaticShape()) { |
| 1748 | + return rewriter.notifyMatchFailure( |
| 1749 | + matchOp, "reshape input has not a static shape"); |
| 1750 | + } |
| 1751 | + const auto reshapeInShape = reshapeInType.getShape(); |
| 1752 | + |
| 1753 | + // trace the dim through the reshape |
| 1754 | + uint64_t acc = 1; |
| 1755 | + for (auto [idx, dimSize] : enumerate(reshapeInShape)) { |
| 1756 | + acc *= dimSize; |
| 1757 | + if (acc == computationRelevantSize) { |
| 1758 | + if (dimSize != biasOrScaleRankFixedShape[*afterReshapeComputationDim]) { |
| 1759 | + return rewriter.notifyMatchFailure( |
| 1760 | + matchOp, "bias/scale shape is not compatible with reshape input"); |
| 1761 | + } |
| 1762 | + reshapeInComputationDim = idx; |
| 1763 | + break; |
| 1764 | + } |
| 1765 | + if (acc > computationRelevantSize) { |
| 1766 | + return rewriter.notifyMatchFailure( |
| 1767 | + matchOp, "bias/scale shape is not compatible with reshape input"); |
| 1768 | + } |
| 1769 | + } |
| 1770 | + SmallVector<int64_t> newScaleOrBiasShape; |
| 1771 | + newScaleOrBiasShape.push_back(reshapeInShape[reshapeInComputationDim]); |
| 1772 | + newScaleOrBiasShape.append( |
| 1773 | + reshapeInShape.size() - reshapeInComputationDim - 1, 1); |
| 1774 | + return newScaleOrBiasShape; |
| 1775 | + } |
1644 | 1776 |
|
1645 | 1777 | LogicalResult matchAndRewrite( |
1646 | | - ONNXAddOp addOp, PatternRewriter &rewriter) const final { |
| 1778 | + MATCH_OP_TYPE matchOp, PatternRewriter &rewriter) const final { |
| 1779 | + PatternRewriter::InsertionGuard guard(rewriter); |
1647 | 1780 | using namespace onnx_mlir; |
1648 | | - Value y, bias; |
1649 | | - Operation *yLayerNormOp; |
1650 | | - Operation *ywbAddOp = addOp.getOperation(); |
| 1781 | + Value y, scaleOrBias; |
| 1782 | + Operation *yLayerNormOp = nullptr; |
| 1783 | + Operation *reshapeOp = nullptr; |
| 1784 | + SmallVector<int64_t> newScaleOrBiasShape; // only used if there is a reshape |
| 1785 | + |
1651 | 1786 | // Match |
1652 | 1787 | // %noBias = "onnx.NoValue"() |
1653 | 1788 | // %y, %mean, %invStdDev = "onnx.LayerNormalization"(%x, %scale, %noBias) |
1654 | 1789 | // {axis = 2 : si64, epsilon = 9.994E-6 : f32, stash_type = 1 : si64} |
1655 | | - // %yBias = "onnx.Add"(%y, %bias) |
1656 | | - if (!onnx_mlir::operandOfOpDefinedBy<OP_TYPE>( |
1657 | | - yLayerNormOp, ywbAddOp, y, bias, 0) && |
1658 | | - !onnx_mlir::operandOfOpDefinedBy<OP_TYPE>( |
1659 | | - yLayerNormOp, ywbAddOp, bias, y, 1)) |
1660 | | - return reportFailure("missing y, layer norm op"); |
| 1790 | + // optional reshape between norm and add |
| 1791 | + // %yBias = "onnx.Add/onnx.Mul"(%y, %scaleOrBias) |
| 1792 | + |
| 1793 | + if (onnx_mlir::operandOfOpDefinedBy<ONNXReshapeOp>( |
| 1794 | + reshapeOp, matchOp, y, scaleOrBias, 0) || |
| 1795 | + onnx_mlir::operandOfOpDefinedBy<ONNXReshapeOp>( |
| 1796 | + reshapeOp, matchOp, scaleOrBias, y, 1)) { |
| 1797 | + yLayerNormOp = reshapeOp->getOperand(0).getDefiningOp<LN_TYPE>(); |
| 1798 | + if (!yLayerNormOp) { |
| 1799 | + return rewriter.notifyMatchFailure( |
| 1800 | + reshapeOp, "reshape op does not have a layer norm as input"); |
| 1801 | + } |
| 1802 | + if (!reshapeOp->hasOneUse()) { |
| 1803 | + return rewriter.notifyMatchFailure( |
| 1804 | + reshapeOp, "reshape op does not have a single use"); |
| 1805 | + } |
| 1806 | + } else { |
| 1807 | + if (!onnx_mlir::operandOfOpDefinedBy<LN_TYPE>( |
| 1808 | + yLayerNormOp, matchOp, y, scaleOrBias, 0) && |
| 1809 | + !onnx_mlir::operandOfOpDefinedBy<LN_TYPE>( |
| 1810 | + yLayerNormOp, matchOp, scaleOrBias, y, 1)) |
| 1811 | + return rewriter.notifyMatchFailure(matchOp, "missing y, layer norm op"); |
| 1812 | + } |
| 1813 | + |
1661 | 1814 | // Study layer norm op; make sure its used only one and that bias is not |
1662 | 1815 | // used. |
1663 | | - if (!yLayerNormOp->hasOneUse()) |
1664 | | - return reportFailure("y/layer norm has too many uses"); |
1665 | | - auto lnOp = mlir::cast<OP_TYPE>(yLayerNormOp); |
1666 | | - if (!onnx_mlir::isNoneValue(lnOp.getB())) |
1667 | | - return reportFailure("layer norm already has a bias"); |
1668 | | - // We are fine. |
1669 | | - Value x = lnOp.getX(); |
1670 | | - Value scale = lnOp.getScale(); |
1671 | | - FloatAttr epsilon = lnOp.getEpsilonAttr(); |
1672 | | - int64_t axis = lnOp.getAxis(); |
1673 | | - LLVM_DEBUG(llvm::dbgs() << "LayerNorm from add, axis : " << axis << "\n"); |
1674 | | - |
1675 | | - // Replace |
1676 | | - MultiDialectBuilder<OnnxBuilder> create( |
1677 | | - rewriter, rewriter.getFusedLoc({lnOp.getLoc(), addOp->getLoc()})); |
1678 | | - Type xType = x.getType(); |
1679 | | - Value res; |
1680 | | - if constexpr (std::is_same<OP_TYPE, ONNXLayerNormalizationOp>::value) |
1681 | | - res = create.onnx.layerNorm(xType, x, scale, bias, axis, epsilon); |
1682 | | - else if constexpr (std::is_same<OP_TYPE, |
1683 | | - ONNXRMSLayerNormalizationOp>::value) |
1684 | | - res = create.onnx.RMSLayerNorm(xType, x, scale, bias, axis, epsilon); |
1685 | | - else |
1686 | | - llvm_unreachable("unsupported op"); |
1687 | | - rewriter.replaceOp(addOp, res); |
| 1816 | + assert(yLayerNormOp && "yLayerNormOp should not be null"); |
| 1817 | + if (!yLayerNormOp->hasOneUse()) { |
| 1818 | + return rewriter.notifyMatchFailure( |
| 1819 | + yLayerNormOp, "y/layer norm has too many uses"); |
| 1820 | + } |
| 1821 | + auto lnOp = mlir::cast<LN_TYPE>(yLayerNormOp); |
| 1822 | + if (!doExisitingScaleAndBiasAllowFusion(lnOp)) |
| 1823 | + return rewriter.notifyMatchFailure( |
| 1824 | + lnOp, "existing scale and bias do not allow fusion"); |
| 1825 | + |
| 1826 | + if (reshapeOp) { |
| 1827 | + auto newShape = verifyAndCalculateNewReshapeShapes( |
| 1828 | + reshapeOp, matchOp, rewriter, scaleOrBias); |
| 1829 | + if (failed(newShape)) { |
| 1830 | + return failure(); |
| 1831 | + } |
| 1832 | + newScaleOrBiasShape = std::move(*newShape); |
| 1833 | + } |
| 1834 | + |
| 1835 | + // Norms only support unidirectional broadcasting to x |
| 1836 | + if (!reshapeOp && !areUnidirectionalBroadcastCompatible( |
| 1837 | + lnOp.getX().getType(), scaleOrBias.getType())) { |
| 1838 | + return rewriter.notifyMatchFailure(matchOp, |
| 1839 | + "layer norm and bias/scale are not unidirectional broadcast " |
| 1840 | + "compatible"); |
| 1841 | + } |
| 1842 | + |
| 1843 | + rewriter.moveOpAfter( |
| 1844 | + lnOp, matchOp); // Make sure we can use the const of the mul |
| 1845 | + rewriter.setInsertionPoint(matchOp); |
| 1846 | + if (reshapeOp) { |
| 1847 | + onnx_mlir::MultiDialectBuilder<onnx_mlir::OnnxBuilder> create( |
| 1848 | + rewriter, reshapeOp->getLoc()); |
| 1849 | + const auto newShapeConst = create.onnx.constantInt64(newScaleOrBiasShape); |
| 1850 | + scaleOrBias = create.onnx.reshape( |
| 1851 | + RankedTensorType::get(newScaleOrBiasShape, |
| 1852 | + cast<ShapedType>(scaleOrBias.getType()).getElementType()), |
| 1853 | + scaleOrBias, newShapeConst); |
| 1854 | + } |
| 1855 | + rewriter.modifyOpInPlace(lnOp, [&] { |
| 1856 | + lnOp.setOperand(OPERAND_TO_MODIFY_INDEX, scaleOrBias); |
| 1857 | + lnOp->setLoc(rewriter.getFusedLoc({lnOp.getLoc(), matchOp->getLoc()})); |
| 1858 | + }); |
| 1859 | + if (reshapeOp) { |
| 1860 | + rewriter.moveOpAfter(reshapeOp, lnOp); |
| 1861 | + rewriter.replaceOp(matchOp, reshapeOp->getResult(0)); |
| 1862 | + } else { |
| 1863 | + rewriter.replaceOp(matchOp, lnOp.getY()); |
| 1864 | + } |
1688 | 1865 | return success(); |
1689 | 1866 | } |
| 1867 | +}; |
1690 | 1868 |
|
1691 | | -private: |
1692 | | - LogicalResult reportFailure(std::string msg) const { |
1693 | | - // Can disable line below if not needed. |
1694 | | - LLVM_DEBUG(llvm::dbgs() << "LayerNorm failure:" << msg << "\n"); |
1695 | | - return failure(); |
| 1869 | +} // namespace |
| 1870 | + |
| 1871 | +template <typename LN_TYPE> |
| 1872 | +struct PropagateScaleIntoLayerNormPattern |
| 1873 | + : public PropagateBiasOrScaleIntoLayerNormRewritePatternBase<LN_TYPE, |
| 1874 | + ONNXMulOp, /*scale*/ 1> { |
| 1875 | + using PropagateBiasOrScaleIntoLayerNormRewritePatternBase<LN_TYPE, ONNXMulOp, |
| 1876 | + /*scale*/ 1>::PropagateBiasOrScaleIntoLayerNormRewritePatternBase; |
| 1877 | + |
| 1878 | + bool doExisitingScaleAndBiasAllowFusion(LN_TYPE lnOp) const override { |
| 1879 | + if (!isValueNoneOrConstZero(lnOp.getB())) { |
| 1880 | + return false; |
| 1881 | + } |
| 1882 | + |
| 1883 | + const auto elementsAttr = getElementAttributeFromONNXValue(lnOp.getScale()); |
| 1884 | + if (!elementsAttr) { |
| 1885 | + return false; |
| 1886 | + } |
| 1887 | + if (!elementsAttr.isSplat()) { |
| 1888 | + return false; |
| 1889 | + } |
| 1890 | + return elementsAttr.template getSplatValue<APFloat>().isExactlyValue(1.0); |
| 1891 | + } |
| 1892 | +}; |
| 1893 | + |
| 1894 | +template <typename LN_TYPE> |
| 1895 | +struct PropagateBiasIntoLayerNormRewritePattern |
| 1896 | + : public PropagateBiasOrScaleIntoLayerNormRewritePatternBase<LN_TYPE, |
| 1897 | + ONNXAddOp, /*bias*/ 2> { |
| 1898 | + using PropagateBiasOrScaleIntoLayerNormRewritePatternBase<LN_TYPE, ONNXAddOp, |
| 1899 | + /*bias*/ 2>::PropagateBiasOrScaleIntoLayerNormRewritePatternBase; |
| 1900 | + |
| 1901 | + bool doExisitingScaleAndBiasAllowFusion(LN_TYPE lnOp) const override { |
| 1902 | + return isValueNoneOrConstZero(lnOp.getB()); |
1696 | 1903 | } |
1697 | 1904 | }; |
1698 | 1905 |
|
@@ -1839,7 +2046,8 @@ struct RemoveInstanceNormPattern |
1839 | 2046 | rewriter, instanceNormOp.getLoc()); |
1840 | 2047 | int64_t axis = nonSpacialRank; |
1841 | 2048 | int64_t numInNorm = inputRank - axis; |
1842 | | - // Unsqueeze scale/bias from [C] to [C x 1 x 1 x ... x 1] with numInNorm 1s. |
| 2049 | + // Unsqueeze scale/bias from [C] to [C x 1 x 1 x ... x 1] with numInNorm |
| 2050 | + // 1s. |
1843 | 2051 | llvm::SmallVector<int64_t, 4> axesList, biasScaleShape; |
1844 | 2052 | biasScaleShape.emplace_back(C); |
1845 | 2053 | for (int64_t i = 1; i <= numInNorm; ++i) { |
@@ -2189,6 +2397,11 @@ void ONNXAddOp::getCanonicalizationPatterns( |
2189 | 2397 | results.insert<FuseAddConvNullBiasPattern>(context); |
2190 | 2398 | results.insert<BinaryOpBroadcastAxisPattern<ONNXAddOp>>(context); |
2191 | 2399 | results.insert<PropagateScalarConstantExpandPattern<ONNXAddOp>>(context); |
| 2400 | + results.insert<PropagateScaleIntoLayerNormPattern<ONNXLayerNormalizationOp>>( |
| 2401 | + context); |
| 2402 | + results |
| 2403 | + .insert<PropagateScaleIntoLayerNormPattern<ONNXRMSLayerNormalizationOp>>( |
| 2404 | + context); |
2192 | 2405 | results.insert< |
2193 | 2406 | PropagateBiasIntoLayerNormRewritePattern<ONNXLayerNormalizationOp>>( |
2194 | 2407 | context); |
|
0 commit comments