Skip to content

Commit 2129d5d

Browse files
authored
Merge pull request #416 from Xilinx/jrickert.fuse_norms
Fuse Mul/Add into norms
2 parents 61843f1 + edd033d commit 2129d5d

File tree

3 files changed

+520
-62
lines changed

3 files changed

+520
-62
lines changed

src/Dialect/ONNX/ONNXOps/Canonicalize.cpp

Lines changed: 260 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
//===----------------------------------------------------------------------===//
1818

1919
#include <math.h>
20+
#include <numeric>
2021

22+
#include "mlir/Dialect/Traits.h"
2123
#include "mlir/IR/Matchers.h"
2224
#include "mlir/IR/PatternMatch.h"
2325
#include "mlir/IR/TypeUtilities.h"
@@ -1633,66 +1635,271 @@ struct RecomposeConcatPattern : public OpRewritePattern<ONNXConcatOp> {
16331635
// =============================================================================
16341636
// Rewrite pattern LayerNormalization
16351637
// =============================================================================
1638+
namespace {
16361639

1637-
template <typename OP_TYPE>
1638-
struct PropagateBiasIntoLayerNormRewritePattern
1639-
: public OpRewritePattern<ONNXAddOp> {
1640-
using OpRewritePattern<ONNXAddOp>::OpRewritePattern;
1640+
// Checks if B is unidiretional broadcastable to A. Requires static shapes
1641+
[[nodiscard]] bool areUnidirectionalBroadcastCompatible(Type a, Type b) {
1642+
auto aShaped = dyn_cast<ShapedType>(a);
1643+
auto bShaped = dyn_cast<ShapedType>(b);
1644+
if (!aShaped || !bShaped || !aShaped.hasStaticShape() ||
1645+
!bShaped.hasStaticShape()) {
1646+
return false;
1647+
}
1648+
SmallVector<int64_t> broadcastedShape;
1649+
if (!OpTrait::util::getBroadcastedShape(
1650+
aShaped.getShape(), bShaped.getShape(), broadcastedShape)) {
1651+
return false;
1652+
}
1653+
// For unidirectional broadcasting, a and the resulting shape need to match
1654+
return aShaped.getShape() == ArrayRef<int64_t>(broadcastedShape);
1655+
}
16411656

1642-
PropagateBiasIntoLayerNormRewritePattern(MLIRContext *context)
1643-
: OpRewritePattern(context) {}
1657+
[[nodiscard]] bool isValueNoneOrConstZero(Value value) {
1658+
if (!value) {
1659+
return false;
1660+
}
1661+
if (isNoneValue(value)) {
1662+
return true;
1663+
}
1664+
auto elementsAttr = getElementAttributeFromONNXValue(value);
1665+
if (!elementsAttr) {
1666+
return false;
1667+
}
1668+
if (!elementsAttr.isSplat()) {
1669+
return false;
1670+
}
1671+
return elementsAttr.template getSplatValue<APFloat>().isZero();
1672+
}
1673+
1674+
template <typename LN_TYPE, typename MATCH_OP_TYPE,
1675+
size_t OPERAND_TO_MODIFY_INDEX>
1676+
struct PropagateBiasOrScaleIntoLayerNormRewritePatternBase
1677+
: public OpRewritePattern<MATCH_OP_TYPE> {
1678+
using OpRewritePattern<MATCH_OP_TYPE>::OpRewritePattern;
1679+
1680+
static_assert(std::is_same_v<MATCH_OP_TYPE, ONNXAddOp> ||
1681+
std::is_same_v<MATCH_OP_TYPE, ONNXMulOp>,
1682+
"MATCH_OP_TYPE must be ONNXAddOp or ONNXMulOp");
1683+
1684+
[[nodiscard]] virtual bool doExisitingScaleAndBiasAllowFusion(
1685+
LN_TYPE lnOp) const = 0;
1686+
1687+
FailureOr<SmallVector<int64_t>> verifyAndCalculateNewReshapeShapes(
1688+
Operation *reshapeOp, MATCH_OP_TYPE matchOp, PatternRewriter &rewriter,
1689+
Value scaleOrBias) const {
1690+
// if we have a reshape, check that the add/mul is not changing the shape
1691+
// by broadcasting
1692+
auto reshapeResultType =
1693+
dyn_cast<ShapedType>(reshapeOp->getResult(0).getType());
1694+
auto addOrMulResultType =
1695+
dyn_cast<ShapedType>(matchOp->getResult(0).getType());
1696+
if (!reshapeResultType || !addOrMulResultType ||
1697+
!reshapeResultType.hasStaticShape() ||
1698+
!addOrMulResultType.hasStaticShape() ||
1699+
reshapeResultType.getShape() != addOrMulResultType.getShape()) {
1700+
return rewriter.notifyMatchFailure(
1701+
matchOp, "incompatible shapes, add is broadcasting");
1702+
}
1703+
// Check that the bias/scale is only on a single dimension, that is not
1704+
// affected by the reshape. The bias/scale could be multi-dimentional, but
1705+
// this increases the complexity and was not seen in models
1706+
auto scaleOrBiasType = dyn_cast<ShapedType>(scaleOrBias.getType());
1707+
if (!scaleOrBiasType || !scaleOrBiasType.hasStaticShape()) {
1708+
return rewriter.notifyMatchFailure(
1709+
matchOp, "bias/scale has not a static shape");
1710+
}
1711+
1712+
SmallVector<int64_t> biasOrScaleRankFixedShape;
1713+
biasOrScaleRankFixedShape.append(
1714+
addOrMulResultType.getRank() - scaleOrBiasType.getRank(), 1);
1715+
biasOrScaleRankFixedShape.append(
1716+
scaleOrBiasType.getShape().begin(), scaleOrBiasType.getShape().end());
1717+
1718+
// biasOrScaleRankFixedShape should have exactly one dimension that is not
1719+
// one
1720+
std::optional<int64_t> afterReshapeComputationDim;
1721+
for (auto [idx, dimSize] : enumerate(biasOrScaleRankFixedShape)) {
1722+
if (dimSize != 1) {
1723+
if (afterReshapeComputationDim) {
1724+
return rewriter.notifyMatchFailure(
1725+
matchOp, "scale/bias has more than one non-one dimension");
1726+
}
1727+
afterReshapeComputationDim = idx;
1728+
}
1729+
}
1730+
if (!afterReshapeComputationDim) {
1731+
return rewriter.notifyMatchFailure(
1732+
matchOp, "scale/bias has no non-one dimension");
1733+
}
1734+
1735+
const auto shapeIncludingComputationDim =
1736+
ArrayRef<int64_t>(reshapeResultType.getShape())
1737+
.slice(0, *afterReshapeComputationDim + 1);
1738+
const uint64_t computationRelevantSize =
1739+
std::accumulate(shapeIncludingComputationDim.begin(),
1740+
shapeIncludingComputationDim.end(), 1, std::multiplies<uint64_t>());
1741+
1742+
// The bias/scale dim should be not affected by the reshape. We need to
1743+
// map it back through it.
1744+
size_t reshapeInComputationDim;
1745+
auto reshapeInType =
1746+
dyn_cast<ShapedType>(reshapeOp->getOperand(0).getType());
1747+
if (!reshapeInType || !reshapeInType.hasStaticShape()) {
1748+
return rewriter.notifyMatchFailure(
1749+
matchOp, "reshape input has not a static shape");
1750+
}
1751+
const auto reshapeInShape = reshapeInType.getShape();
1752+
1753+
// trace the dim through the reshape
1754+
uint64_t acc = 1;
1755+
for (auto [idx, dimSize] : enumerate(reshapeInShape)) {
1756+
acc *= dimSize;
1757+
if (acc == computationRelevantSize) {
1758+
if (dimSize != biasOrScaleRankFixedShape[*afterReshapeComputationDim]) {
1759+
return rewriter.notifyMatchFailure(
1760+
matchOp, "bias/scale shape is not compatible with reshape input");
1761+
}
1762+
reshapeInComputationDim = idx;
1763+
break;
1764+
}
1765+
if (acc > computationRelevantSize) {
1766+
return rewriter.notifyMatchFailure(
1767+
matchOp, "bias/scale shape is not compatible with reshape input");
1768+
}
1769+
}
1770+
SmallVector<int64_t> newScaleOrBiasShape;
1771+
newScaleOrBiasShape.push_back(reshapeInShape[reshapeInComputationDim]);
1772+
newScaleOrBiasShape.append(
1773+
reshapeInShape.size() - reshapeInComputationDim - 1, 1);
1774+
return newScaleOrBiasShape;
1775+
}
16441776

16451777
LogicalResult matchAndRewrite(
1646-
ONNXAddOp addOp, PatternRewriter &rewriter) const final {
1778+
MATCH_OP_TYPE matchOp, PatternRewriter &rewriter) const final {
1779+
PatternRewriter::InsertionGuard guard(rewriter);
16471780
using namespace onnx_mlir;
1648-
Value y, bias;
1649-
Operation *yLayerNormOp;
1650-
Operation *ywbAddOp = addOp.getOperation();
1781+
Value y, scaleOrBias;
1782+
Operation *yLayerNormOp = nullptr;
1783+
Operation *reshapeOp = nullptr;
1784+
SmallVector<int64_t> newScaleOrBiasShape; // only used if there is a reshape
1785+
16511786
// Match
16521787
// %noBias = "onnx.NoValue"()
16531788
// %y, %mean, %invStdDev = "onnx.LayerNormalization"(%x, %scale, %noBias)
16541789
// {axis = 2 : si64, epsilon = 9.994E-6 : f32, stash_type = 1 : si64}
1655-
// %yBias = "onnx.Add"(%y, %bias)
1656-
if (!onnx_mlir::operandOfOpDefinedBy<OP_TYPE>(
1657-
yLayerNormOp, ywbAddOp, y, bias, 0) &&
1658-
!onnx_mlir::operandOfOpDefinedBy<OP_TYPE>(
1659-
yLayerNormOp, ywbAddOp, bias, y, 1))
1660-
return reportFailure("missing y, layer norm op");
1790+
// optional reshape between norm and add
1791+
// %yBias = "onnx.Add/onnx.Mul"(%y, %scaleOrBias)
1792+
1793+
if (onnx_mlir::operandOfOpDefinedBy<ONNXReshapeOp>(
1794+
reshapeOp, matchOp, y, scaleOrBias, 0) ||
1795+
onnx_mlir::operandOfOpDefinedBy<ONNXReshapeOp>(
1796+
reshapeOp, matchOp, scaleOrBias, y, 1)) {
1797+
yLayerNormOp = reshapeOp->getOperand(0).getDefiningOp<LN_TYPE>();
1798+
if (!yLayerNormOp) {
1799+
return rewriter.notifyMatchFailure(
1800+
reshapeOp, "reshape op does not have a layer norm as input");
1801+
}
1802+
if (!reshapeOp->hasOneUse()) {
1803+
return rewriter.notifyMatchFailure(
1804+
reshapeOp, "reshape op does not have a single use");
1805+
}
1806+
} else {
1807+
if (!onnx_mlir::operandOfOpDefinedBy<LN_TYPE>(
1808+
yLayerNormOp, matchOp, y, scaleOrBias, 0) &&
1809+
!onnx_mlir::operandOfOpDefinedBy<LN_TYPE>(
1810+
yLayerNormOp, matchOp, scaleOrBias, y, 1))
1811+
return rewriter.notifyMatchFailure(matchOp, "missing y, layer norm op");
1812+
}
1813+
16611814
// Study layer norm op; make sure its used only one and that bias is not
16621815
// used.
1663-
if (!yLayerNormOp->hasOneUse())
1664-
return reportFailure("y/layer norm has too many uses");
1665-
auto lnOp = mlir::cast<OP_TYPE>(yLayerNormOp);
1666-
if (!onnx_mlir::isNoneValue(lnOp.getB()))
1667-
return reportFailure("layer norm already has a bias");
1668-
// We are fine.
1669-
Value x = lnOp.getX();
1670-
Value scale = lnOp.getScale();
1671-
FloatAttr epsilon = lnOp.getEpsilonAttr();
1672-
int64_t axis = lnOp.getAxis();
1673-
LLVM_DEBUG(llvm::dbgs() << "LayerNorm from add, axis : " << axis << "\n");
1674-
1675-
// Replace
1676-
MultiDialectBuilder<OnnxBuilder> create(
1677-
rewriter, rewriter.getFusedLoc({lnOp.getLoc(), addOp->getLoc()}));
1678-
Type xType = x.getType();
1679-
Value res;
1680-
if constexpr (std::is_same<OP_TYPE, ONNXLayerNormalizationOp>::value)
1681-
res = create.onnx.layerNorm(xType, x, scale, bias, axis, epsilon);
1682-
else if constexpr (std::is_same<OP_TYPE,
1683-
ONNXRMSLayerNormalizationOp>::value)
1684-
res = create.onnx.RMSLayerNorm(xType, x, scale, bias, axis, epsilon);
1685-
else
1686-
llvm_unreachable("unsupported op");
1687-
rewriter.replaceOp(addOp, res);
1816+
assert(yLayerNormOp && "yLayerNormOp should not be null");
1817+
if (!yLayerNormOp->hasOneUse()) {
1818+
return rewriter.notifyMatchFailure(
1819+
yLayerNormOp, "y/layer norm has too many uses");
1820+
}
1821+
auto lnOp = mlir::cast<LN_TYPE>(yLayerNormOp);
1822+
if (!doExisitingScaleAndBiasAllowFusion(lnOp))
1823+
return rewriter.notifyMatchFailure(
1824+
lnOp, "existing scale and bias do not allow fusion");
1825+
1826+
if (reshapeOp) {
1827+
auto newShape = verifyAndCalculateNewReshapeShapes(
1828+
reshapeOp, matchOp, rewriter, scaleOrBias);
1829+
if (failed(newShape)) {
1830+
return failure();
1831+
}
1832+
newScaleOrBiasShape = std::move(*newShape);
1833+
}
1834+
1835+
// Norms only support unidirectional broadcasting to x
1836+
if (!reshapeOp && !areUnidirectionalBroadcastCompatible(
1837+
lnOp.getX().getType(), scaleOrBias.getType())) {
1838+
return rewriter.notifyMatchFailure(matchOp,
1839+
"layer norm and bias/scale are not unidirectional broadcast "
1840+
"compatible");
1841+
}
1842+
1843+
rewriter.moveOpAfter(
1844+
lnOp, matchOp); // Make sure we can use the const of the mul
1845+
rewriter.setInsertionPoint(matchOp);
1846+
if (reshapeOp) {
1847+
onnx_mlir::MultiDialectBuilder<onnx_mlir::OnnxBuilder> create(
1848+
rewriter, reshapeOp->getLoc());
1849+
const auto newShapeConst = create.onnx.constantInt64(newScaleOrBiasShape);
1850+
scaleOrBias = create.onnx.reshape(
1851+
RankedTensorType::get(newScaleOrBiasShape,
1852+
cast<ShapedType>(scaleOrBias.getType()).getElementType()),
1853+
scaleOrBias, newShapeConst);
1854+
}
1855+
rewriter.modifyOpInPlace(lnOp, [&] {
1856+
lnOp.setOperand(OPERAND_TO_MODIFY_INDEX, scaleOrBias);
1857+
lnOp->setLoc(rewriter.getFusedLoc({lnOp.getLoc(), matchOp->getLoc()}));
1858+
});
1859+
if (reshapeOp) {
1860+
rewriter.moveOpAfter(reshapeOp, lnOp);
1861+
rewriter.replaceOp(matchOp, reshapeOp->getResult(0));
1862+
} else {
1863+
rewriter.replaceOp(matchOp, lnOp.getY());
1864+
}
16881865
return success();
16891866
}
1867+
};
16901868

1691-
private:
1692-
LogicalResult reportFailure(std::string msg) const {
1693-
// Can disable line below if not needed.
1694-
LLVM_DEBUG(llvm::dbgs() << "LayerNorm failure:" << msg << "\n");
1695-
return failure();
1869+
} // namespace
1870+
1871+
template <typename LN_TYPE>
1872+
struct PropagateScaleIntoLayerNormPattern
1873+
: public PropagateBiasOrScaleIntoLayerNormRewritePatternBase<LN_TYPE,
1874+
ONNXMulOp, /*scale*/ 1> {
1875+
using PropagateBiasOrScaleIntoLayerNormRewritePatternBase<LN_TYPE, ONNXMulOp,
1876+
/*scale*/ 1>::PropagateBiasOrScaleIntoLayerNormRewritePatternBase;
1877+
1878+
bool doExisitingScaleAndBiasAllowFusion(LN_TYPE lnOp) const override {
1879+
if (!isValueNoneOrConstZero(lnOp.getB())) {
1880+
return false;
1881+
}
1882+
1883+
const auto elementsAttr = getElementAttributeFromONNXValue(lnOp.getScale());
1884+
if (!elementsAttr) {
1885+
return false;
1886+
}
1887+
if (!elementsAttr.isSplat()) {
1888+
return false;
1889+
}
1890+
return elementsAttr.template getSplatValue<APFloat>().isExactlyValue(1.0);
1891+
}
1892+
};
1893+
1894+
template <typename LN_TYPE>
1895+
struct PropagateBiasIntoLayerNormRewritePattern
1896+
: public PropagateBiasOrScaleIntoLayerNormRewritePatternBase<LN_TYPE,
1897+
ONNXAddOp, /*bias*/ 2> {
1898+
using PropagateBiasOrScaleIntoLayerNormRewritePatternBase<LN_TYPE, ONNXAddOp,
1899+
/*bias*/ 2>::PropagateBiasOrScaleIntoLayerNormRewritePatternBase;
1900+
1901+
bool doExisitingScaleAndBiasAllowFusion(LN_TYPE lnOp) const override {
1902+
return isValueNoneOrConstZero(lnOp.getB());
16961903
}
16971904
};
16981905

@@ -1839,7 +2046,8 @@ struct RemoveInstanceNormPattern
18392046
rewriter, instanceNormOp.getLoc());
18402047
int64_t axis = nonSpacialRank;
18412048
int64_t numInNorm = inputRank - axis;
1842-
// Unsqueeze scale/bias from [C] to [C x 1 x 1 x ... x 1] with numInNorm 1s.
2049+
// Unsqueeze scale/bias from [C] to [C x 1 x 1 x ... x 1] with numInNorm
2050+
// 1s.
18432051
llvm::SmallVector<int64_t, 4> axesList, biasScaleShape;
18442052
biasScaleShape.emplace_back(C);
18452053
for (int64_t i = 1; i <= numInNorm; ++i) {
@@ -2189,6 +2397,11 @@ void ONNXAddOp::getCanonicalizationPatterns(
21892397
results.insert<FuseAddConvNullBiasPattern>(context);
21902398
results.insert<BinaryOpBroadcastAxisPattern<ONNXAddOp>>(context);
21912399
results.insert<PropagateScalarConstantExpandPattern<ONNXAddOp>>(context);
2400+
results.insert<PropagateScaleIntoLayerNormPattern<ONNXLayerNormalizationOp>>(
2401+
context);
2402+
results
2403+
.insert<PropagateScaleIntoLayerNormPattern<ONNXRMSLayerNormalizationOp>>(
2404+
context);
21922405
results.insert<
21932406
PropagateBiasIntoLayerNormRewritePattern<ONNXLayerNormalizationOp>>(
21942407
context);

0 commit comments

Comments
 (0)