|
17 | 17 | //===----------------------------------------------------------------------===// |
18 | 18 |
|
19 | 19 | #include <math.h> |
| 20 | +#include <numeric> |
20 | 21 |
|
| 22 | +#include "mlir/Dialect/Traits.h" |
21 | 23 | #include "mlir/IR/Matchers.h" |
22 | 24 | #include "mlir/IR/PatternMatch.h" |
23 | 25 | #include "mlir/IR/TypeUtilities.h" |
@@ -1634,7 +1636,25 @@ struct RecomposeConcatPattern : public OpRewritePattern<ONNXConcatOp> { |
1634 | 1636 | // Rewrite pattern LayerNormalization |
1635 | 1637 | // ============================================================================= |
1636 | 1638 | namespace { |
1637 | | -bool isValueNoneOrConstZero(Value value) { |
| 1639 | + |
| 1640 | +// Checks if B is unidiretional broadcastable to A. Requires static shapes |
| 1641 | +[[nodiscard]] bool areUnidirectionalBroadcastCompatible(Type a, Type b) { |
| 1642 | + auto aShaped = dyn_cast<ShapedType>(a); |
| 1643 | + auto bShaped = dyn_cast<ShapedType>(b); |
| 1644 | + if (!aShaped || !bShaped || !aShaped.hasStaticShape() || |
| 1645 | + !bShaped.hasStaticShape()) { |
| 1646 | + return false; |
| 1647 | + } |
| 1648 | + SmallVector<int64_t> broadcastedShape; |
| 1649 | + if (!OpTrait::util::getBroadcastedShape( |
| 1650 | + aShaped.getShape(), bShaped.getShape(), broadcastedShape)) { |
| 1651 | + return false; |
| 1652 | + } |
| 1653 | + // For unidirectional broadcasting, a and the resulting shape need to match |
| 1654 | + return aShaped.getShape() == ArrayRef<int64_t>(broadcastedShape); |
| 1655 | +} |
| 1656 | + |
| 1657 | +[[nodiscard]] bool isValueNoneOrConstZero(Value value) { |
1638 | 1658 | if (!value) { |
1639 | 1659 | return false; |
1640 | 1660 | } |
@@ -1727,63 +1747,166 @@ struct PropagateBiasIntoLayerNormRewritePattern |
1727 | 1747 |
|
1728 | 1748 | LogicalResult matchAndRewrite( |
1729 | 1749 | ONNXAddOp addOp, PatternRewriter &rewriter) const final { |
| 1750 | + PatternRewriter::InsertionGuard guard(rewriter); |
1730 | 1751 | using namespace onnx_mlir; |
1731 | 1752 | Value y, bias; |
1732 | | - Operation *yLayerNormOp; |
1733 | | - Operation *ywbAddOp = addOp.getOperation(); |
| 1753 | + Operation *yLayerNormOp = nullptr; |
| 1754 | + Operation *reshapeOp = nullptr; |
| 1755 | + SmallVector<int64_t> newBiasShape; // only used if there is a reshape |
| 1756 | + |
1734 | 1757 | // Match |
1735 | 1758 | // %noBias = "onnx.NoValue"() |
1736 | 1759 | // %y, %mean, %invStdDev = "onnx.LayerNormalization"(%x, %scale, %noBias) |
1737 | 1760 | // {axis = 2 : si64, epsilon = 9.994E-6 : f32, stash_type = 1 : si64} |
| 1761 | + // optional reshape between norm and add |
1738 | 1762 | // %yBias = "onnx.Add"(%y, %bias) |
1739 | | - if (!onnx_mlir::operandOfOpDefinedBy<OP_TYPE>( |
1740 | | - yLayerNormOp, ywbAddOp, y, bias, 0) && |
1741 | | - !onnx_mlir::operandOfOpDefinedBy<OP_TYPE>( |
1742 | | - yLayerNormOp, ywbAddOp, bias, y, 1)) |
1743 | | - return reportFailure("missing y, layer norm op"); |
| 1763 | + |
| 1764 | + if (onnx_mlir::operandOfOpDefinedBy<ONNXReshapeOp>( |
| 1765 | + reshapeOp, addOp, y, bias, 0) || |
| 1766 | + onnx_mlir::operandOfOpDefinedBy<ONNXReshapeOp>( |
| 1767 | + reshapeOp, addOp, bias, y, 1)) { |
| 1768 | + yLayerNormOp = reshapeOp->getOperand(0).getDefiningOp<OP_TYPE>(); |
| 1769 | + if (!yLayerNormOp) { |
| 1770 | + return rewriter.notifyMatchFailure( |
| 1771 | + reshapeOp, "reshape op does not have a layer norm as input"); |
| 1772 | + } |
| 1773 | + if (!reshapeOp->hasOneUse()) { |
| 1774 | + return rewriter.notifyMatchFailure( |
| 1775 | + reshapeOp, "reshape op does not have a single use"); |
| 1776 | + } |
| 1777 | + } else { |
| 1778 | + if (!onnx_mlir::operandOfOpDefinedBy<OP_TYPE>( |
| 1779 | + yLayerNormOp, addOp, y, bias, 0) && |
| 1780 | + !onnx_mlir::operandOfOpDefinedBy<OP_TYPE>( |
| 1781 | + yLayerNormOp, addOp, bias, y, 1)) |
| 1782 | + return rewriter.notifyMatchFailure(addOp, "missing y, layer norm op"); |
| 1783 | + } |
| 1784 | + |
1744 | 1785 | // Study layer norm op; make sure its used only one and that bias is not |
1745 | 1786 | // used. |
1746 | | - if (!yLayerNormOp->hasOneUse()) |
1747 | | - return reportFailure("y/layer norm has too many uses"); |
| 1787 | + assert(yLayerNormOp && "yLayerNormOp should not be null"); |
| 1788 | + if (!yLayerNormOp->hasOneUse()) { |
| 1789 | + return rewriter.notifyMatchFailure( |
| 1790 | + yLayerNormOp, "y/layer norm has too many uses"); |
| 1791 | + } |
1748 | 1792 | auto lnOp = mlir::cast<OP_TYPE>(yLayerNormOp); |
1749 | 1793 | if (!isValueNoneOrConstZero(lnOp.getB())) |
1750 | | - return reportFailure("layer norm already has a bias"); |
| 1794 | + return rewriter.notifyMatchFailure(lnOp, "layer norm already has a bias"); |
| 1795 | + |
| 1796 | + if (reshapeOp) { |
| 1797 | + // if we have a reshape, check that the add is not changing the shape by |
| 1798 | + // broadcasting |
| 1799 | + auto reshapeResultType = |
| 1800 | + dyn_cast<ShapedType>(reshapeOp->getResult(0).getType()); |
| 1801 | + auto addResultType = dyn_cast<ShapedType>(addOp->getResult(0).getType()); |
| 1802 | + if (!reshapeResultType || !addResultType || |
| 1803 | + !reshapeResultType.hasStaticShape() || |
| 1804 | + !addResultType.hasStaticShape() || |
| 1805 | + reshapeResultType.getShape() != addResultType.getShape()) { |
| 1806 | + return rewriter.notifyMatchFailure( |
| 1807 | + addOp, "incompatible shapes, add is broadcasting"); |
| 1808 | + } |
| 1809 | + // Check that the bias is only on a single dimension, that is not affected |
| 1810 | + // by the reshape. The bias could be multi-dimentional, but this increases |
| 1811 | + // the complexity and was not seen in models |
| 1812 | + auto biasType = dyn_cast<ShapedType>(bias.getType()); |
| 1813 | + if (!biasType || !biasType.hasStaticShape()) { |
| 1814 | + return rewriter.notifyMatchFailure( |
| 1815 | + addOp, "bias has not a static shape"); |
| 1816 | + } |
| 1817 | + |
| 1818 | + SmallVector<int64_t> biasRankFixedShape; |
| 1819 | + biasRankFixedShape.append( |
| 1820 | + addResultType.getRank() - biasType.getRank(), 1); |
| 1821 | + biasRankFixedShape.append( |
| 1822 | + biasType.getShape().begin(), biasType.getShape().end()); |
| 1823 | + |
| 1824 | + // biasRankFixedShape should have exactly one dimension that is not one |
| 1825 | + std::optional<int64_t> biasDim; |
| 1826 | + for (auto [idx, dimSize] : enumerate(biasRankFixedShape)) { |
| 1827 | + if (dimSize != 1) { |
| 1828 | + if (biasDim) { |
| 1829 | + return rewriter.notifyMatchFailure( |
| 1830 | + addOp, "bias has more than one non-one dimension"); |
| 1831 | + } |
| 1832 | + biasDim = idx; |
| 1833 | + } |
| 1834 | + } |
| 1835 | + if (!biasDim) { |
| 1836 | + return rewriter.notifyMatchFailure( |
| 1837 | + addOp, "bias has no non-one dimension"); |
| 1838 | + } |
| 1839 | + |
| 1840 | + const auto biasShapeUntilDim = |
| 1841 | + ArrayRef<int64_t>(reshapeResultType.getShape()) |
| 1842 | + .slice(0, *biasDim + 1); |
| 1843 | + const uint64_t biasShapeRelevantSize = |
| 1844 | + std::accumulate(biasShapeUntilDim.begin(), biasShapeUntilDim.end(), 1, |
| 1845 | + std::multiplies<uint64_t>()); |
| 1846 | + |
| 1847 | + // The bias dim should be not affected by the reshape. We need to map it |
| 1848 | + // back through it. |
| 1849 | + size_t reshapeInBiasDim; |
| 1850 | + auto reshapeInType = |
| 1851 | + dyn_cast<ShapedType>(reshapeOp->getOperand(0).getType()); |
| 1852 | + if (!reshapeInType || !reshapeInType.hasStaticShape()) { |
| 1853 | + return rewriter.notifyMatchFailure( |
| 1854 | + addOp, "reshape input has not a static shape"); |
| 1855 | + } |
| 1856 | + const auto reshapeInShape = reshapeInType.getShape(); |
| 1857 | + |
| 1858 | + // trace the dim through the reshape |
| 1859 | + uint64_t acc = 1; |
| 1860 | + for (auto [idx, dimSize] : enumerate(reshapeInShape)) { |
| 1861 | + acc *= dimSize; |
| 1862 | + if (acc == biasShapeRelevantSize) { |
| 1863 | + if (dimSize != biasRankFixedShape[*biasDim]) { |
| 1864 | + return rewriter.notifyMatchFailure( |
| 1865 | + addOp, "bias shape is not compatible with reshape input"); |
| 1866 | + } |
| 1867 | + reshapeInBiasDim = idx; |
| 1868 | + break; |
| 1869 | + } |
| 1870 | + if (acc > biasShapeRelevantSize) { |
| 1871 | + return rewriter.notifyMatchFailure( |
| 1872 | + addOp, "bias shape is not compatible with reshape input"); |
| 1873 | + } |
| 1874 | + } |
| 1875 | + |
| 1876 | + newBiasShape.push_back(reshapeInShape[reshapeInBiasDim]); |
| 1877 | + newBiasShape.append(reshapeInShape.size() - reshapeInBiasDim - 1, 1); |
| 1878 | + } |
1751 | 1879 |
|
1752 | 1880 | // Norms only support unidirectional broadcating from bias to y |
1753 | | - const auto yType = dyn_cast<ShapedType>(y.getType()); |
1754 | | - const auto addType = dyn_cast<ShapedType>(addOp.getType()); |
1755 | | - if (!yType || !addType || !yType.hasStaticShape() || |
1756 | | - !addType.hasStaticShape() || yType.getShape() != addType.getShape()) { |
1757 | | - return rewriter.notifyMatchFailure(addOp, "incompatible shapes"); |
| 1881 | + if (!reshapeOp && !areUnidirectionalBroadcastCompatible( |
| 1882 | + lnOp.getX().getType(), bias.getType())) { |
| 1883 | + return rewriter.notifyMatchFailure(addOp, |
| 1884 | + "layer norm and bias are not unidirectional broadcast compatible"); |
1758 | 1885 | } |
1759 | | - // We are fine. |
1760 | | - Value x = lnOp.getX(); |
1761 | | - Value scale = lnOp.getScale(); |
1762 | | - FloatAttr epsilon = lnOp.getEpsilonAttr(); |
1763 | | - int64_t axis = lnOp.getAxis(); |
1764 | | - LLVM_DEBUG(llvm::dbgs() << "LayerNorm from add, axis : " << axis << "\n"); |
1765 | | - |
1766 | | - // Replace |
1767 | | - MultiDialectBuilder<OnnxBuilder> create( |
1768 | | - rewriter, rewriter.getFusedLoc({lnOp.getLoc(), addOp->getLoc()})); |
1769 | | - Type xType = x.getType(); |
1770 | | - Value res; |
1771 | | - if constexpr (std::is_same<OP_TYPE, ONNXLayerNormalizationOp>::value) |
1772 | | - res = create.onnx.layerNorm(xType, x, scale, bias, axis, epsilon); |
1773 | | - else if constexpr (std::is_same<OP_TYPE, |
1774 | | - ONNXRMSLayerNormalizationOp>::value) |
1775 | | - res = create.onnx.RMSLayerNorm(xType, x, scale, bias, axis, epsilon); |
1776 | | - else |
1777 | | - llvm_unreachable("unsupported op"); |
1778 | | - rewriter.replaceOp(addOp, res); |
1779 | | - return success(); |
1780 | | - } |
1781 | 1886 |
|
1782 | | -private: |
1783 | | - LogicalResult reportFailure(std::string msg) const { |
1784 | | - // Can disable line below if not needed. |
1785 | | - LLVM_DEBUG(llvm::dbgs() << "LayerNorm failure:" << msg << "\n"); |
1786 | | - return failure(); |
| 1887 | + rewriter.moveOpAfter( |
| 1888 | + lnOp, addOp); // Make sure we can use the const of the mul |
| 1889 | + rewriter.setInsertionPoint(addOp); |
| 1890 | + if (reshapeOp) { |
| 1891 | + onnx_mlir::MultiDialectBuilder<onnx_mlir::OnnxBuilder> create( |
| 1892 | + rewriter, reshapeOp->getLoc()); |
| 1893 | + const auto newShapeConst = create.onnx.constantInt64(newBiasShape); |
| 1894 | + bias = create.onnx.reshape( |
| 1895 | + RankedTensorType::get( |
| 1896 | + newBiasShape, cast<ShapedType>(bias.getType()).getElementType()), |
| 1897 | + bias, newShapeConst); |
| 1898 | + } |
| 1899 | + rewriter.modifyOpInPlace(lnOp, [&] { |
| 1900 | + lnOp.setOperand(/*bias*/ 2, bias); |
| 1901 | + lnOp->setLoc(rewriter.getFusedLoc({lnOp.getLoc(), addOp->getLoc()})); |
| 1902 | + }); |
| 1903 | + if (reshapeOp) { |
| 1904 | + rewriter.moveOpAfter(reshapeOp, lnOp); |
| 1905 | + rewriter.replaceOp(addOp, reshapeOp->getResult(0)); |
| 1906 | + } else { |
| 1907 | + rewriter.replaceOp(addOp, lnOp.getY()); |
| 1908 | + } |
| 1909 | + return success(); |
1787 | 1910 | } |
1788 | 1911 | }; |
1789 | 1912 |
|
@@ -1930,7 +2053,8 @@ struct RemoveInstanceNormPattern |
1930 | 2053 | rewriter, instanceNormOp.getLoc()); |
1931 | 2054 | int64_t axis = nonSpacialRank; |
1932 | 2055 | int64_t numInNorm = inputRank - axis; |
1933 | | - // Unsqueeze scale/bias from [C] to [C x 1 x 1 x ... x 1] with numInNorm 1s. |
| 2056 | + // Unsqueeze scale/bias from [C] to [C x 1 x 1 x ... x 1] with numInNorm |
| 2057 | + // 1s. |
1934 | 2058 | llvm::SmallVector<int64_t, 4> axesList, biasScaleShape; |
1935 | 2059 | biasScaleShape.emplace_back(C); |
1936 | 2060 | for (int64_t i = 1; i <= numInNorm; ++i) { |
|
0 commit comments