Skip to content

Commit de0094e

Browse files
authored
[mlir][tosa] Introduce accumulator type for reduce_sum on bf16 (#158389)
TOSA requires that `reduce_sum` operations on bf16 accumulate into fp32. This change updates the `linalg` legalization by introducing an explicit accumulator type to ensure compliance with the specification. --------- Signed-off-by: Georgios Pinitas <[email protected]>
1 parent 50bcf68 commit de0094e

File tree

2 files changed

+74
-15
lines changed

2 files changed

+74
-15
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 48 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1160,6 +1160,12 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
11601160
auto elementTy = resultTy.getElementType();
11611161
Value input = op->getOperand(0);
11621162

1163+
// Figure out the accType if needed
1164+
bool widenAccTy = std::is_same_v<OpTy, tosa::ReduceSumOp> &&
1165+
isa<FloatType>(elementTy) &&
1166+
cast<FloatType>(elementTy).isBF16();
1167+
Type accTy = widenAccTy ? rewriter.getF32Type() : elementTy;
1168+
11631169
SmallVector<int64_t> reduceShape;
11641170
SmallVector<Value> dynDims;
11651171
for (unsigned i = 0; i < inputTy.getRank(); i++) {
@@ -1174,11 +1180,11 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
11741180
inputs.push_back(input);
11751181

11761182
// First fill the output buffer with the init value.
1177-
auto emptyTensor = tensor::EmptyOp::create(rewriter, loc, reduceShape,
1178-
resultTy.getElementType(), dynDims)
1179-
.getResult();
1183+
auto emptyTensor =
1184+
tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
1185+
.getResult();
11801186

1181-
auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter);
1187+
auto fillValueAttr = createInitialValueForReduceOp(op, accTy, rewriter);
11821188
if (!fillValueAttr)
11831189
return rewriter.notifyMatchFailure(
11841190
op, "No initial value found for reduction operation");
@@ -1231,8 +1237,14 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
12311237
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
12321238
std::array<Value, 2> binaryArgs{
12331239
blockArgs[0], isNanIgnoreMode ? blockArgs[2] : blockArgs[1]};
1234-
auto result = createLinalgBodyCalculationForReduceOp(
1235-
op, binaryArgs, elementTy, rewriter);
1240+
1241+
// If reduction type differs then extend (applicable to reduce_sum)
1242+
if (binaryArgs[0].getType() != accTy)
1243+
binaryArgs[0] = arith::ExtFOp::create(nestedBuilder, nestedLoc, accTy,
1244+
binaryArgs[0]);
1245+
1246+
auto result = createLinalgBodyCalculationForReduceOp(op, binaryArgs,
1247+
accTy, rewriter);
12361248
if (result)
12371249
didEncounterError = true;
12381250

@@ -1273,12 +1285,11 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
12731285

12741286
// Create a tensor full of NaNs.
12751287
auto nanValueAttr = rewriter.getFloatAttr(
1276-
elementTy,
1288+
accTy,
12771289
APFloat::getNaN(cast<FloatType>(elementTy).getFloatSemantics(), false));
12781290
auto nanValue = arith::ConstantOp::create(rewriter, loc, nanValueAttr);
12791291
auto emptyNanTensor =
1280-
tensor::EmptyOp::create(rewriter, loc, reduceShape,
1281-
resultTy.getElementType(), dynDims)
1292+
tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
12821293
.getResult();
12831294
auto nanFilledTensor =
12841295
linalg::FillOp::create(rewriter, loc, ValueRange{nanValue},
@@ -1288,8 +1299,7 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
12881299
// Create an empty tensor, non need to fill this since it will be
12891300
// overwritten by the select.
12901301
auto finalEmptyTensor =
1291-
tensor::EmptyOp::create(rewriter, loc, reduceShape,
1292-
resultTy.getElementType(), dynDims)
1302+
tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
12931303
.getResult();
12941304

12951305
// Do a selection between the tensors akin to:
@@ -1304,9 +1314,32 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
13041314
linalgOp = linalgSelect;
13051315
}
13061316

1317+
// Truncate back to resultTy if needed
1318+
Value reducedRes = linalgOp->getResult(0);
1319+
if (widenAccTy) {
1320+
auto resEmptyOp =
1321+
tensor::EmptyOp::create(rewriter, loc, reduceShape, elementTy, dynDims)
1322+
.getResult();
1323+
1324+
const unsigned reducedRank =
1325+
cast<ShapedType>(reducedRes.getType()).getRank();
1326+
auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
1327+
reducedRes =
1328+
linalg::GenericOp::create(
1329+
rewriter, loc, resEmptyOp.getType(), ValueRange{reducedRes},
1330+
ValueRange{resEmptyOp},
1331+
ArrayRef<AffineMap>{identityMap, identityMap},
1332+
getNParallelLoopsAttrs(reducedRank),
1333+
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
1334+
Value truncf = arith::TruncFOp::create(nestedBuilder, nestedLoc,
1335+
elementTy, args[0]);
1336+
linalg::YieldOp::create(nestedBuilder, nestedLoc, truncf);
1337+
})
1338+
.getResults()[0];
1339+
}
1340+
13071341
SmallVector<ReassociationExprs, 4> reassociationMap;
1308-
uint64_t expandInputRank =
1309-
cast<ShapedType>(linalgOp->getResults()[0].getType()).getRank();
1342+
uint64_t expandInputRank = cast<ShapedType>(reducedRes.getType()).getRank();
13101343
reassociationMap.resize(expandInputRank);
13111344

