Skip to content

Commit da6cff5

Browse files
committed
Fuse batch normalization into convolution weights
1 parent b1c1115 commit da6cff5

File tree

2 files changed

+211
-0
lines changed

2 files changed

+211
-0
lines changed

stablehlo/testdata/bn_conv_fuse_float32.mlir

Lines changed: 60 additions & 0 deletions
Large diffs are not rendered by default.

stablehlo/transforms/StablehloAggressiveSimplification.cpp

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "llvm/ADT/SmallVector.h"
2222
#include "llvm/ADT/SmallVectorExtras.h"
2323
#include "llvm/Support/ErrorHandling.h"
24+
#include "mlir/Dialect/Utils/IndexingUtils.h"
2425
#include "mlir/IR/Attributes.h"
2526
#include "mlir/IR/Block.h"
2627
#include "mlir/IR/Builders.h"
@@ -1467,6 +1468,154 @@ struct ReorderElementwiseAndShapeOp final
14671468
}
14681469
};
14691470

1471+
// Fuses batch normalization operation with convolution weight:
1472+
// X = conv(input, weight)
1473+
// Y = batch_norm_inference(X, ...)
1474+
// into ->
1475+
// X = conv(input, weight(new))
1476+
// Y = add(X, broadcast_in_dim(Bias(new)))
1477+
//
1478+
struct FuseConvolutionBatchNormalization final
1479+
: OpRewritePattern<BatchNormInferenceOp> {
1480+
using OpRewritePattern::OpRewritePattern;
1481+
1482+
LogicalResult matchAndRewrite(BatchNormInferenceOp op,
1483+
PatternRewriter &rewriter) const override {
1484+
auto bnOperandType = op.getOperand().getType();
1485+
auto bnOperandShape = bnOperandType.getShape();
1486+
auto bnResultType = op.getResult().getType();
1487+
uint64_t bnFeatureIndex = op.getFeatureIndex();
1488+
1489+
auto convOp = op.getOperand().getDefiningOp<ConvolutionOp>();
1490+
if (!convOp) return failure();
1491+
1492+
auto convWeight = convOp.getRhs();
1493+
auto convWeightType = convWeight.getType();
1494+
auto convWeightShape = convWeightType.getShape();
1495+
1496+
auto dimNumbers = convOp.getDimensionNumbers();
1497+
if (dimNumbers.getInputBatchDimension() != 0 ||
1498+
dimNumbers.getInputFeatureDimension() != 1 ||
1499+
dimNumbers.getOutputBatchDimension() != 0 ||
1500+
dimNumbers.getOutputFeatureDimension() != 1 ||
1501+
dimNumbers.getKernelOutputFeatureDimension() != 0 ||
1502+
dimNumbers.getKernelInputFeatureDimension() != 1)
1503+
return rewriter.notifyMatchFailure(convOp,
1504+
"Only [b, f, ...]x[o, i, ...]->[b, f, "
1505+
"...] configuration is supported");
1506+
1507+
if (convOp.getFeatureGroupCount() > 1 || convOp.getBatchGroupCount() > 1)
1508+
return rewriter.notifyMatchFailure(
1509+
convOp, "feature or batch grouping is not supported");
1510+
1511+
if (bnOperandShape[bnFeatureIndex] != convWeightShape.front())
1512+
return failure();
1513+
1514+
DenseFPElementsAttr convWeightElems;
1515+
DenseFPElementsAttr scaleElems;
1516+
DenseFPElementsAttr offsetElems;
1517+
DenseFPElementsAttr meanElems;
1518+
DenseFPElementsAttr varianceElems;
1519+
1520+
auto epsilon = op.getEpsilon();
1521+
1522+
if (!matchPattern(convWeight, m_Constant(&convWeightElems)))
1523+
return rewriter.notifyMatchFailure(
1524+
op, "expected constant convolution weight");
1525+
1526+
if (!matchPattern(op.getScale(), m_Constant(&scaleElems)) ||
1527+
!matchPattern(op.getOffset(), m_Constant(&offsetElems)) ||
1528+
!matchPattern(op.getMean(), m_Constant(&meanElems)) ||
1529+
!matchPattern(op.getVariance(), m_Constant(&varianceElems)))
1530+
return failure();
1531+
1532+
const auto &convWeightSemantics =
1533+
cast<FloatType>(convWeightType.getElementType()).getFloatSemantics();
1534+
1535+
// W(new) = W(old) * gamma * rsqrt(variance + epsilon)
1536+
// B(new) = (B(old) - mean) * rsqrt(variance + epsilon) * gamma + betta
1537+
// where: gamma - scaling factor
1538+
// betta - shifting factor
1539+
// rsqrt - reciprocal square root function
1540+
// W - weight
1541+
// B - bias
1542+
//
1543+
const SmallVector<double> multipliers = llvm::map_to_vector(
1544+
llvm::zip_equal(varianceElems, scaleElems),
1545+
[&epsilon](const std::tuple<APFloat, APFloat> &pack) -> double {
1546+
const auto &[variance, scale] = pack;
1547+
auto varEps = (variance + epsilon).convertToDouble();
1548+
auto rsqrt = 1.0 / std::sqrt(varEps);
1549+
return rsqrt * scale.convertToDouble();
1550+
});
1551+
1552+
SmallVector<APFloat> newWeight;
1553+
newWeight.reserve(convWeightType.getNumElements());
1554+
1555+
const size_t outFeatureTileSize =
1556+
computeProduct(convWeightShape.drop_front());
1557+
auto it = convWeightElems.begin();
1558+
for (const auto &multiplier : multipliers) {
1559+
for (size_t i = 0; i < outFeatureTileSize; ++i) {
1560+
double v = (*it).convertToDouble() * multiplier;
1561+
APFloat result(v);
1562+
bool losesInfo;
1563+
if (APFloat::opStatus::opInvalidOp ==
1564+
result.convert(convWeightSemantics, APFloat::rmNearestTiesToEven,
1565+
&losesInfo))
1566+
return failure();
1567+
newWeight.push_back(result);
1568+
++it;
1569+
}
1570+
}
1571+
1572+
SmallVector<APFloat> biasValues;
1573+
biasValues.reserve(multipliers.size());
1574+
1575+
for (const auto &[off, multiplier, mean] :
1576+
llvm::zip_equal(offsetElems, multipliers, meanElems)) {
1577+
// stablehlo convolution operation doesn't have a builtin bias
1578+
double convBias = 0;
1579+
double v = (convBias - mean.convertToDouble()) * multiplier +
1580+
off.convertToDouble();
1581+
APFloat result(v);
1582+
1583+
bool losesInfo;
1584+
if (APFloat::opStatus::opInvalidOp ==
1585+
result.convert(convWeightSemantics, APFloat::rmNearestTiesToEven,
1586+
&losesInfo))
1587+
return failure();
1588+
1589+
biasValues.push_back(result);
1590+
}
1591+
1592+
rewriter.setInsertionPoint(op);
1593+
auto newConvWeight = rewriter.create<ConstantOp>(
1594+
convWeight.getLoc(), convWeightType,
1595+
DenseFPElementsAttr::get(convWeightType, newWeight));
1596+
1597+
// Keep old convolution as it might have other users
1598+
auto newConvOp = rewriter.create<ConvolutionOp>(
1599+
convOp.getLoc(), convOp->getResultTypes(),
1600+
ValueRange{convOp.getLhs(), newConvWeight}, convOp->getAttrs());
1601+
1602+
SmallVector<int64_t> biasShape{static_cast<int64_t>(biasValues.size())};
1603+
auto biasType =
1604+
convWeightType.cloneWith(biasShape, convWeightType.getElementType());
1605+
auto bias = rewriter.create<ConstantOp>(
1606+
op.getLoc(), biasType, DenseFPElementsAttr::get(biasType, biasValues));
1607+
1608+
auto indices =
1609+
rewriter.getDenseI64ArrayAttr({static_cast<int64_t>(bnFeatureIndex)});
1610+
auto bcast = rewriter.create<BroadcastInDimOp>(op.getLoc(), bnResultType,
1611+
bias, indices);
1612+
auto add = rewriter.create<AddOp>(op.getLoc(), newConvOp, bcast);
1613+
1614+
rewriter.replaceOp(op, add);
1615+
return success();
1616+
}
1617+
};
1618+
14701619
struct StablehloAggressiveSimplificationPass final
14711620
: impl::StablehloAggressiveSimplificationPassBase<
14721621
StablehloAggressiveSimplificationPass> {
@@ -1513,6 +1662,8 @@ void populateStablehloCanonicalizationPatterns(MLIRContext *context,
15131662
patterns
15141663
->add<GetDimensionSizeOpCanon, DynamicBroadcastInDimOpNotActuallyDynamic,
15151664
DynamicReshapeOpIsStatic, DynamicIotaIsStatic>(context);
1665+
1666+
patterns->add<FuseConvolutionBatchNormalization>(context);
15161667
}
15171668

15181669
} // namespace stablehlo

0 commit comments

Comments
 (0)