@@ -1160,6 +1160,12 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
1160
1160
auto elementTy = resultTy.getElementType ();
1161
1161
Value input = op->getOperand (0 );
1162
1162
1163
+ // Figure out the accType if needed
1164
+ bool widenAccTy = std::is_same_v<OpTy, tosa::ReduceSumOp> &&
1165
+ isa<FloatType>(elementTy) &&
1166
+ cast<FloatType>(elementTy).isBF16 ();
1167
+ Type accTy = widenAccTy ? rewriter.getF32Type () : elementTy;
1168
+
1163
1169
SmallVector<int64_t > reduceShape;
1164
1170
SmallVector<Value> dynDims;
1165
1171
for (unsigned i = 0 ; i < inputTy.getRank (); i++) {
@@ -1174,11 +1180,11 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
1174
1180
inputs.push_back (input);
1175
1181
1176
1182
// First fill the output buffer with the init value.
1177
- auto emptyTensor = tensor::EmptyOp::create (rewriter, loc, reduceShape,
1178
- resultTy. getElementType () , dynDims)
1179
- .getResult ();
1183
+ auto emptyTensor =
1184
+ tensor::EmptyOp::create (rewriter, loc, reduceShape, accTy , dynDims)
1185
+ .getResult ();
1180
1186
1181
- auto fillValueAttr = createInitialValueForReduceOp (op, elementTy , rewriter);
1187
+ auto fillValueAttr = createInitialValueForReduceOp (op, accTy , rewriter);
1182
1188
if (!fillValueAttr)
1183
1189
return rewriter.notifyMatchFailure (
1184
1190
op, " No initial value found for reduction operation" );
@@ -1231,8 +1237,14 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
1231
1237
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
1232
1238
std::array<Value, 2 > binaryArgs{
1233
1239
blockArgs[0 ], isNanIgnoreMode ? blockArgs[2 ] : blockArgs[1 ]};
1234
- auto result = createLinalgBodyCalculationForReduceOp (
1235
- op, binaryArgs, elementTy, rewriter);
1240
+
1241
+ // If reduction type differs then extend (applicable to reduce_sum)
1242
+ if (binaryArgs[0 ].getType () != accTy)
1243
+ binaryArgs[0 ] = arith::ExtFOp::create (nestedBuilder, nestedLoc, accTy,
1244
+ binaryArgs[0 ]);
1245
+
1246
+ auto result = createLinalgBodyCalculationForReduceOp (op, binaryArgs,
1247
+ accTy, rewriter);
1236
1248
if (result)
1237
1249
didEncounterError = true ;
1238
1250
@@ -1273,12 +1285,11 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
1273
1285
1274
1286
// Create a tensor full of NaNs.
1275
1287
auto nanValueAttr = rewriter.getFloatAttr (
1276
- elementTy ,
1288
+ accTy ,
1277
1289
APFloat::getNaN (cast<FloatType>(elementTy).getFloatSemantics (), false ));
1278
1290
auto nanValue = arith::ConstantOp::create (rewriter, loc, nanValueAttr);
1279
1291
auto emptyNanTensor =
1280
- tensor::EmptyOp::create (rewriter, loc, reduceShape,
1281
- resultTy.getElementType (), dynDims)
1292
+ tensor::EmptyOp::create (rewriter, loc, reduceShape, accTy, dynDims)
1282
1293
.getResult ();
1283
1294
auto nanFilledTensor =
1284
1295
linalg::FillOp::create (rewriter, loc, ValueRange{nanValue},
@@ -1288,8 +1299,7 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
1288
1299
// Create an empty tensor, non need to fill this since it will be
1289
1300
// overwritten by the select.
1290
1301
auto finalEmptyTensor =
1291
- tensor::EmptyOp::create (rewriter, loc, reduceShape,
1292
- resultTy.getElementType (), dynDims)
1302
+ tensor::EmptyOp::create (rewriter, loc, reduceShape, accTy, dynDims)
1293
1303
.getResult ();
1294
1304
1295
1305
// Do a selection between the tensors akin to:
@@ -1304,9 +1314,32 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
1304
1314
linalgOp = linalgSelect;
1305
1315
}
1306
1316
1317
+ // Truncate back to resultTy if needed
1318
+ Value reducedRes = linalgOp->getResult (0 );
1319
+ if (widenAccTy) {
1320
+ auto resEmptyOp =
1321
+ tensor::EmptyOp::create (rewriter, loc, reduceShape, elementTy, dynDims)
1322
+ .getResult ();
1323
+
1324
+ const unsigned reducedRank =
1325
+ cast<ShapedType>(reducedRes.getType ()).getRank ();
1326
+ auto identityMap = rewriter.getMultiDimIdentityMap (reducedRank);
1327
+ reducedRes =
1328
+ linalg::GenericOp::create (
1329
+ rewriter, loc, resEmptyOp.getType (), ValueRange{reducedRes},
1330
+ ValueRange{resEmptyOp},
1331
+ ArrayRef<AffineMap>{identityMap, identityMap},
1332
+ getNParallelLoopsAttrs (reducedRank),
1333
+ [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
1334
+ Value truncf = arith::TruncFOp::create (nestedBuilder, nestedLoc,
1335
+ elementTy, args[0 ]);
1336
+ linalg::YieldOp::create (nestedBuilder, nestedLoc, truncf);
1337
+ })
1338
+ .getResults ()[0 ];
1339
+ }
1340
+
1307
1341
SmallVector<ReassociationExprs, 4 > reassociationMap;
1308
- uint64_t expandInputRank =
1309
- cast<ShapedType>(linalgOp->getResults ()[0 ].getType ()).getRank ();
1342
+ uint64_t expandInputRank = cast<ShapedType>(reducedRes.getType ()).getRank ();
1310
1343
reassociationMap.resize (expandInputRank);
1311
1344
1312
1345
for (uint64_t i = 0 ; i < expandInputRank; i++) {
@@ -1324,8 +1357,8 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
1324
1357
// since here we know which dimension to expand, and `tosa::ReshapeOp` would
1325
1358
// not have access to such information. This matters when handling dynamically
1326
1359
// sized tensors.
1327
- rewriter.replaceOpWithNewOp <tensor::ExpandShapeOp>(
1328
- op, resultTy, linalgOp-> getResults ()[ 0 ], reassociationMap);
1360
+ rewriter.replaceOpWithNewOp <tensor::ExpandShapeOp>(op, resultTy, reducedRes,
1361
+ reassociationMap);
1329
1362
return success ();
1330
1363
}
1331
1364
0 commit comments