Skip to content

Commit c54e3de

Browse files
committed
fix naming, comments
1 parent e70f089 commit c54e3de

File tree

2 files changed

+43
-37
lines changed

2 files changed

+43
-37
lines changed

stablehlo/testdata/bn_conv_fuse_float32.mlir

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,17 @@ module @jit_main attributes {torch.debug_module_name = "ResNet"} {
66
%cst_1 = stablehlo.constant dense<[1.01694489, 3.71674347, 5.81334356E-11, 3.28254271, 1.71074404E-13, 0.658226967, 4.37006235, 6.60045282E-12, 0.915522992, 1.93175254E-9, 4.12558556, 2.74399233, 2.8390913, 4.79658588E-8, 11.0722713, 0.500745952, 2.23128176, 4.82570696, 2.69861364, 9.36995506, 3.73391747, 5.48429585, 5.7126689, 0.445444882, 0.436275303, 7.15633583, 13.7179089, 5.25117493, 6.81737518, 1.67235756, 1.65343034, 1.23245978, 4.90762854, 3.07305121, 4.23838568, 4.99363518, 1.44646307E-12, 1.52116203, 1.03519833E-13, 0.351344079, 0.17024748, 1.42054474, 1.90848303, 2.15124035, 2.66084933, 4.84443378, 1.92971194, 1.49994361, 2.94806145E-13, 1.53064024, 0.365027189, 2.93755412, 5.46641159, 0.707924544, 3.33150721, 0.771802961, 2.40678358, 6.5213666, 4.12625027, 1.05063522, 2.95303202, 11.3656216, 4.76904678, 1.65587807]> : tensor<64xf32>
77
%cst_2 = stablehlo.constant dense<[0.234872743, 0.266257942, -5.10959595E-8, 0.518699706, 3.44040196E-9, 0.222385287, 0.422887057, 1.31532403E-7, 0.25093165, 1.5152026E-6, 0.316871643, 0.250491828, 0.378926098, 1.08618351E-5, 2.752640e-01, 0.236741036, 0.242021769, 0.395314813, 0.469346285, 0.2908957, 0.272684187, 0.27802828, 0.290692091, 0.206927493, 0.258990377, 0.278710574, 0.291149527, 0.316013753, 0.388891488, 0.304111898, 0.267757207, 0.210925162, 0.287084132, 0.332426429, 0.42672804, 0.373260558, 7.48037578E-8, 0.19067812, 1.47401256E-8, 0.223029822, 0.179079413, 0.248600766, 0.27399528, 0.259228647, 0.294202209, 0.299236417, 0.223688841, 0.262799472, 2.20011476E-8, 0.266098082, 0.220890298, 0.284285516, 0.330723315, 0.226809531, 0.365380913, 0.21229881, 0.239653021, 0.24949576, 0.525830686, 0.248247579, 0.295652747, 0.258776665, 0.4832564, 0.26670444]> : tensor<64xf32>
88
%cst_3 = stablehlo.constant dense<[0.230717152, 0.253822476, -1.05429808E-6, -0.664388895, -1.65705547E-8, 0.161521927, 0.454503953, -4.301950e-07, 0.300513744, -8.005240e-06, 0.349418074, 0.311480612, -0.249529764, -3.474890e-05, 0.107726313, 0.218970656, 0.381412596, -0.529882133, -0.628644109, 0.571398079, 0.299846917, 0.584303737, 0.48202154, 0.328526348, 0.196717009, 0.194961801, 0.152145416, 0.085522361, 0.513142824, 0.0152367353, 0.166441768, 0.332394391, 0.249211237, 0.443366677, -0.280169278, -0.0203848016, -2.45068748E-7, 0.321340501, -4.9151744E-8, 0.237767309, 0.232907727, 0.315274626, 0.427762389, 0.293127537, 0.263794243, 0.675975859, 0.429100394, 0.345662743, -8.69090186E-8, 0.247294366, 0.303160846, 0.615772783, 0.39834857, 0.332067341, -0.412187815, 0.378069043, 0.178953409, 0.25747788, -0.449079722, 0.213058949, 0.569339037, 5.727430e-01, -0.402383476, 0.23406373]> : tensor<64xf32>
9+
10+
// Inputs/expected represent the input and output of the first Conv operation in the ResNet model,
11+
// obtained by passing a random image through the ONNX Runtime compiled with debug flags
12+
// to capture intermediate tensor shapes and data.
913
%0 = call @inputs() : () -> tensor<1x3x224x224xf32>
1014
%1 = call @expected() : () -> tensor<1x64x112x112xf32>
1115

12-
// Slicing the weight to reduce CPU cycles spend in interpreter.
16+
// Slicing the kernel to reduce CPU cycles spend in interpreter.
1317
// Calculating just a couple of layers already takes ~10s to complete.
14-
%weight_slice = stablehlo.slice %cst [30:32, 0:3, 0:7, 0:7] : (tensor<64x3x7x7xf32>) -> tensor<2x3x7x7xf32>
15-
%2 = stablehlo.convolution(%0, %weight_slice)
18+
%kernel_slice = stablehlo.slice %cst [30:32, 0:3, 0:7, 0:7] : (tensor<64x3x7x7xf32>) -> tensor<2x3x7x7xf32>
19+
%2 = stablehlo.convolution(%0, %kernel_slice)
1620
dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1],
1721
window = {stride = [2, 2], pad = [[3, 3], [3, 3]], rhs_dilate = [1, 1]}
1822
{batch_group_count = 1 : i64, feature_group_count = 1 : i64}

stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp

Lines changed: 36 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1477,12 +1477,12 @@ struct ReorderElementwiseAndShapeOp final
14771477
}
14781478
};
14791479

1480-
// Fuses batch normalization operation with convolution weight:
1481-
// X = conv(input, weight)
1480+
// Fuses batch normalization operation with convolution kernel:
1481+
// X = conv(input, kernel.old)
14821482
// Y = batch_norm_inference(X, ...)
14831483
// into ->
1484-
// X = conv(input, weight(new))
1485-
// Y = add(X, broadcast_in_dim(Bias(new)))
1484+
// X = conv(input, kernel.new)
1485+
// Y = add(X, broadcast_in_dim(bias.new))
14861486
//
14871487
struct FuseConvolutionBatchNormalization final
14881488
: OpRewritePattern<BatchNormInferenceOp> {
@@ -1498,55 +1498,57 @@ struct FuseConvolutionBatchNormalization final
14981498
auto convOp = op.getOperand().getDefiningOp<ConvolutionOp>();
14991499
if (!convOp) return failure();
15001500

1501-
auto convWeight = convOp.getRhs();
1502-
auto convWeightType = convWeight.getType();
1503-
auto convWeightShape = convWeightType.getShape();
1501+
auto convKernel = convOp.getRhs();
1502+
auto convKernelType = convKernel.getType();
1503+
auto convKernelShape = convKernelType.getShape();
15041504

15051505
auto dimNumbers = convOp.getDimensionNumbers();
15061506
if (dimNumbers.getInputBatchDimension() != 0 ||
15071507
dimNumbers.getInputFeatureDimension() != 1 ||
15081508
dimNumbers.getOutputBatchDimension() != 0 ||
15091509
dimNumbers.getOutputFeatureDimension() != 1 ||
15101510
dimNumbers.getKernelOutputFeatureDimension() != 0 ||
1511-
dimNumbers.getKernelInputFeatureDimension() != 1)
1512-
return rewriter.notifyMatchFailure(convOp,
1513-
"Only [b, f, ...]x[o, i, ...]->[b, f, "
1514-
"...] configuration is supported");
1511+
dimNumbers.getKernelInputFeatureDimension() != 1) {
1512+
constexpr StringLiteral msg =
1513+
"Only [b, f, ...]x[o, i, ...]->[b, f, ...] configuration is "
1514+
"supported";
1515+
return rewriter.notifyMatchFailure(convOp, msg);
1516+
}
15151517

15161518
if (convOp.getFeatureGroupCount() > 1 || convOp.getBatchGroupCount() > 1)
15171519
return rewriter.notifyMatchFailure(
15181520
convOp, "feature or batch grouping is not supported");
15191521

1520-
if (bnOperandShape[bnFeatureIndex] != convWeightShape.front())
1522+
if (bnOperandShape[bnFeatureIndex] != convKernelShape.front())
15211523
return failure();
15221524

1523-
DenseFPElementsAttr convWeightElems;
1525+
DenseFPElementsAttr convKernelElems;
15241526
DenseFPElementsAttr scaleElems;
15251527
DenseFPElementsAttr offsetElems;
15261528
DenseFPElementsAttr meanElems;
15271529
DenseFPElementsAttr varianceElems;
15281530

1529-
auto epsilon = op.getEpsilon();
1531+
const auto epsilon = op.getEpsilon();
15301532

1531-
if (!matchPattern(convWeight, m_Constant(&convWeightElems)))
1533+
if (!matchPattern(convKernel, m_Constant(&convKernelElems)))
15321534
return rewriter.notifyMatchFailure(
1533-
op, "expected constant convolution weight");
1535+
op, "expected constant convolution kernel");
15341536

