@@ -1468,12 +1468,12 @@ struct ReorderElementwiseAndShapeOp final
14681468 }
14691469};
14701470
1471- // Fuses batch normalization operation with convolution weight :
1472- // X = conv(input, weight )
1471+ // Fuses batch normalization operation with convolution kernel :
1472+ // X = conv(input, kernel.old )
14731473// Y = batch_norm_inference(X, ...)
14741474// into ->
1475- // X = conv(input, weight( new) )
1476- // Y = add(X, broadcast_in_dim(Bias( new) ))
1475+ // X = conv(input, kernel. new)
1476+ // Y = add(X, broadcast_in_dim(bias. new))
14771477//
14781478struct FuseConvolutionBatchNormalization final
14791479 : OpRewritePattern<BatchNormInferenceOp> {
@@ -1489,55 +1489,57 @@ struct FuseConvolutionBatchNormalization final
14891489 auto convOp = op.getOperand ().getDefiningOp <ConvolutionOp>();
14901490 if (!convOp) return failure ();
14911491
1492- auto convWeight = convOp.getRhs ();
1493- auto convWeightType = convWeight .getType ();
1494- auto convWeightShape = convWeightType .getShape ();
1492+ auto convKernel = convOp.getRhs ();
1493+ auto convKernelType = convKernel .getType ();
1494+ auto convKernelShape = convKernelType .getShape ();
14951495
14961496 auto dimNumbers = convOp.getDimensionNumbers ();
14971497 if (dimNumbers.getInputBatchDimension () != 0 ||
14981498 dimNumbers.getInputFeatureDimension () != 1 ||
14991499 dimNumbers.getOutputBatchDimension () != 0 ||
15001500 dimNumbers.getOutputFeatureDimension () != 1 ||
15011501 dimNumbers.getKernelOutputFeatureDimension () != 0 ||
1502- dimNumbers.getKernelInputFeatureDimension () != 1 )
1503- return rewriter.notifyMatchFailure (convOp,
1504- " Only [b, f, ...]x[o, i, ...]->[b, f, "
1505- " ...] configuration is supported" );
1502+ dimNumbers.getKernelInputFeatureDimension () != 1 ) {
1503+ constexpr StringLiteral msg =
1504+ " Only [b, f, ...]x[o, i, ...]->[b, f, ...] configuration is "
1505+ " supported" ;
1506+ return rewriter.notifyMatchFailure (convOp, msg);
1507+ }
15061508
15071509 if (convOp.getFeatureGroupCount () > 1 || convOp.getBatchGroupCount () > 1 )
15081510 return rewriter.notifyMatchFailure (
15091511 convOp, " feature or batch grouping is not supported" );
15101512
1511- if (bnOperandShape[bnFeatureIndex] != convWeightShape .front ())
1513+ if (bnOperandShape[bnFeatureIndex] != convKernelShape .front ())
15121514 return failure ();
15131515
1514- DenseFPElementsAttr convWeightElems ;
1516+ DenseFPElementsAttr convKernelElems ;
15151517 DenseFPElementsAttr scaleElems;
15161518 DenseFPElementsAttr offsetElems;
15171519 DenseFPElementsAttr meanElems;
15181520 DenseFPElementsAttr varianceElems;
15191521
1520- auto epsilon = op.getEpsilon ();
1522+ const auto epsilon = op.getEpsilon ();
15211523
1522- if (!matchPattern (convWeight , m_Constant (&convWeightElems )))
1524+ if (!matchPattern (convKernel , m_Constant (&convKernelElems )))
15231525 return rewriter.notifyMatchFailure (
1524- op, " expected constant convolution weight " );
1526+ op, " expected constant convolution kernel " );
15251527
15261528 if (!matchPattern (op.getScale (), m_Constant (&scaleElems)) ||
15271529 !matchPattern (op.getOffset (), m_Constant (&offsetElems)) ||
15281530 !matchPattern (op.getMean (), m_Constant (&meanElems)) ||
15291531 !matchPattern (op.getVariance (), m_Constant (&varianceElems)))
15301532 return failure ();
15311533
1532- const auto &convWeightSemantics =
1533- cast<FloatType>(convWeightType .getElementType ()).getFloatSemantics ();
1534+ const auto &convKernelSemantics =
1535+ cast<FloatType>(convKernelType .getElementType ()).getFloatSemantics ();
15341536
1535- // W( new) = W( old) * gamma * rsqrt(variance + epsilon)
1536- // B( new) = (B( old) - mean) * rsqrt(variance + epsilon) * gamma + betta
1537+ // K. new = K. old * gamma * rsqrt(variance + epsilon)
1538+ // B. new = (B. old - mean) * rsqrt(variance + epsilon) * gamma + beta
15371539 // where: gamma - scaling factor
1538- // betta - shifting factor
1540+ // beta - shifting factor
15391541 // rsqrt - reciprocal square root function
1540- // W - weight
1542+ // K - kernel(a.k.a weight)
15411543 // B - bias
15421544 //
15431545 const SmallVector<double > multipliers = llvm::map_to_vector (
@@ -1549,22 +1551,22 @@ struct FuseConvolutionBatchNormalization final
15491551 return rsqrt * scale.convertToDouble ();
15501552 });
15511553
1552- SmallVector<APFloat> newWeight ;
1553- newWeight .reserve (convWeightType .getNumElements ());
1554+ SmallVector<APFloat> newKernel ;
1555+ newKernel .reserve (convKernelType .getNumElements ());
15541556
15551557 const size_t outFeatureTileSize =
1556- computeProduct (convWeightShape .drop_front ());
1557- auto it = convWeightElems .begin ();
1558+ computeProduct (convKernelShape .drop_front ());
1559+ auto it = convKernelElems .begin ();
15581560 for (const auto &multiplier : multipliers) {
15591561 for (size_t i = 0 ; i < outFeatureTileSize; ++i) {
15601562 double v = (*it).convertToDouble () * multiplier;
15611563 APFloat result (v);
15621564 bool losesInfo;
15631565 if (APFloat::opStatus::opInvalidOp ==
1564- result.convert (convWeightSemantics , APFloat::rmNearestTiesToEven,
1566+ result.convert (convKernelSemantics , APFloat::rmNearestTiesToEven,
15651567 &losesInfo))
15661568 return failure ();
1567- newWeight .push_back (result);
1569+ newKernel .push_back (result);
15681570 ++it;
15691571 }
15701572 }
@@ -1582,26 +1584,26 @@ struct FuseConvolutionBatchNormalization final
15821584
15831585 bool losesInfo;
15841586 if (APFloat::opStatus::opInvalidOp ==
1585- result.convert (convWeightSemantics , APFloat::rmNearestTiesToEven,
1587+ result.convert (convKernelSemantics , APFloat::rmNearestTiesToEven,
15861588 &losesInfo))
15871589 return failure ();
15881590
15891591 biasValues.push_back (result);
15901592 }
15911593
15921594 rewriter.setInsertionPoint (op);
1593- auto newConvWeight = rewriter.create <ConstantOp>(
1594- convWeight .getLoc (), convWeightType ,
1595- DenseFPElementsAttr::get (convWeightType, newWeight ));
1595+ auto newConvKernel = rewriter.create <ConstantOp>(
1596+ convKernel .getLoc (), convKernelType ,
1597+ DenseFPElementsAttr::get (convKernelType, newKernel ));
15961598
15971599 // Keep old convolution as it might have other users
15981600 auto newConvOp = rewriter.create <ConvolutionOp>(
15991601 convOp.getLoc (), convOp->getResultTypes (),
1600- ValueRange{convOp.getLhs (), newConvWeight }, convOp->getAttrs ());
1602+ ValueRange{convOp.getLhs (), newConvKernel }, convOp->getAttrs ());
16011603
16021604 SmallVector<int64_t > biasShape{static_cast <int64_t >(biasValues.size ())};
16031605 auto biasType =
1604- convWeightType .cloneWith (biasShape, convWeightType .getElementType ());
1606+ convKernelType .cloneWith (biasShape, convKernelType .getElementType ());
16051607 auto bias = rewriter.create <ConstantOp>(
16061608 op.getLoc (), biasType, DenseFPElementsAttr::get (biasType, biasValues));
16071609
0 commit comments