13121345
for (uint64_t i = 0; i < expandInputRank; i++) {
@@ -1324,8 +1357,8 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
13241357
// since here we know which dimension to expand, and `tosa::ReshapeOp` would
13251358
// not have access to such information. This matters when handling dynamically
13261359
// sized tensors.
1327-
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1328-
op, resultTy, linalgOp->getResults()[0], reassociationMap);
1360+
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(op, resultTy, reducedRes,
1361+
reassociationMap);
13291362
return success();
13301363
}
13311364

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -912,6 +912,32 @@ func.func @test_identity(%arg0: tensor<1xf32>, %arg1: tensor<1xi32>) -> (tensor<
912912

913913
// -----
914914

915+
// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
916+
// CHECK-LABEL: @reduce_bf16
917+
// CHECK-SAME: [[ARG0:%.+]]: tensor<5x4xbf16>
918+
func.func @reduce_bf16(%arg0: tensor<5x4xbf16>) -> () {
919+
// CHECK: [[INIT:%.+]] = tensor.empty() : tensor<4xf32>
920+
// CHECK: [[CST0:%.+]] = arith.constant 0.0
921+
// CHECK: [[FILL:%.+]] = linalg.fill ins([[CST0]]{{.*}}outs([[INIT]]
922+
// CHECK: [[REDUCE:%.+]] = linalg.reduce ins([[ARG0]] : tensor<5x4xbf16>) outs([[FILL]] : tensor<4xf32>) dimensions = [0]
923+
// CHECK: (%[[ARG1:.*]]: bf16, %[[ARG2:.*]]: f32) {
924+
// CHECK: [[EXTF:%.+]] = arith.extf %[[ARG1]] : bf16 to f32
925+
// CHECK: [[ACC:%.+]] = arith.addf [[EXTF]], %[[ARG2]] : f32
926+
// CHECK: linalg.yield [[ACC]] : f32
927+
// CHECK: }
928+
// CHECK: [[INIT_RES:%.+]] = tensor.empty() : tensor<4xbf16>
929+
// CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins([[REDUCE]] : tensor<4xf32>) outs([[INIT_RES]] : tensor<4xbf16>)
930+
// CHECK: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: bf16):
931+
// CHECK: [[TRUNCF:%.+]] = arith.truncf %[[IN]] : f32 to bf16
932+
// CHECK: linalg.yield [[TRUNCF]] : bf16
933+
// CHECK: }
934+
// CHECK: tensor.expand_shape [[RES]] {{\[}}[0, 1]] output_shape [1, 4] : tensor<4xbf16> into tensor<1x4xbf16>
935+
%0 = tosa.reduce_sum %arg0 {axis = 0 : i32} : (tensor<5x4xbf16>) -> tensor<1x4xbf16>
936+
return
937+
}
938+
939+
// -----
940+
915941
// CHECK-LABEL: @reduce_float
916942
// CHECK-SAME: [[ARG0:%.+]]: tensor<5x4xf32>
917943
func.func @reduce_float(%arg0: tensor<5x4xf32>) -> () {

0 commit comments

Comments
 (0)