|
21 | 21 | #include "llvm/ADT/SmallVector.h" |
22 | 22 | #include "llvm/ADT/SmallVectorExtras.h" |
23 | 23 | #include "llvm/Support/ErrorHandling.h" |
| 24 | +#include "mlir/Dialect/Utils/IndexingUtils.h" |
24 | 25 | #include "mlir/IR/Attributes.h" |
25 | 26 | #include "mlir/IR/Block.h" |
26 | 27 | #include "mlir/IR/Builders.h" |
@@ -1467,6 +1468,154 @@ struct ReorderElementwiseAndShapeOp final |
1467 | 1468 | } |
1468 | 1469 | }; |
1469 | 1470 |
|
| 1471 | +// Fuses batch normalization operation with convolution weight: |
| 1472 | +// X = conv(input, weight) |
| 1473 | +// Y = batch_norm_inference(X, ...) |
| 1474 | +// into -> |
| 1475 | +// X = conv(input, weight(new)) |
| 1476 | +// Y = add(X, broadcast_in_dim(Bias(new))) |
| 1477 | +// |
| 1478 | +struct FuseConvolutionBatchNormalization final |
| 1479 | + : OpRewritePattern<BatchNormInferenceOp> { |
| 1480 | + using OpRewritePattern::OpRewritePattern; |
| 1481 | + |
| 1482 | + LogicalResult matchAndRewrite(BatchNormInferenceOp op, |
| 1483 | + PatternRewriter &rewriter) const override { |
| 1484 | + auto bnOperandType = op.getOperand().getType(); |
| 1485 | + auto bnOperandShape = bnOperandType.getShape(); |
| 1486 | + auto bnResultType = op.getResult().getType(); |
| 1487 | + uint64_t bnFeatureIndex = op.getFeatureIndex(); |
| 1488 | + |
| 1489 | + auto convOp = op.getOperand().getDefiningOp<ConvolutionOp>(); |
| 1490 | + if (!convOp) return failure(); |
| 1491 | + |
| 1492 | + auto convWeight = convOp.getRhs(); |
| 1493 | + auto convWeightType = convWeight.getType(); |
| 1494 | + auto convWeightShape = convWeightType.getShape(); |
| 1495 | + |
| 1496 | + auto dimNumbers = convOp.getDimensionNumbers(); |
| 1497 | + if (dimNumbers.getInputBatchDimension() != 0 || |
| 1498 | + dimNumbers.getInputFeatureDimension() != 1 || |
| 1499 | + dimNumbers.getOutputBatchDimension() != 0 || |
| 1500 | + dimNumbers.getOutputFeatureDimension() != 1 || |
| 1501 | + dimNumbers.getKernelOutputFeatureDimension() != 0 || |
| 1502 | + dimNumbers.getKernelInputFeatureDimension() != 1) |
| 1503 | + return rewriter.notifyMatchFailure(convOp, |
| 1504 | + "Only [b, f, ...]x[o, i, ...]->[b, f, " |
| 1505 | + "...] configuration is supported"); |
| 1506 | + |
| 1507 | + if (convOp.getFeatureGroupCount() > 1 || convOp.getBatchGroupCount() > 1) |
| 1508 | + return rewriter.notifyMatchFailure( |
| 1509 | + convOp, "feature or batch grouping is not supported"); |
| 1510 | + |
| 1511 | + if (bnOperandShape[bnFeatureIndex] != convWeightShape.front()) |
| 1512 | + return failure(); |
| 1513 | + |
| 1514 | + DenseFPElementsAttr convWeightElems; |
| 1515 | + DenseFPElementsAttr scaleElems; |
| 1516 | + DenseFPElementsAttr offsetElems; |
| 1517 | + DenseFPElementsAttr meanElems; |
| 1518 | + DenseFPElementsAttr varianceElems; |
| 1519 | + |
| 1520 | + auto epsilon = op.getEpsilon(); |
| 1521 | + |
| 1522 | + if (!matchPattern(convWeight, m_Constant(&convWeightElems))) |
| 1523 | + return rewriter.notifyMatchFailure( |
| 1524 | + op, "expected constant convolution weight"); |
| 1525 | + |
| 1526 | + if (!matchPattern(op.getScale(), m_Constant(&scaleElems)) || |
| 1527 | + !matchPattern(op.getOffset(), m_Constant(&offsetElems)) || |
| 1528 | + !matchPattern(op.getMean(), m_Constant(&meanElems)) || |
| 1529 | + !matchPattern(op.getVariance(), m_Constant(&varianceElems))) |
| 1530 | + return failure(); |
| 1531 | + |
| 1532 | + const auto &convWeightSemantics = |
| 1533 | + cast<FloatType>(convWeightType.getElementType()).getFloatSemantics(); |
| 1534 | + |
| 1535 | + // W(new) = W(old) * gamma * rsqrt(variance + epsilon) |
| 1536 | + // B(new) = (B(old) - mean) * rsqrt(variance + epsilon) * gamma + betta |
| 1537 | + // where: gamma - scaling factor |
| 1538 | + // betta - shifting factor |
| 1539 | + // rsqrt - reciprocal square root function |
| 1540 | + // W - weight |
| 1541 | + // B - bias |
| 1542 | + // |
| 1543 | + const SmallVector<double> multipliers = llvm::map_to_vector( |
| 1544 | + llvm::zip_equal(varianceElems, scaleElems), |
| 1545 | + [&epsilon](const std::tuple<APFloat, APFloat> &pack) -> double { |
| 1546 | + const auto &[variance, scale] = pack; |
| 1547 | + auto varEps = (variance + epsilon).convertToDouble(); |
| 1548 | + auto rsqrt = 1.0 / std::sqrt(varEps); |
| 1549 | + return rsqrt * scale.convertToDouble(); |
| 1550 | + }); |
| 1551 | + |
| 1552 | + SmallVector<APFloat> newWeight; |
| 1553 | + newWeight.reserve(convWeightType.getNumElements()); |
| 1554 | + |
| 1555 | + const size_t outFeatureTileSize = |
| 1556 | + computeProduct(convWeightShape.drop_front()); |
| 1557 | + auto it = convWeightElems.begin(); |
| 1558 | + for (const auto &multiplier : multipliers) { |
| 1559 | + for (size_t i = 0; i < outFeatureTileSize; ++i) { |
| 1560 | + double v = (*it).convertToDouble() * multiplier; |
| 1561 | + APFloat result(v); |
| 1562 | + bool losesInfo; |
| 1563 | + if (APFloat::opStatus::opInvalidOp == |
| 1564 | + result.convert(convWeightSemantics, APFloat::rmNearestTiesToEven, |
| 1565 | + &losesInfo)) |
| 1566 | + return failure(); |
| 1567 | + newWeight.push_back(result); |
| 1568 | + ++it; |
| 1569 | + } |
| 1570 | + } |
| 1571 | + |
| 1572 | + SmallVector<APFloat> biasValues; |
| 1573 | + biasValues.reserve(multipliers.size()); |
| 1574 | + |
| 1575 | + for (const auto &[off, multiplier, mean] : |
| 1576 | + llvm::zip_equal(offsetElems, multipliers, meanElems)) { |
| 1577 | + // stablehlo convolution operation doesn't have a builtin bias |
| 1578 | + double convBias = 0; |
| 1579 | + double v = (convBias - mean.convertToDouble()) * multiplier + |
| 1580 | + off.convertToDouble(); |
| 1581 | + APFloat result(v); |
| 1582 | + |
| 1583 | + bool losesInfo; |
| 1584 | + if (APFloat::opStatus::opInvalidOp == |
| 1585 | + result.convert(convWeightSemantics, APFloat::rmNearestTiesToEven, |
| 1586 | + &losesInfo)) |
| 1587 | + return failure(); |
| 1588 | + |
| 1589 | + biasValues.push_back(result); |
| 1590 | + } |
| 1591 | + |
| 1592 | + rewriter.setInsertionPoint(op); |
| 1593 | + auto newConvWeight = rewriter.create<ConstantOp>( |
| 1594 | + convWeight.getLoc(), convWeightType, |
| 1595 | + DenseFPElementsAttr::get(convWeightType, newWeight)); |
| 1596 | + |
| 1597 | + // Keep old convolution as it might have other users |
| 1598 | + auto newConvOp = rewriter.create<ConvolutionOp>( |
| 1599 | + convOp.getLoc(), convOp->getResultTypes(), |
| 1600 | + ValueRange{convOp.getLhs(), newConvWeight}, convOp->getAttrs()); |
| 1601 | + |
| 1602 | + SmallVector<int64_t> biasShape{static_cast<int64_t>(biasValues.size())}; |
| 1603 | + auto biasType = |
| 1604 | + convWeightType.cloneWith(biasShape, convWeightType.getElementType()); |
| 1605 | + auto bias = rewriter.create<ConstantOp>( |
| 1606 | + op.getLoc(), biasType, DenseFPElementsAttr::get(biasType, biasValues)); |
| 1607 | + |
| 1608 | + auto indices = |
| 1609 | + rewriter.getDenseI64ArrayAttr({static_cast<int64_t>(bnFeatureIndex)}); |
| 1610 | + auto bcast = rewriter.create<BroadcastInDimOp>(op.getLoc(), bnResultType, |
| 1611 | + bias, indices); |
| 1612 | + auto add = rewriter.create<AddOp>(op.getLoc(), newConvOp, bcast); |
| 1613 | + |
| 1614 | + rewriter.replaceOp(op, add); |
| 1615 | + return success(); |
| 1616 | + } |
| 1617 | +}; |
| 1618 | + |
1470 | 1619 | struct StablehloAggressiveSimplificationPass final |
1471 | 1620 | : impl::StablehloAggressiveSimplificationPassBase< |
1472 | 1621 | StablehloAggressiveSimplificationPass> { |
@@ -1513,6 +1662,8 @@ void populateStablehloCanonicalizationPatterns(MLIRContext *context, |
1513 | 1662 | patterns |
1514 | 1663 | ->add<GetDimensionSizeOpCanon, DynamicBroadcastInDimOpNotActuallyDynamic, |
1515 | 1664 | DynamicReshapeOpIsStatic, DynamicIotaIsStatic>(context); |
| 1665 | + |
| 1666 | + patterns->add<FuseConvolutionBatchNormalization>(context); |
1516 | 1667 | } |
1517 | 1668 |
|
1518 | 1669 | } // namespace stablehlo |
|
0 commit comments