3232#include " llvm/ADT/Sequence.h"
3333
3434#include < numeric>
35+ #include < type_traits>
3536
3637using namespace mlir ;
3738using namespace mlir ::tosa;
3839
40+ // Helper function to materialize the semantically correct compare and select
41+ // operations given a binary operation with a specific NaN propagation mode.
42+ //
43+ // In the case of "PROPAGATE" semantics no compare and selection is required and
44+ // this function does nothing.
45+ //
46+ // In the case of "IGNORE" semantics this function materializes a comparison of
47+ // the current operands to the op which will return true for any NaN
48+ // argument and then selects between the non-NaN operation argument and the
49+ // calculated result based on whether the lhs or rhs is NaN or not. In pseudo
50+ // code:
51+ //
52+ // binary<op>(lhs, rhs):
53+ // result = op(lhs, rhs)
54+ // if lhs == NaN return rhs
55+ // if rhs == NaN return lhs
56+ // return result
57+ template <typename OpTy>
58+ static Value
59+ materializeBinaryNanCheckIfRequired (OpTy op, PatternRewriter &rewriter,
60+ Value lhs, Value rhs, Value result) {
61+ auto nanMode = op.getNanMode ();
62+ if (nanMode == " PROPAGATE" )
63+ return result;
64+
65+ // Unordered comparison of NaN against itself will always return true.
66+ Value lhsIsNaN = rewriter.create <arith::CmpFOp>(
67+ op.getLoc (), arith::CmpFPredicate::UNO, lhs, lhs);
68+ Value rhsIsNaN = rewriter.create <arith::CmpFOp>(
69+ op.getLoc (), arith::CmpFPredicate::UNO, rhs, rhs);
70+ Value rhsOrResult =
71+ rewriter.create <arith::SelectOp>(op.getLoc (), lhsIsNaN, rhs, result);
72+ return rewriter.create <arith::SelectOp>(op.getLoc (), rhsIsNaN, lhs,
73+ rhsOrResult);
74+ }
75+
3976template <typename T>
4077static arith::ConstantOp
4178createConstFromIntAttribute (Operation *op, const std::string &attrName,
@@ -367,7 +404,9 @@ static Value createLinalgBodyCalculationForElementwiseOp(
367404
368405 // tosa::MaximumOp
369406 if (isa<tosa::MaximumOp>(op) && isa<FloatType>(elementTy)) {
370- return rewriter.create <arith::MaximumFOp>(loc, args[0 ], args[1 ]);
407+ auto max = rewriter.create <arith::MaximumFOp>(loc, args[0 ], args[1 ]);
408+ return materializeBinaryNanCheckIfRequired (llvm::cast<tosa::MaximumOp>(op),
409+ rewriter, args[0 ], args[1 ], max);
371410 }
372411
373412 if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger ()) {
@@ -376,7 +415,9 @@ static Value createLinalgBodyCalculationForElementwiseOp(
376415
377416 // tosa::MinimumOp
378417 if (isa<tosa::MinimumOp>(op) && isa<FloatType>(elementTy)) {
379- return rewriter.create <arith::MinimumFOp>(loc, args[0 ], args[1 ]);
418+ auto min = rewriter.create <arith::MinimumFOp>(loc, args[0 ], args[1 ]);
419+ return materializeBinaryNanCheckIfRequired (llvm::cast<tosa::MinimumOp>(op),
420+ rewriter, args[0 ], args[1 ], min);
380421 }
381422
382423 if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger ()) {
@@ -404,7 +445,31 @@ static Value createLinalgBodyCalculationForElementwiseOp(
404445 loc, elementTy, rewriter.getFloatAttr (elementTy, minApf));
405446 auto max = rewriter.create <arith::ConstantOp>(
406447 loc, elementTy, rewriter.getFloatAttr (elementTy, maxApf));
407- return clampFloatHelper (loc, args[0 ], min, max, rewriter);
448+ auto result = clampFloatHelper (loc, args[0 ], min, max, rewriter);
449+
450+ auto clampOp = llvm::cast<tosa::ClampOp>(op);
451+ const auto nanMode = clampOp.getNanMode ();
452+ // In the case of "PROPAGATE" semantics no compare and selection is
453+ // required.
454+ if (nanMode == " PROPAGATE" )
455+ return result;
456+
457+ // In the case of "IGNORE" semantics materialize a comparison
458+ // of the current operand to the reduction which will return true for a NaN
459+ // argument and then selects between the initial reduction value and the
460+ // calculated result based on whether the argument is NaN or not. In pseudo
461+ // code:
462+ //
463+ // reduce<op>(x, init):
464+ // result = op(init, x)
465+ // return init if x == NaN else result
466+
467+ // Unordered comparison of NaN against itself will always return true.
468+ Value isNaN = rewriter.create <arith::CmpFOp>(
469+ op->getLoc (), arith::CmpFPredicate::UNO, args[0 ], args[0 ]);
470+ // TOSA specifies that in "ignore" NaN mode the result is "min" if the input
471+ // is NaN.
472+ return rewriter.create <arith::SelectOp>(op->getLoc (), isNaN, min, result);
408473 }
409474
410475 if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
@@ -1078,7 +1143,8 @@ static Value createLinalgBodyCalculationForReduceOp(Operation *op,
10781143// Performs the match and rewrite for reduction operations. This includes
10791144// declaring a correctly sized initial value, and the linalg.generic operation
10801145// that reduces across the specified axis.
1081- static LogicalResult reduceMatchAndRewriteHelper (Operation *op, uint64_t axis,
1146+ template <typename OpTy>
1147+ static LogicalResult reduceMatchAndRewriteHelper (OpTy op, uint64_t axis,
10821148 PatternRewriter &rewriter) {
10831149 auto loc = op->getLoc ();
10841150 auto inputTy = cast<ShapedType>(op->getOperand (0 ).getType ());
@@ -1096,6 +1162,9 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
10961162 }
10971163 }
10981164
1165+ SmallVector<Value> inputs, outputs;
1166+ inputs.push_back (input);
1167+
10991168 // First fill the output buffer with the init value.
11001169 auto emptyTensor =
11011170 rewriter
@@ -1113,26 +1182,127 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
11131182 .create <linalg::FillOp>(loc, ValueRange{fillValue},
11141183 ValueRange{emptyTensor})
11151184 .result ();
1185+ outputs.push_back (filledTensor);
1186+
1187+ bool isNanIgnoreMode = false ;
1188+ if constexpr (std::is_same_v<OpTy, tosa::ReduceMinOp> ||
1189+ std::is_same_v<OpTy, tosa::ReduceMaxOp>) {
1190+ if (op.getNanMode () == " IGNORE" ) {
1191+ isNanIgnoreMode = true ;
1192+ // Because the TOSA spec requires the result be NaN iff all elements in
1193+ // the reduction are NaN we can't simply perform a compare and select.
1194+ // Additionally we have to keep track of whether we've seen any non-NaN
1195+ // values and then do a final select based on this predicate.
1196+ auto trueAttr = rewriter.getBoolAttr (true );
1197+ auto trueValue = rewriter.create <arith::ConstantOp>(loc, trueAttr);
1198+ auto emptyBoolTensor =
1199+ rewriter
1200+ .create <tensor::EmptyOp>(loc, reduceShape, trueValue.getType (),
1201+ dynDims)
1202+ .getResult ();
1203+ auto allResultsNaNTensor =
1204+ rewriter
1205+ .create <linalg::FillOp>(loc, ValueRange{trueValue},
1206+ ValueRange{emptyBoolTensor})
1207+ .result ();
1208+ // Note that because the linalg::ReduceOp has two variadic arguments
1209+ // (inputs and outputs) and it has the SameVariadicOperandSize trait we
1210+ // need to have the same number of inputs and outputs.
1211+ //
1212+ // The second input isn't actually used anywhere since the value used to
1213+ // update the NaN flag is calculated inside the body of the reduction and
1214+ // then used to update an out value.
1215+ // In order to satisfy type constraints we just pass another copy of the
1216+ // input here.
1217+ inputs.push_back (input);
1218+ outputs.push_back (allResultsNaNTensor);
1219+ }
1220+ }
11161221
11171222 bool didEncounterError = false ;
1118- auto linalgOp = rewriter.create <linalg::ReduceOp>(
1119- loc, input, filledTensor , axis,
1223+ linalg::LinalgOp linalgOp = rewriter.create <linalg::ReduceOp>(
1224+ loc, inputs, outputs , axis,
11201225 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
1226+ std::array<Value, 2 > binaryArgs{
1227+ blockArgs[0 ], isNanIgnoreMode ? blockArgs[2 ] : blockArgs[1 ]};
11211228 auto result = createLinalgBodyCalculationForReduceOp (
1122- op, blockArgs , elementTy, rewriter);
1229+ op, binaryArgs , elementTy, rewriter);
11231230 if (result)
11241231 didEncounterError = true ;
11251232
1126- nestedBuilder.create <linalg::YieldOp>(loc, result);
1233+ SmallVector<Value> resultsToYield;
1234+ if (isNanIgnoreMode) {
1235+ auto inputValue = blockArgs[0 ];
1236+ auto initialValue = blockArgs[2 ];
1237+ auto oldAllResultsNanFlagValue = blockArgs[3 ];
1238+
1239+ // Unordered comparison of NaN against itself will always return true.
1240+ Value isNaN = nestedBuilder.create <arith::CmpFOp>(
1241+ op->getLoc (), arith::CmpFPredicate::UNO, inputValue, inputValue);
1242+ // If we've encountered a NaN, take the non-NaN value.
1243+ auto selectOp = nestedBuilder.create <arith::SelectOp>(
1244+ op->getLoc (), isNaN, initialValue, result);
1245+ // Update the flag which keeps track of whether we have seen a non-NaN
1246+ // value.
1247+ auto newAllResultsNanFlagValue = nestedBuilder.create <arith::AndIOp>(
1248+ op->getLoc (), oldAllResultsNanFlagValue, isNaN);
1249+ resultsToYield.push_back (selectOp);
1250+ resultsToYield.push_back (newAllResultsNanFlagValue);
1251+ } else {
1252+ resultsToYield.push_back (result);
1253+ }
1254+ nestedBuilder.create <linalg::YieldOp>(loc, resultsToYield);
11271255 });
11281256
11291257 if (!didEncounterError)
11301258 return rewriter.notifyMatchFailure (
11311259 op, " unable to create linalg.generic body for reduce op" );
11321260
1261+ if (isNanIgnoreMode) {
1262+ // Materialize a check to see whether we encountered any non-NaN values, if
1263+ // we didn't we need to select a tensor of NaNs since the result will just
1264+ // be the initial identity value propagated through all the compares and
1265+ // selects inside the reduction.
1266+
1267+ // Create a tensor full of NaNs.
1268+ auto nanValueAttr = rewriter.getFloatAttr (
1269+ elementTy,
1270+ APFloat::getNaN (cast<FloatType>(elementTy).getFloatSemantics (), false ));
1271+ auto nanValue = rewriter.create <arith::ConstantOp>(loc, nanValueAttr);
1272+ auto emptyNanTensor =
1273+ rewriter
1274+ .create <tensor::EmptyOp>(loc, reduceShape,
1275+ resultTy.getElementType (), dynDims)
1276+ .getResult ();
1277+ auto nanFilledTensor =
1278+ rewriter
1279+ .create <linalg::FillOp>(loc, ValueRange{nanValue},
1280+ ValueRange{emptyNanTensor})
1281+ .result ();
1282+
1283+ // Create an empty tensor, non need to fill this since it will be
1284+ // overwritten by the select.
1285+ auto finalEmptyTensor =
1286+ rewriter
1287+ .create <tensor::EmptyOp>(loc, reduceShape,
1288+ resultTy.getElementType (), dynDims)
1289+ .getResult ();
1290+
1291+ // Do a selection between the tensors akin to:
1292+ // result = NaN if "all results NaN" else result.
1293+ SmallVector<Value> ins, outs;
1294+ ins.push_back (linalgOp->getOpResult (1 ));
1295+ ins.push_back (nanFilledTensor);
1296+ ins.push_back (linalgOp->getResult (0 ));
1297+ outs.push_back (finalEmptyTensor);
1298+ auto linalgSelect =
1299+ rewriter.create <linalg::SelectOp>(op->getLoc (), ins, outs);
1300+ linalgOp = linalgSelect;
1301+ }
1302+
11331303 SmallVector<ReassociationExprs, 4 > reassociationMap;
11341304 uint64_t expandInputRank =
1135- cast<ShapedType>(linalgOp. getResults ()[0 ].getType ()).getRank ();
1305+ cast<ShapedType>(linalgOp-> getResults ()[0 ].getType ()).getRank ();
11361306 reassociationMap.resize (expandInputRank);
11371307
11381308 for (uint64_t i = 0 ; i < expandInputRank; i++) {
@@ -1151,7 +1321,7 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
11511321 // not have access to such information. This matters when handling dynamically
11521322 // sized tensors.
11531323 rewriter.replaceOpWithNewOp <tensor::ExpandShapeOp>(
1154- op, resultTy, linalgOp. getResults ()[0 ], reassociationMap);
1324+ op, resultTy, linalgOp-> getResults ()[0 ], reassociationMap);
11551325 return success ();
11561326}
11571327
@@ -2097,6 +2267,27 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
20972267 nestedLoc, predicate, newValue, oldValue);
20982268 auto resultIndex = rewriter.create <arith::SelectOp>(
20992269 nestedLoc, predicate, newIndex, oldIndex);
2270+
2271+ // Check if we need to materialize compare and select for the given
2272+ // NaN propagation mode.
2273+
2274+ // "PROPAGATE" matches the default NaN propagation mode of the arith
2275+ // dialect so no compare and select is required.
2276+ //
2277+ // In the case "IGNORE" we check if the current argument is NaN and
2278+ // select the old index and value otherwise take the updated index and
2279+ // value.
2280+ if (const auto nanMode = argmaxOp.getNanMode (); nanMode == " IGNORE" ) {
2281+ // Unordered comparison of NaN against itself will always return
2282+ // true.
2283+ Value isNaN = rewriter.create <arith::CmpFOp>(
2284+ argmaxOp.getLoc (), arith::CmpFPredicate::UNO, newValue,
2285+ newValue);
2286+ resultMax = rewriter.create <arith::SelectOp>(nestedLoc, isNaN,
2287+ oldValue, resultMax);
2288+ resultIndex = rewriter.create <arith::SelectOp>(
2289+ nestedLoc, isNaN, oldIndex, resultIndex);
2290+ }
21002291 nestedBuilder.create <linalg::YieldOp>(
21012292 nestedLoc, ValueRange ({resultIndex, resultMax}));
21022293 });
0 commit comments