Skip to content

Commit 147c226

Browse files
committed
fix naming, comments
1 parent 3ac0c23 commit 147c226

File tree

2 files changed

+43
-37
lines changed

2 files changed

+43
-37
lines changed

stablehlo/testdata/bn_conv_fuse_float32.mlir

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,17 @@ module @jit_main attributes {torch.debug_module_name = "ResNet"} {
66
%cst_1 = stablehlo.constant dense<[1.01694489, 3.71674347, 5.81334356E-11, 3.28254271, 1.71074404E-13, 0.658226967, 4.37006235, 6.60045282E-12, 0.915522992, 1.93175254E-9, 4.12558556, 2.74399233, 2.8390913, 4.79658588E-8, 11.0722713, 0.500745952, 2.23128176, 4.82570696, 2.69861364, 9.36995506, 3.73391747, 5.48429585, 5.7126689, 0.445444882, 0.436275303, 7.15633583, 13.7179089, 5.25117493, 6.81737518, 1.67235756, 1.65343034, 1.23245978, 4.90762854, 3.07305121, 4.23838568, 4.99363518, 1.44646307E-12, 1.52116203, 1.03519833E-13, 0.351344079, 0.17024748, 1.42054474, 1.90848303, 2.15124035, 2.66084933, 4.84443378, 1.92971194, 1.49994361, 2.94806145E-13, 1.53064024, 0.365027189, 2.93755412, 5.46641159, 0.707924544, 3.33150721, 0.771802961, 2.40678358, 6.5213666, 4.12625027, 1.05063522, 2.95303202, 11.3656216, 4.76904678, 1.65587807]> : tensor<64xf32>
77
%cst_2 = stablehlo.constant dense<[0.234872743, 0.266257942, -5.10959595E-8, 0.518699706, 3.44040196E-9, 0.222385287, 0.422887057, 1.31532403E-7, 0.25093165, 1.5152026E-6, 0.316871643, 0.250491828, 0.378926098, 1.08618351E-5, 2.752640e-01, 0.236741036, 0.242021769, 0.395314813, 0.469346285, 0.2908957, 0.272684187, 0.27802828, 0.290692091, 0.206927493, 0.258990377, 0.278710574, 0.291149527, 0.316013753, 0.388891488, 0.304111898, 0.267757207, 0.210925162, 0.287084132, 0.332426429, 0.42672804, 0.373260558, 7.48037578E-8, 0.19067812, 1.47401256E-8, 0.223029822, 0.179079413, 0.248600766, 0.27399528, 0.259228647, 0.294202209, 0.299236417, 0.223688841, 0.262799472, 2.20011476E-8, 0.266098082, 0.220890298, 0.284285516, 0.330723315, 0.226809531, 0.365380913, 0.21229881, 0.239653021, 0.24949576, 0.525830686, 0.248247579, 0.295652747, 0.258776665, 0.4832564, 0.26670444]> : tensor<64xf32>
88
%cst_3 = stablehlo.constant dense<[0.230717152, 0.253822476, -1.05429808E-6, -0.664388895, -1.65705547E-8, 0.161521927, 0.454503953, -4.301950e-07, 0.300513744, -8.005240e-06, 0.349418074, 0.311480612, -0.249529764, -3.474890e-05, 0.107726313, 0.218970656, 0.381412596, -0.529882133, -0.628644109, 0.571398079, 0.299846917, 0.584303737, 0.48202154, 0.328526348, 0.196717009, 0.194961801, 0.152145416, 0.085522361, 0.513142824, 0.0152367353, 0.166441768, 0.332394391, 0.249211237, 0.443366677, -0.280169278, -0.0203848016, -2.45068748E-7, 0.321340501, -4.9151744E-8, 0.237767309, 0.232907727, 0.315274626, 0.427762389, 0.293127537, 0.263794243, 0.675975859, 0.429100394, 0.345662743, -8.69090186E-8, 0.247294366, 0.303160846, 0.615772783, 0.39834857, 0.332067341, -0.412187815, 0.378069043, 0.178953409, 0.25747788, -0.449079722, 0.213058949, 0.569339037, 5.727430e-01, -0.402383476, 0.23406373]> : tensor<64xf32>
9+
10+
// Inputs/expected represent the input and output of the first Conv operation in the ResNet model,
11+
// obtained by passing a random image through the ONNX Runtime compiled with debug flags
12+
// to capture intermediate tensor shapes and data.
913
%0 = call @inputs() : () -> tensor<1x3x224x224xf32>
1014
%1 = call @expected() : () -> tensor<1x64x112x112xf32>
1115

12-
// Slicing the weight to reduce CPU cycles spend in interpreter.
16+
// Slicing the kernel to reduce CPU cycles spend in interpreter.
1317
// Calculating just a couple of layers already takes ~10s to complete.
14-
%weight_slice = stablehlo.slice %cst [30:32, 0:3, 0:7, 0:7] : (tensor<64x3x7x7xf32>) -> tensor<2x3x7x7xf32>
15-
%2 = stablehlo.convolution(%0, %weight_slice)
18+
%kernel_slice = stablehlo.slice %cst [30:32, 0:3, 0:7, 0:7] : (tensor<64x3x7x7xf32>) -> tensor<2x3x7x7xf32>
19+
%2 = stablehlo.convolution(%0, %kernel_slice)
1620
dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1],
1721
window = {stride = [2, 2], pad = [[3, 3], [3, 3]], rhs_dilate = [1, 1]}
1822
{batch_group_count = 1 : i64, feature_group_count = 1 : i64}

