Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 48 additions & 15 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1160,6 +1160,12 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
auto elementTy = resultTy.getElementType();
Value input = op->getOperand(0);

// Figure out the accType if needed
bool widenAccTy = std::is_same_v<OpTy, tosa::ReduceSumOp> &&
isa<FloatType>(elementTy) &&
cast<FloatType>(elementTy).isBF16();
Type accTy = widenAccTy ? rewriter.getF32Type() : elementTy;

SmallVector<int64_t> reduceShape;
SmallVector<Value> dynDims;
for (unsigned i = 0; i < inputTy.getRank(); i++) {
Expand All @@ -1174,11 +1180,11 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
inputs.push_back(input);

// First fill the output buffer with the init value.
auto emptyTensor = tensor::EmptyOp::create(rewriter, loc, reduceShape,
resultTy.getElementType(), dynDims)
.getResult();
auto emptyTensor =
tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
.getResult();

auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter);
auto fillValueAttr = createInitialValueForReduceOp(op, accTy, rewriter);
if (!fillValueAttr)
return rewriter.notifyMatchFailure(
op, "No initial value found for reduction operation");
Expand Down Expand Up @@ -1231,8 +1237,14 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
std::array<Value, 2> binaryArgs{
blockArgs[0], isNanIgnoreMode ? blockArgs[2] : blockArgs[1]};
auto result = createLinalgBodyCalculationForReduceOp(
op, binaryArgs, elementTy, rewriter);

// If reduction type differs then extend (applicable to reduce_sum)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need an opType guard here if anything later mandates different handling of other reductions ?

Copy link
Contributor Author

@GeorgeARM GeorgeARM Sep 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is already kind of guarded by OpType as this is only true when the operation is a reduce sum. We might need to rework this all-together if it mandates different handling other reductions.

if (binaryArgs[0].getType() != accTy)
binaryArgs[0] = arith::ExtFOp::create(nestedBuilder, nestedLoc, accTy,
binaryArgs[0]);

auto result = createLinalgBodyCalculationForReduceOp(op, binaryArgs,
accTy, rewriter);
if (result)
didEncounterError = true;

Expand Down Expand Up @@ -1273,12 +1285,11 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,

// Create a tensor full of NaNs.
auto nanValueAttr = rewriter.getFloatAttr(
elementTy,
accTy,
APFloat::getNaN(cast<FloatType>(elementTy).getFloatSemantics(), false));
auto nanValue = arith::ConstantOp::create(rewriter, loc, nanValueAttr);
auto emptyNanTensor =
tensor::EmptyOp::create(rewriter, loc, reduceShape,
resultTy.getElementType(), dynDims)
tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
.getResult();
auto nanFilledTensor =
linalg::FillOp::create(rewriter, loc, ValueRange{nanValue},
Expand All @@ -1288,8 +1299,7 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
// Create an empty tensor, non need to fill this since it will be
// overwritten by the select.
auto finalEmptyTensor =
tensor::EmptyOp::create(rewriter, loc, reduceShape,
resultTy.getElementType(), dynDims)
tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
.getResult();

// Do a selection between the tensors akin to:
Expand All @@ -1304,9 +1314,32 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
linalgOp = linalgSelect;
}

// Truncate back to resultTy if needed
Value reducedRes = linalgOp->getResult(0);
if (widenAccTy) {
auto resEmptyOp =
tensor::EmptyOp::create(rewriter, loc, reduceShape, elementTy, dynDims)
.getResult();

const unsigned reducedRank =
cast<ShapedType>(reducedRes.getType()).getRank();
auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
reducedRes =
linalg::GenericOp::create(
rewriter, loc, resEmptyOp.getType(), ValueRange{reducedRes},
ValueRange{resEmptyOp},
ArrayRef<AffineMap>{identityMap, identityMap},
getNParallelLoopsAttrs(reducedRank),
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
Value truncf = arith::TruncFOp::create(nestedBuilder, nestedLoc,
elementTy, args[0]);
linalg::YieldOp::create(nestedBuilder, nestedLoc, truncf);
})
.getResults()[0];
}

SmallVector<ReassociationExprs, 4> reassociationMap;
uint64_t expandInputRank =
cast<ShapedType>(linalgOp->getResults()[0].getType()).getRank();
uint64_t expandInputRank = cast<ShapedType>(reducedRes.getType()).getRank();
reassociationMap.resize(expandInputRank);

for (uint64_t i = 0; i < expandInputRank; i++) {
Expand All @@ -1324,8 +1357,8 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
// since here we know which dimension to expand, and `tosa::ReshapeOp` would
// not have access to such information. This matters when handling dynamically
// sized tensors.
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
op, resultTy, linalgOp->getResults()[0], reassociationMap);
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(op, resultTy, reducedRes,
reassociationMap);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

strange whitespace

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure. Just ran clang-format. The reassociationMap just ends to the line below while the other move above.

return success();
}

Expand Down
26 changes: 26 additions & 0 deletions mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -912,6 +912,32 @@ func.func @test_identity(%arg0: tensor<1xf32>, %arg1: tensor<1xi32>) -> (tensor<

// -----

// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
// CHECK-LABEL: @reduce_bf16
// CHECK-SAME: [[ARG0:%.+]]: tensor<5x4xbf16>
func.func @reduce_bf16(%arg0: tensor<5x4xbf16>) -> () {
// CHECK: [[INIT:%.+]] = tensor.empty() : tensor<4xf32>
// CHECK: [[CST0:%.+]] = arith.constant 0.0
// CHECK: [[FILL:%.+]] = linalg.fill ins([[CST0]]{{.*}}outs([[INIT]]
// CHECK: [[REDUCE:%.+]] = linalg.reduce ins([[ARG0]] : tensor<5x4xbf16>) outs([[FILL]] : tensor<4xf32>) dimensions = [0]
// CHECK: (%[[ARG1:.*]]: bf16, %[[ARG2:.*]]: f32) {
// CHECK: [[EXTF:%.+]] = arith.extf %[[ARG1]] : bf16 to f32
// CHECK: [[ACC:%.+]] = arith.addf [[EXTF]], %[[ARG2]] : f32
// CHECK: linalg.yield [[ACC]] : f32
// CHECK: }
// CHECK: [[INIT_RES:%.+]] = tensor.empty() : tensor<4xbf16>
// CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins([[REDUCE]] : tensor<4xf32>) outs([[INIT_RES]] : tensor<4xbf16>)
// CHECK: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: bf16):
// CHECK: [[TRUNCF:%.+]] = arith.truncf %[[IN]] : f32 to bf16
// CHECK: linalg.yield [[TRUNCF]] : bf16
// CHECK: }
// CHECK: tensor.expand_shape [[RES]] {{\[}}[0, 1]] output_shape [1, 4] : tensor<4xbf16> into tensor<1x4xbf16>
%0 = tosa.reduce_sum %arg0 {axis = 0 : i32} : (tensor<5x4xbf16>) -> tensor<1x4xbf16>
return
}

// -----

// CHECK-LABEL: @reduce_float
// CHECK-SAME: [[ARG0:%.+]]: tensor<5x4xf32>
func.func @reduce_float(%arg0: tensor<5x4xf32>) -> () {
Expand Down
Loading