Skip to content

Commit e70f089

Browse files
committed
Fuse batch normalization into convolution weights
1 parent 4598975 commit e70f089

File tree

2 files changed

+211
-0
lines changed

2 files changed

+211
-0
lines changed

stablehlo/testdata/bn_conv_fuse_float32.mlir

Lines changed: 60 additions & 0 deletions
Large diffs are not rendered by default.

stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "llvm/ADT/SmallVectorExtras.h"
2424
#include "llvm/Support/ErrorHandling.h"
2525
#include "mlir/Dialect/Arith/IR/Arith.h"
26+
#include "mlir/Dialect/Utils/IndexingUtils.h"
2627
#include "mlir/IR/Attributes.h"
2728
#include "mlir/IR/Block.h"
2829
#include "mlir/IR/Builders.h"
@@ -1476,6 +1477,154 @@ struct ReorderElementwiseAndShapeOp final
14761477
}
14771478
};
14781479

1480+
// Fuses batch normalization operation with convolution weight:
1481+
// X = conv(input, weight)
1482+
// Y = batch_norm_inference(X, ...)
1483+
// into ->
1484+
// X = conv(input, weight(new))
1485+
// Y = add(X, broadcast_in_dim(Bias(new)))
1486+
//
1487+
struct FuseConvolutionBatchNormalization final
1488+
: OpRewritePattern<BatchNormInferenceOp> {
1489+
using OpRewritePattern::OpRewritePattern;
1490+
1491+
LogicalResult matchAndRewrite(BatchNormInferenceOp op,
1492+
PatternRewriter &rewriter) const override {
1493+
auto bnOperandType = op.getOperand().getType();
1494+
auto bnOperandShape = bnOperandType.getShape();
1495+
auto bnResultType = op.getResult().getType();
1496+
uint64_t bnFeatureIndex = op.getFeatureIndex();
1497+
1498+
auto convOp = op.getOperand().getDefiningOp<ConvolutionOp>();
1499+
if (!convOp) return failure();
1500+
1501+
auto convWeight = convOp.getRhs();
1502+
auto convWeightType = convWeight.getType();
1503+
auto convWeightShape = convWeightType.getShape();
1504+
1505+
auto dimNumbers = convOp.getDimensionNumbers();
1506+
if (dimNumbers.getInputBatchDimension() != 0 ||
1507+
dimNumbers.getInputFeatureDimension() != 1 ||
1508+
dimNumbers.getOutputBatchDimension() != 0 ||
1509+
dimNumbers.getOutputFeatureDimension() != 1 ||
1510+
dimNumbers.getKernelOutputFeatureDimension() != 0 ||
1511+
dimNumbers.getKernelInputFeatureDimension() != 1)
1512+
return rewriter.notifyMatchFailure(convOp,
1513+
"Only [b, f, ...]x[o, i, ...]->[b, f, "
1514+
"...] configuration is supported");
1515+
1516+
if (convOp.getFeatureGroupCount() > 1 || convOp.getBatchGroupCount() > 1)
1517+
return rewriter.notifyMatchFailure(
1518+
convOp, "feature or batch grouping is not supported");
1519+
1520+
if (bnOperandShape[bnFeatureIndex] != convWeightShape.front())
1521+
return failure();
1522+
1523+
DenseFPElementsAttr convWeightElems;
1524+
DenseFPElementsAttr scaleElems;
1525+
DenseFPElementsAttr offsetElems;
1526+
DenseFPElementsAttr meanElems;
1527+
DenseFPElementsAttr varianceElems;
1528+
1529+
auto epsilon = op.getEpsilon();
1530+
1531+
if (!matchPattern(convWeight, m_Constant(&convWeightElems)))
1532+
return rewriter.notifyMatchFailure(
1533+
op, "expected constant convolution weight");
1534+
1535+
if (!matchPattern(op.getScale(), m_Constant(&scaleElems)) ||
1536+
!matchPattern(op.getOffset(), m_Constant(&offsetElems)) ||
1537+
!matchPattern(op.getMean(), m_Constant(&meanElems)) ||
1538+
!matchPattern(op.getVariance(), m_Constant(&varianceElems)))
1539+
return failure();
1540+
1541+
const auto &convWeightSemantics =
1542+
cast<FloatType>(convWeightType.getElementType()).getFloatSemantics();
1543+
1544+
// W(new) = W(old) * gamma * rsqrt(variance + epsilon)
1545+
// B(new) = (B(old) - mean) * rsqrt(variance + epsilon) * gamma + betta
1546+
// where: gamma - scaling factor
1547+
// betta - shifting factor
1548+
// rsqrt - reciprocal square root function
1549+
// W - weight
1550+
// B - bias
1551+
//
1552+
const SmallVector<double> multipliers = llvm::map_to_vector(
1553+
llvm::zip_equal(varianceElems, scaleElems),
1554+
[&epsilon](const std::tuple<APFloat, APFloat> &pack) -> double {
1555+
const auto &[variance, scale] = pack;
1556+
auto varEps = (variance + epsilon).convertToDouble();
1557+
auto rsqrt = 1.0 / std::sqrt(varEps);
1558+
return rsqrt * scale.convertToDouble();
1559+
});
1560+
1561+
SmallVector<APFloat> newWeight;
1562+
newWeight.reserve(convWeightType.getNumElements());
1563+
1564+
const size_t outFeatureTileSize =
1565+
computeProduct(convWeightShape.drop_front());
1566+
auto it = convWeightElems.begin();
1567+
for (const auto &multiplier : multipliers) {
1568+
for (size_t i = 0; i < outFeatureTileSize; ++i) {
1569+
double v = (*it).convertToDouble() * multiplier;
1570+
APFloat result(v);
1571+
bool losesInfo;
1572+
if (APFloat::opStatus::opInvalidOp ==
1573+
result.convert(convWeightSemantics, APFloat::rmNearestTiesToEven,
1574+
&losesInfo))
1575+
return failure();
1576+
newWeight.push_back(result);
1577+
++it;
1578+
}
1579+
}
1580+
1581+
SmallVector<APFloat> biasValues;
1582+
biasValues.reserve(multipliers.size());
1583+
1584+
for (const auto &[off, multiplier, mean] :
1585+
llvm::zip_equal(offsetElems, multipliers, meanElems)) {
1586+
// stablehlo convolution operation doesn't have a builtin bias
1587+
double convBias = 0;
1588+
double v = (convBias - mean.convertToDouble()) * multiplier +
1589+
off.convertToDouble();
1590+
APFloat result(v);
1591+
1592+
bool losesInfo;
1593+
if (APFloat::opStatus::opInvalidOp ==
1594+
result.convert(convWeightSemantics, APFloat::rmNearestTiesToEven,
1595+
&losesInfo))
1596+
return failure();
1597+
1598+
biasValues.push_back(result);
1599+
}
1600+
1601+
rewriter.setInsertionPoint(op);
1602+
auto newConvWeight = rewriter.create<ConstantOp>(
1603+
convWeight.getLoc(), convWeightType,
1604+
DenseFPElementsAttr::get(convWeightType, newWeight));
1605+
1606+
// Keep old convolution as it might have other users
1607+
auto newConvOp = rewriter.create<ConvolutionOp>(
1608+
convOp.getLoc(), convOp->getResultTypes(),
1609+
ValueRange{convOp.getLhs(), newConvWeight}, convOp->getAttrs());
1610+
1611+
SmallVector<int64_t> biasShape{static_cast<int64_t>(biasValues.size())};
1612+
auto biasType =
1613+
convWeightType.cloneWith(biasShape, convWeightType.getElementType());
1614+
auto bias = rewriter.create<ConstantOp>(
1615+
op.getLoc(), biasType, DenseFPElementsAttr::get(biasType, biasValues));
1616+
1617+
auto indices =
1618+
rewriter.getDenseI64ArrayAttr({static_cast<int64_t>(bnFeatureIndex)});
1619+
auto bcast = rewriter.create<BroadcastInDimOp>(op.getLoc(), bnResultType,
1620+
bias, indices);
1621+
auto add = rewriter.create<AddOp>(op.getLoc(), newConvOp, bcast);
1622+
1623+
rewriter.replaceOp(op, add);
1624+
return success();
1625+
}
1626+
};
1627+
14791628
struct StablehloAggressiveSimplificationPass final
14801629
: impl::StablehloAggressiveSimplificationPassBase<
14811630
StablehloAggressiveSimplificationPass> {
@@ -1526,6 +1675,8 @@ void populateStablehloCanonicalizationPatterns(MLIRContext *context,
15261675
patterns
15271676
->add<GetDimensionSizeOpCanon, DynamicBroadcastInDimOpNotActuallyDynamic,
15281677
DynamicReshapeOpIsStatic, DynamicIotaIsStatic>(context);
1678+
1679+
patterns->add<FuseConvolutionBatchNormalization>(context);
15291680
}
15301681

15311682
std::unique_ptr<Pass> createStablehloAggressiveSimplificationPass(

0 commit comments

Comments
 (0)