stablehlo/transforms/StablehloAggressiveSimplification.cpp

Lines changed: 36 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1468,12 +1468,12 @@ struct ReorderElementwiseAndShapeOp final
14681468
}
14691469
};
14701470

1471-
// Fuses batch normalization operation with convolution weight:
1472-
// X = conv(input, weight)
1471+
// Fuses batch normalization operation with convolution kernel:
1472+
// X = conv(input, kernel.old)
14731473
// Y = batch_norm_inference(X, ...)
14741474
// into ->
1475-
// X = conv(input, weight(new))
1476-
// Y = add(X, broadcast_in_dim(Bias(new)))
1475+
// X = conv(input, kernel.new)
1476+
// Y = add(X, broadcast_in_dim(bias.new))
14771477
//
14781478
struct FuseConvolutionBatchNormalization final
14791479
: OpRewritePattern<BatchNormInferenceOp> {
@@ -1489,55 +1489,57 @@ struct FuseConvolutionBatchNormalization final
14891489
auto convOp = op.getOperand().getDefiningOp<ConvolutionOp>();
14901490
if (!convOp) return failure();
14911491

1492-
auto convWeight = convOp.getRhs();
1493-
auto convWeightType = convWeight.getType();
1494-
auto convWeightShape = convWeightType.getShape();
1492+
auto convKernel = convOp.getRhs();
1493+
auto convKernelType = convKernel.getType();
1494+
auto convKernelShape = convKernelType.getShape();
14951495

14961496
auto dimNumbers = convOp.getDimensionNumbers();
14971497
if (dimNumbers.getInputBatchDimension() != 0 ||
14981498
dimNumbers.getInputFeatureDimension() != 1 ||
14991499
dimNumbers.getOutputBatchDimension() != 0 ||
15001500
dimNumbers.getOutputFeatureDimension() != 1 ||
15011501
dimNumbers.getKernelOutputFeatureDimension() != 0 ||
1502-
dimNumbers.getKernelInputFeatureDimension() != 1)
1503-
return rewriter.notifyMatchFailure(convOp,
1504-
"Only [b, f, ...]x[o, i, ...]->[b, f, "
1505-
"...] configuration is supported");
1502+
dimNumbers.getKernelInputFeatureDimension() != 1) {
1503+
constexpr StringLiteral msg =
1504+
"Only [b, f, ...]x[o, i, ...]->[b, f, ...] configuration is "
1505+
"supported";
1506+
return rewriter.notifyMatchFailure(convOp, msg);
1507+
}
15061508

15071509
if (convOp.getFeatureGroupCount() > 1 || convOp.getBatchGroupCount() > 1)
15081510
return rewriter.notifyMatchFailure(
15091511
convOp, "feature or batch grouping is not supported");
15101512

1511-
if (bnOperandShape[bnFeatureIndex] != convWeightShape.front())
1513+
if (bnOperandShape[bnFeatureIndex] != convKernelShape.front())
15121514
return failure();
15131515

1514-
DenseFPElementsAttr convWeightElems;
1516+
DenseFPElementsAttr convKernelElems;
15151517
DenseFPElementsAttr scaleElems;
15161518
DenseFPElementsAttr offsetElems;
15171519
DenseFPElementsAttr meanElems;
15181520
DenseFPElementsAttr varianceElems;
15191521

1520-
auto epsilon = op.getEpsilon();
1522+
const auto epsilon = op.getEpsilon();
15211523

1522-
if (!matchPattern(convWeight, m_Constant(&convWeightElems)))
1524+
if (!matchPattern(convKernel, m_Constant(&convKernelElems)))
15231525
return rewriter.notifyMatchFailure(
1524-
op, "expected constant convolution weight");
1526+
op, "expected constant convolution kernel");
15251527