15351537
if (!matchPattern(op.getScale(), m_Constant(&scaleElems)) ||
15361538
!matchPattern(op.getOffset(), m_Constant(&offsetElems)) ||
15371539
!matchPattern(op.getMean(), m_Constant(&meanElems)) ||
15381540
!matchPattern(op.getVariance(), m_Constant(&varianceElems)))
15391541
return failure();
15401542

1541-
const auto &convWeightSemantics =
1542-
cast<FloatType>(convWeightType.getElementType()).getFloatSemantics();
1543+
const auto &convKernelSemantics =
1544+
cast<FloatType>(convKernelType.getElementType()).getFloatSemantics();
15431545

1544-
// W(new) = W(old) * gamma * rsqrt(variance + epsilon)
1545-
// B(new) = (B(old) - mean) * rsqrt(variance + epsilon) * gamma + betta
1546+
// K.new = K.old * gamma * rsqrt(variance + epsilon)
1547+
// B.new = (B.old - mean) * rsqrt(variance + epsilon) * gamma + beta
15461548
// where: gamma - scaling factor
1547-
// betta - shifting factor
1549+
// beta - shifting factor
15481550
// rsqrt - reciprocal square root function
1549-
// W - weight
1551+
// K - kernel(a.k.a weight)
15501552
// B - bias
15511553
//
15521554
const SmallVector<double> multipliers = llvm::map_to_vector(
@@ -1558,22 +1560,22 @@ struct FuseConvolutionBatchNormalization final
15581560
return rsqrt * scale.convertToDouble();
15591561
});
15601562

1561-
SmallVector<APFloat> newWeight;
1562-
newWeight.reserve(convWeightType.getNumElements());
1563+
SmallVector<APFloat> newKernel;
1564+
newKernel.reserve(convKernelType.getNumElements());
15631565

15641566
const size_t outFeatureTileSize =
1565-
computeProduct(convWeightShape.drop_front());
1566-
auto it = convWeightElems.begin();
1567+
computeProduct(convKernelShape.drop_front());
1568+
auto it = convKernelElems.begin();
15671569
for (const auto &multiplier : multipliers) {
15681570
for (size_t i = 0; i < outFeatureTileSize; ++i) {
15691571
double v = (*it).convertToDouble() * multiplier;
15701572
APFloat result(v);
15711573
bool losesInfo;
15721574
if (APFloat::opStatus::opInvalidOp ==
1573-
result.convert(convWeightSemantics, APFloat::rmNearestTiesToEven,
1575+
result.convert(convKernelSemantics, APFloat::rmNearestTiesToEven,
15741576
&losesInfo))
15751577
return failure();
1576-
newWeight.push_back(result);
1578+
newKernel.push_back(result);
15771579
++it;
15781580
}
15791581
}
@@ -1591,26 +1593,26 @@ struct FuseConvolutionBatchNormalization final
15911593

15921594
bool losesInfo;
15931595
if (APFloat::opStatus::opInvalidOp ==
1594-
result.convert(convWeightSemantics, APFloat::rmNearestTiesToEven,
1596+
result.convert(convKernelSemantics, APFloat::rmNearestTiesToEven,
15951597
&losesInfo))
15961598
return failure();
15971599

15981600
biasValues.push_back(result);
15991601
}
16001602

16011603
rewriter.setInsertionPoint(op);
1602-
auto newConvWeight = rewriter.create<ConstantOp>(
1603-
convWeight.getLoc(), convWeightType,
1604-
DenseFPElementsAttr::get(convWeightType, newWeight));
1604+
auto newConvKernel = rewriter.create<ConstantOp>(
1605+
convKernel.getLoc(), convKernelType,
1606+
DenseFPElementsAttr::get(convKernelType, newKernel));
16051607

16061608
// Keep old convolution as it might have other users
16071609
auto newConvOp = rewriter.create<ConvolutionOp>(
16081610
convOp.getLoc(), convOp->getResultTypes(),
1609-
ValueRange{convOp.getLhs(), newConvWeight}, convOp->getAttrs());
1611+
ValueRange{convOp.getLhs(), newConvKernel}, convOp->getAttrs());
16101612

16111613
SmallVector<int64_t> biasShape{static_cast<int64_t>(biasValues.size())};
16121614
auto biasType =
1613-
convWeightType.cloneWith(biasShape, convWeightType.getElementType());
1615+
convKernelType.cloneWith(biasShape, convKernelType.getElementType());
16141616
auto bias = rewriter.create<ConstantOp>(
16151617
op.getLoc(), biasType, DenseFPElementsAttr::get(biasType, biasValues));
16161618

0 commit comments

Comments
 (0)