@@ -1477,12 +1477,12 @@ struct ReorderElementwiseAndShapeOp final
14771477 }
14781478};
14791479
1480- // Fuses batch normalization operation with convolution weight :
1481- // X = conv(input, weight )
1480+ // Fuses batch normalization operation with convolution kernel :
1481+ // X = conv(input, kernel.old )
14821482// Y = batch_norm_inference(X, ...)
14831483// into ->
1484- // X = conv(input, weight( new) )
1485- // Y = add(X, broadcast_in_dim(Bias( new) ))
1484+ // X = conv(input, kernel. new)
1485+ // Y = add(X, broadcast_in_dim(bias. new))
14861486//
14871487struct FuseConvolutionBatchNormalization final
14881488 : OpRewritePattern<BatchNormInferenceOp> {
@@ -1498,55 +1498,57 @@ struct FuseConvolutionBatchNormalization final
14981498 auto convOp = op.getOperand ().getDefiningOp <ConvolutionOp>();
14991499 if (!convOp) return failure ();
15001500
1501- auto convWeight = convOp.getRhs ();
1502- auto convWeightType = convWeight .getType ();
1503- auto convWeightShape = convWeightType .getShape ();
1501+ auto convKernel = convOp.getRhs ();
1502+ auto convKernelType = convKernel .getType ();
1503+ auto convKernelShape = convKernelType .getShape ();
15041504
15051505 auto dimNumbers = convOp.getDimensionNumbers ();
15061506 if (dimNumbers.getInputBatchDimension () != 0 ||
15071507 dimNumbers.getInputFeatureDimension () != 1 ||
15081508 dimNumbers.getOutputBatchDimension () != 0 ||
15091509 dimNumbers.getOutputFeatureDimension () != 1 ||
15101510 dimNumbers.getKernelOutputFeatureDimension () != 0 ||
1511- dimNumbers.getKernelInputFeatureDimension () != 1 )
1512- return rewriter.notifyMatchFailure (convOp,
1513- " Only [b, f, ...]x[o, i, ...]->[b, f, "
1514- " ...] configuration is supported" );
1511+ dimNumbers.getKernelInputFeatureDimension () != 1 ) {
1512+ constexpr StringLiteral msg =
1513+ " Only [b, f, ...]x[o, i, ...]->[b, f, ...] configuration is "
1514+ " supported" ;
1515+ return rewriter.notifyMatchFailure (convOp, msg);
1516+ }
15151517
15161518 if (convOp.getFeatureGroupCount () > 1 || convOp.getBatchGroupCount () > 1 )
15171519 return rewriter.notifyMatchFailure (
15181520 convOp, " feature or batch grouping is not supported" );
15191521
1520- if (bnOperandShape[bnFeatureIndex] != convWeightShape .front ())
1522+ if (bnOperandShape[bnFeatureIndex] != convKernelShape .front ())
15211523 return failure ();
15221524
1523- DenseFPElementsAttr convWeightElems ;
1525+ DenseFPElementsAttr convKernelElems ;
15241526 DenseFPElementsAttr scaleElems;
15251527 DenseFPElementsAttr offsetElems;
15261528 DenseFPElementsAttr meanElems;
15271529 DenseFPElementsAttr varianceElems;
15281530
1529- auto epsilon = op.getEpsilon ();
1531+ const auto epsilon = op.getEpsilon ();
15301532
1531- if (!matchPattern (convWeight , m_Constant (&convWeightElems )))
1533+ if (!matchPattern (convKernel , m_Constant (&convKernelElems )))
15321534 return rewriter.notifyMatchFailure (
1533- op, " expected constant convolution weight " );
1535+ op, " expected constant convolution kernel " );
15341536
15351537 if (!matchPattern (op.getScale (), m_Constant (&scaleElems)) ||
15361538 !matchPattern (op.getOffset (), m_Constant (&offsetElems)) ||
15371539 !matchPattern (op.getMean (), m_Constant (&meanElems)) ||
15381540 !matchPattern (op.getVariance (), m_Constant (&varianceElems)))
15391541 return failure ();
15401542
1541- const auto &convWeightSemantics =
1542- cast<FloatType>(convWeightType .getElementType ()).getFloatSemantics ();
1543+ const auto &convKernelSemantics =
1544+ cast<FloatType>(convKernelType .getElementType ()).getFloatSemantics ();
15431545
1544- // W( new) = W( old) * gamma * rsqrt(variance + epsilon)
1545- // B( new) = (B( old) - mean) * rsqrt(variance + epsilon) * gamma + betta
1546+ // K. new = K. old * gamma * rsqrt(variance + epsilon)
1547+ // B. new = (B. old - mean) * rsqrt(variance + epsilon) * gamma + beta
15461548 // where: gamma - scaling factor
1547- // betta - shifting factor
1549+ // beta - shifting factor
15481550 // rsqrt - reciprocal square root function
1549- // W - weight
1551+ // K - kernel(a.k.a weight)
15501552 // B - bias
15511553 //
15521554 const SmallVector<double > multipliers = llvm::map_to_vector (
@@ -1558,22 +1560,22 @@ struct FuseConvolutionBatchNormalization final
15581560 return rsqrt * scale.convertToDouble ();
15591561 });
15601562
1561- SmallVector<APFloat> newWeight ;
1562- newWeight .reserve (convWeightType .getNumElements ());
1563+ SmallVector<APFloat> newKernel ;
1564+ newKernel .reserve (convKernelType .getNumElements ());
15631565
15641566 const size_t outFeatureTileSize =
1565- computeProduct (convWeightShape .drop_front ());
1566- auto it = convWeightElems .begin ();
1567+ computeProduct (convKernelShape .drop_front ());
1568+ auto it = convKernelElems .begin ();
15671569 for (const auto &multiplier : multipliers) {
15681570 for (size_t i = 0 ; i < outFeatureTileSize; ++i) {
15691571 double v = (*it).convertToDouble () * multiplier;
15701572 APFloat result (v);
15711573 bool losesInfo;
15721574 if (APFloat::opStatus::opInvalidOp ==
1573- result.convert (convWeightSemantics , APFloat::rmNearestTiesToEven,
1575+ result.convert (convKernelSemantics , APFloat::rmNearestTiesToEven,
15741576 &losesInfo))
15751577 return failure ();
1576- newWeight .push_back (result);
1578+ newKernel .push_back (result);
15771579 ++it;
15781580 }
15791581 }
@@ -1591,26 +1593,26 @@ struct FuseConvolutionBatchNormalization final
15911593
15921594 bool losesInfo;
15931595 if (APFloat::opStatus::opInvalidOp ==
1594- result.convert (convWeightSemantics , APFloat::rmNearestTiesToEven,
1596+ result.convert (convKernelSemantics , APFloat::rmNearestTiesToEven,
15951597 &losesInfo))
15961598 return failure ();
15971599
15981600 biasValues.push_back (result);
15991601 }
16001602
16011603 rewriter.setInsertionPoint (op);
1602- auto newConvWeight = rewriter.create <ConstantOp>(
1603- convWeight .getLoc (), convWeightType ,
1604- DenseFPElementsAttr::get (convWeightType, newWeight ));
1604+ auto newConvKernel = rewriter.create <ConstantOp>(
1605+ convKernel .getLoc (), convKernelType ,
1606+ DenseFPElementsAttr::get (convKernelType, newKernel ));
16051607
16061608 // Keep old convolution as it might have other users
16071609 auto newConvOp = rewriter.create <ConvolutionOp>(
16081610 convOp.getLoc (), convOp->getResultTypes (),
1609- ValueRange{convOp.getLhs (), newConvWeight }, convOp->getAttrs ());
1611+ ValueRange{convOp.getLhs (), newConvKernel }, convOp->getAttrs ());
16101612
16111613 SmallVector<int64_t > biasShape{static_cast <int64_t >(biasValues.size ())};
16121614 auto biasType =
1613- convWeightType .cloneWith (biasShape, convWeightType .getElementType ());
1615+ convKernelType .cloneWith (biasShape, convKernelType .getElementType ());
16141616 auto bias = rewriter.create <ConstantOp>(
16151617 op.getLoc (), biasType, DenseFPElementsAttr::get (biasType, biasValues));
16161618
0 commit comments