15261528
if (!matchPattern(op.getScale(), m_Constant(&scaleElems)) ||
15271529
!matchPattern(op.getOffset(), m_Constant(&offsetElems)) ||
15281530
!matchPattern(op.getMean(), m_Constant(&meanElems)) ||
15291531
!matchPattern(op.getVariance(), m_Constant(&varianceElems)))
15301532
return failure();
15311533

1532-
const auto &convWeightSemantics =
1533-
cast<FloatType>(convWeightType.getElementType()).getFloatSemantics();
1534+
const auto &convKernelSemantics =
1535+
cast<FloatType>(convKernelType.getElementType()).getFloatSemantics();
15341536

1535-
// W(new) = W(old) * gamma * rsqrt(variance + epsilon)
1536-
// B(new) = (B(old) - mean) * rsqrt(variance + epsilon) * gamma + betta
1537+
// K.new = K.old * gamma * rsqrt(variance + epsilon)
1538+
// B.new = (B.old - mean) * rsqrt(variance + epsilon) * gamma + beta
15371539
// where: gamma - scaling factor
1538-
// betta - shifting factor
1540+
// beta - shifting factor
15391541
// rsqrt - reciprocal square root function
1540-
// W - weight
1542+
// K - kernel(a.k.a weight)
15411543
// B - bias
15421544
//
15431545
const SmallVector<double> multipliers = llvm::map_to_vector(
@@ -1549,22 +1551,22 @@ struct FuseConvolutionBatchNormalization final
15491551
return rsqrt * scale.convertToDouble();
15501552
});
15511553

1552-
SmallVector<APFloat> newWeight;
1553-
newWeight.reserve(convWeightType.getNumElements());
1554+
SmallVector<APFloat> newKernel;
1555+
newKernel.reserve(convKernelType.getNumElements());
15541556

15551557
const size_t outFeatureTileSize =
1556-
computeProduct(convWeightShape.drop_front());
1557-
auto it = convWeightElems.begin();
1558+
computeProduct(convKernelShape.drop_front());
1559+
auto it = convKernelElems.begin();
15581560
for (const auto &multiplier : multipliers) {
15591561
for (size_t i = 0; i < outFeatureTileSize; ++i) {
15601562
double v = (*it).convertToDouble() * multiplier;
15611563
APFloat result(v);
15621564
bool losesInfo;
15631565
if (APFloat::opStatus::opInvalidOp ==
1564-
result.convert(convWeightSemantics, APFloat::rmNearestTiesToEven,
1566+
result.convert(convKernelSemantics, APFloat::rmNearestTiesToEven,
15651567
&losesInfo))
15661568
return failure();
1567-
newWeight.push_back(result);
1569+
newKernel.push_back(result);
15681570
++it;
15691571
}
15701572
}
@@ -1582,26 +1584,26 @@ struct FuseConvolutionBatchNormalization final
15821584

15831585
bool losesInfo;
15841586
if (APFloat::opStatus::opInvalidOp ==
1585-
result.convert(convWeightSemantics, APFloat::rmNearestTiesToEven,
1587+
result.convert(convKernelSemantics, APFloat::rmNearestTiesToEven,
15861588
&losesInfo))
15871589
return failure();
15881590

15891591
biasValues.push_back(result);
15901592
}
15911593

15921594
rewriter.setInsertionPoint(op);
1593-
auto newConvWeight = rewriter.create<ConstantOp>(
1594-
convWeight.getLoc(), convWeightType,
1595-
DenseFPElementsAttr::get(convWeightType, newWeight));
1595+
auto newConvKernel = rewriter.create<ConstantOp>(
1596+
convKernel.getLoc(), convKernelType,
1597+
DenseFPElementsAttr::get(convKernelType, newKernel));
15961598

15971599
// Keep old convolution as it might have other users
15981600
auto newConvOp = rewriter.create<ConvolutionOp>(
15991601
convOp.getLoc(), convOp->getResultTypes(),
1600-
ValueRange{convOp.getLhs(), newConvWeight}, convOp->getAttrs());
1602+
ValueRange{convOp.getLhs(), newConvKernel}, convOp->getAttrs());
16011603

16021604
SmallVector<int64_t> biasShape{static_cast<int64_t>(biasValues.size())};
16031605
auto biasType =
1604-
convWeightType.cloneWith(biasShape, convWeightType.getElementType());
1606+
convKernelType.cloneWith(biasShape, convKernelType.getElementType());
16051607
auto bias = rewriter.create<ConstantOp>(
16061608
op.getLoc(), biasType, DenseFPElementsAttr::get(biasType, biasValues));
16071609

0 commit comments

Comments
 (0)